In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch.utils.data import Dataset
from torch_geometric.data import Data, DataLoader
import numpy as np
from torch_geometric.utils import dense_to_sparse
from model import TeethGNN
from torch.utils.tensorboard import SummaryWriter  # For TensorBoard

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TensorBoard writer
writer = SummaryWriter()

In [None]:
class ClinicalCaseDataset(Dataset):
    def __init__(self, clinical_data, edge_index):
        self.clinical_data = clinical_data  # [input_data, labels] pairs
        self.edge_index = edge_index  # Graph structure (same for all cases)
    
    def __len__(self):
        return len(self.clinical_data)
    
    def __getitem__(self, idx):
        input_data, labels = self.clinical_data[idx]
        input_data = torch.tensor(input_data, dtype=torch.float)  # Features: 3 translation + 4 quaternion
        labels = torch.tensor(labels, dtype=torch.float)  # Corrected positions (3 translation + 4 quaternion)
        
        # Create a Data object for PyG (input features, graph structure, target labels)
        graph_data = Data(x=input_data, edge_index=self.edge_index, y=labels)
        return graph_data

# Example dataset with random data (replace with actual clinical data)
num_cases = 1000  # Assume 1000 clinical cases
num_teeth = 32  # Assume each case has 32 teeth
clinical_data = [(np.random.rand(num_teeth, 7), np.random.rand(num_teeth, 7)) for _ in range(num_cases)]

# Graph adjacency matrix (remains the same for all cases)
adj_matrix = np.eye(num_teeth, k=1) + np.eye(num_teeth, k=-1)
edge_index = dense_to_sparse(torch.tensor(adj_matrix, dtype=torch.long).to(device))[0]

# Create datasets
train_dataset = ClinicalCaseDataset(clinical_data[:800], edge_index)
val_dataset = ClinicalCaseDataset(clinical_data[800:], edge_index)

# Create DataLoader for batching (using torch_geometric DataLoader)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [None]:
in_channels = 7  # Input: Initial positions (3 for translation + 4 for quaternion)
hidden_channels = 64  # Hidden layer size
out_channels = 7  # Output: Predicted corrected positions (3 for translation + 4 for quaternion)

model = TeethGNN(in_channels, hidden_channels, out_channels)

# Optimizer and Loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()


In [None]:
"""Training Loop"""
def train_model(train_loader, val_loader, model, optimizer, criterion, epochs=50):
    for epoch in range(epochs):
        model.train()  # Training mode
        running_loss = 0.0
        for batch in train_loader:
            batch = batch.to(device)  # Move the entire batch to the GPU

            # Forward pass
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)  # Pass input features and graph structure
            loss = criterion(out, batch.y)
            loss.backward()  # Backpropagation
            optimizer.step()  # Update weights

            running_loss += loss.item()
        
        # Log average training loss to TensorBoard
        avg_train_loss = running_loss / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)

        # Validation step
        model.eval()  # Evaluation mode
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)  # Move the entire batch to the GPU
                out = model(batch.x, batch.edge_index)
                loss = criterion(out, batch.y)
                val_loss += loss.item()
        
        # Log average validation loss to TensorBoard
        avg_val_loss = val_loss / len(val_loader)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)

        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train the model
train_model(train_loader, val_loader, model, optimizer, criterion)

In [None]:
"""validation"""
model.eval()
with torch.no_grad():
    # Pass new initial teeth positions and their graph structure to the model
    predicted_positions = model(data.x, data.edge_index)
    print(predicted_positions)
