In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import numpy as np
from torch_geometric.utils import dense_to_sparse
from models import TeethGNN

In [None]:
"""Creating Graph Data for a Single Case"""
# Node features (initial positions: 3 translation + 4 quaternion for each tooth)
initial_positions = np.random.rand(32, 7)  # This would be your actual data
x = torch.tensor(initial_positions, dtype=torch.float)

# Adjacency matrix: Define which teeth are connected (e.g., based on proximity or clinical relevance)
adj_matrix = np.eye(32, k=1) + np.eye(32, k=-1)  # Simple adjacency for neighboring teeth

# Convert the adjacency matrix to edge_index format
edge_index = dense_to_sparse(torch.tensor(adj_matrix, dtype=torch.long))[0]

# Target: Corrected positions (ground truth for training, 3 for translation + 4 for quaternion)
corrected_positions = np.random.rand(32, 7)  # Replace with your actual corrected positions
y = torch.tensor(corrected_positions, dtype=torch.float)

# Create graph data object
data = Data(x=x, edge_index=edge_index, y=y)


In [None]:
"""Training Loop"""

in_channels = 7  # Input: Initial positions (3 for translation + 4 for quaternion)
hidden_channels = 64  # Hidden layer size
out_channels = 7  # Output: Predicted corrected positions (3 for translation + 4 for quaternion)

model = TeethGNN(in_channels, hidden_channels, out_channels)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()  # Loss function to predict translations/rotations

# Training loop
def train(data, model, epochs=200):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        # Forward pass through the model
        out = model(data.x, data.edge_index)
        # Compute loss (difference between predicted and actual positions)
        loss = criterion(out, data.y)
        # Backpropagation
        loss.backward()
        optimizer.step()
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

# Train the model
train(data, model)


In [None]:
"""validation"""
model.eval()
with torch.no_grad():
    # Pass new initial teeth positions and their graph structure to the model
    predicted_positions = model(data.x, data.edge_index)
    print(predicted_positions)
