In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import cv2
import numpy as np
from PIL import Image
import os

class DentalDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the CSV file with annotations
            img_dir (string): Directory with all the images
            transform (callable, optional): Optional transform to be applied on a sample
        """
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        
        # Create class mapping
        self.classes = {
            'Class 0- No endodontic treatment': 0,
            'Class 1- complete endodontic treatment': 1,
            'Class 2- incomplete endodontic treatment': 2,
            'cavity': 3
        }
        
    def __len__(self):
        return len(self.annotations.drop_duplicates('filename'))
        
    def __getitem__(self, idx):
        img_name = self.annotations.iloc[idx]['filename']
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        # Get all annotations for this image
        img_annotations = self.annotations[self.annotations['filename'] == img_name]
        
        # Initialize heatmaps and other targets
        num_classes = len(self.classes)
        heatmap = np.zeros((num_classes, 160, 160))  # Downsampled by 4
        wh = np.zeros((500, 2))  # Max 500 objects per image
        reg = np.zeros((500, 2))
        reg_mask = np.zeros((500,))
        ind = np.zeros((500,))
        
        draw_gaussian = lambda x, y, sigma: np.exp(-((x[..., np.newaxis] - y[np.newaxis, ...])**2).sum(axis=2) / (2*sigma**2))
        
        num_objs = 0
        for _, ann in img_annotations.iterrows():
            class_id = self.classes[ann['class']]
            bbox = np.array([ann['xmin'], ann['ymin'], ann['xmax'], ann['ymax']], dtype=np.float32)
            
            # Convert to center point and width/height
            h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
            center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2])
            
            # Scale to heatmap size
            center = center / 4
            
            center_int = center.astype(np.int32)
            
            # Generate gaussian heatmap
            radius = gaussian_radius((h, w))
            radius = max(0, int(radius))
            diameter = 2 * radius + 1
            gaussian = gaussian2D((diameter, diameter))
            
            x, y = center_int[0], center_int[1]
            height, width = heatmap.shape[1:]
            
            left, right = min(x, radius), min(width - x, radius + 1)
            top, bottom = min(y, radius), min(height - y, radius + 1)
            
            masked_heatmap = heatmap[class_id]
            masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
            
            if min(masked_gaussian.shape) > 0 and min(masked_heatmap[y - top:y + bottom, x - left:x + right].shape) > 0:
                masked_heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
                    masked_heatmap[y - top:y + bottom, x - left:x + right],
                    masked_gaussian
                )
            
            wh[num_objs] = 1. * w / 4, 1. * h / 4
            reg[num_objs] = center - center_int
            reg_mask[num_objs] = 1
            ind[num_objs] = y * width + x
            
            num_objs = num_objs + 1
            
        if self.transform:
            image = self.transform(image)
            
        ret = {
            'input': image,
            'hm': torch.from_numpy(heatmap),
            'reg_mask': torch.from_numpy(reg_mask),
            'ind': torch.from_numpy(ind),
            'wh': torch.from_numpy(wh),
            'reg': torch.from_numpy(reg)
        }
        
        return ret

class CenterNet(nn.Module):
    def __init__(self, num_classes):
        super(CenterNet, self).__init__()
        
        # Use ResNet50 as backbone (you could also use DLA-34 as in original paper)
        self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
        
        # Remove the last few layers
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # Deconvolution layers
        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Heads
        self.hm = nn.Sequential(
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )
        
        self.wh = nn.Sequential(
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 2, kernel_size=1)
        )
        
        self.reg = nn.Sequential(
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 2, kernel_size=1)
        )
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.deconv_layers(x)
        
        return {
            'hm': torch.sigmoid(self.hm(x)),
            'wh': self.wh(x),
            'reg': self.reg(x)
        }

def focal_loss(pred, gt):
    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()
    
    neg_weights = torch.pow(1 - gt, 4)
    
    loss = 0
    
    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
    
    num_pos = pos_inds.float().sum()
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()
    
    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos
        
    return loss

def reg_l1_loss(pred, target, mask):
    mask = mask.unsqueeze(2).expand_as(pred).float()
    loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
    loss = loss / (mask.sum() + 1e-4)
    return loss

# Training setup
def train(model, train_loader, optimizer, epoch, device):
    model.train()
    
    for batch_idx, batch in enumerate(train_loader):
        for k in batch:
            batch[k] = batch[k].to(device)
            
        outputs = model(batch['input'])
        
        hm_loss = focal_loss(outputs['hm'], batch['hm'])
        wh_loss = reg_l1_loss(outputs['wh'], batch['wh'], batch['reg_mask'])
        reg_loss = reg_l1_loss(outputs['reg'], batch['reg'], batch['reg_mask'])
        
        loss = hm_loss + 0.1 * wh_loss + reg_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.6f}')

# Helper functions for gaussian generation
def gaussian2D(shape, sigma=1):
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m+1,-n:n+1]
    
    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    return h

def gaussian_radius(det_size, min_overlap=0.7):
    height, width = det_size
    
    a1  = 1
    b1  = (height + width)
    c1  = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1  = (b1 + sq1) / 2
    
    a2  = 4
    b2  = 2 * (height + width)
    c2  = (1 - min_overlap) * width * height
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2  = (b2 + sq2) / 2
    
    a3  = 4 * min_overlap
    b3  = -2 * min_overlap * (height + width)
    c3  = (min_overlap - 1) * width * height
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    r3  = (b3 + sq3) / 2
    
    return min(r1, r2, r3)

# Main training setup
def main():
    # Determine device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Hyperparameters
    batch_size = 8
    num_epochs = 100
    learning_rate = 1e-4
    
    # Dataset and DataLoader
    transform = transforms.Compose([
        transforms.Resize((640, 640)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    
    dataset = DentalDataset(
        csv_file="train/_annotations.csv",
        img_dir="train",
        transform=transform
    )
    
    train_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    
    # Model
    model = CenterNet(num_classes=4).to(device)
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training loop
    for epoch in range(num_epochs):
        train(model, train_loader, optimizer, epoch, device)
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, f'checkpoint_epoch_{epoch+1}.pth')

if __name__ == '__main__':
    main()

Using device: cpu


Using cache found in C:\Users\ajani/.cache\torch\hub\pytorch_vision_v0.10.0


RuntimeError: DataLoader worker (pid(s) 40804, 18904) exited unexpectedly