In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class ToothDataset(Dataset):
    def __init__(self, csv_file, transform=None, tooth_cols=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
        self.tooth_cols = tooth_cols or [col for col in self.df.columns if col.startswith('tooth_')]
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row['image_path']).convert('RGB')
        label = row[self.tooth_cols].values.astype('float32')
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label)


In [2]:
from torchvision import models
import torch.nn as nn

def get_resnet50_multilabel(num_teeth):
    model = models.resnet50(pretrained=True)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_teeth)
    return model


In [3]:
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, logits, targets):
        probas = torch.sigmoid(logits)
        ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        pt = torch.where(targets == 1, probas, 1 - probas)
        focal_weight = self.alpha * (1 - pt) ** self.gamma
        loss = focal_weight * ce_loss
        return loss.mean() if self.reduction == 'mean' else loss.sum()


In [None]:
from torch.utils.data import DataLoader

# setup
num_teeth = 16  # or 32 for both jaws
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = ToothDataset(csv_file='tooth_labels.csv', transform=transform)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)

model = get_resnet50_multilabel(num_teeth)
criterion = FocalLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# training loop
model.train()
epochs = 10  # set number of epochs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for epoch in range(epochs):
    for images, targets in train_loader:
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch}, Loss: {loss.item()}')



RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

: 