In [2]:
import os
import torch
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from pycocotools.coco import COCO
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
class CocoDataset(Dataset):
    """
    Dataset class for COCO segmentation dataset.
    
    This class handles loading and preprocessing of COCO images and their corresponding
    segmentation masks.
    """
    
    def __init__(self, root_dir: str, annotation_file: str, transform=None):
        """
        Initialize the COCO dataset.
        
        Args:
            root_dir (str): Root directory containing the images
            annotation_file (str): Path to COCO annotation JSON file
            transform: Optional transform to be applied to images and masks
        """
        self.root_dir = root_dir
        self.transform = transform
        self.coco = COCO(annotation_file)
        self.ids = list(sorted(self.coco.imgs.keys()))
        
    def __getitem__(self, index: int):
        """
        Get an image and its corresponding segmentation mask.
        
        Args:
            index (int): Index of the data point
            
        Returns:
            tuple: (image, mask) where image is a normalized tensor and 
                  mask is a tensor containing segmentation labels
        """
        # Load image
        img_id = self.ids[index]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root_dir, img_info['file_name'])
        img = Image.open(img_path).convert('RGB')
        
        # Load mask
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        mask = np.zeros((img_info['height'], img_info['width']), dtype=np.int32)
        for ann in anns:
            mask = np.maximum(mask, self.coco.annToMask(ann) * ann['category_id'])
            
        # Convert to tensors
        if self.transform:
            transformed = self.transform(image=np.array(img), mask=mask)
            img, mask = transformed['image'], transformed['mask']
            
        img = torch.from_numpy(img).permute(2, 0, 1).float()
        mask = torch.from_numpy(mask).long()
        # Map background (0) to 255 and decrease other class IDs by 1
        # mask[mask == 0] = 255
        # mask[mask < 255] -= 1
        
        return img, mask
    
    def __len__(self) -> int:
        """
        Get the number of images in the dataset.
        
        Returns:
            int: Length of the dataset
        """
        return len(self.ids)


In [3]:
# For training data
import albumentations as A

transform = A.Compose([
    A.Resize(504, 504),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = CocoDataset(
    root_dir='/home/arda/anyma/datasets/coco/train2017',
    annotation_file='/home/arda/anyma/datasets/coco/annotations/instances_train2017.json',
    transform=transform  # Add your transforms here if needed
)

# For validation data
val_dataset = CocoDataset(
    root_dir='/home/arda/anyma/datasets/coco/val2017',
    annotation_file='/home/arda/anyma/datasets/coco/annotations/instances_val2017.json',
    transform=transform  # Add your transforms here if needed
)

loading annotations into memory...
Done (t=13.64s)
creating index...
index created!
loading annotations into memory...
Done (t=0.37s)
creating index...
index created!


In [4]:
import sys
sys.path.append('/home/arda/dinov2/distillation')
from models.dinov2 import DINOv2ViT
device = 'cuda' 
encoder = DINOv2ViT().to(device)



Using cache found in /home/arda/.cache/torch/hub/facebookresearch_dinov2_main


In [6]:
from torch.utils.data import DataLoader
from torch import nn
import torch.optim as optim
from tqdm import tqdm
from torch.amp import autocast, GradScaler
import numpy as np

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=8)

class SegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.decoder = nn.Sequential(
            # Starting from 36x36 (after encoder)
            nn.ConvTranspose2d(in_channels, 256, kernel_size=4, stride=2, padding=1),  # -> 72x72
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # -> 144x144
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # -> 288x288
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # -> 576x576
            nn.ReLU(),
            nn.Upsample(size=(504, 504), mode='bilinear', align_corners=True),
            nn.Conv2d(32, num_classes, kernel_size=1)  # 1x1 conv for final class predictions
        )
    
    def forward(self, x):
        return self.decoder(x)  # Remove the double application of decoder

# Initialize models and training components
num_classes = 91  # COCO has 90 classes + background
decoder = SegmentationHead(in_channels=1536, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.Adam(list(decoder.parameters())+list(encoder.parameters()), lr=1e-4)
scaler = GradScaler(device=device)

def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
    k = (b >= 0) & (b < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iou(hist: np.ndarray) -> np.ndarray:
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)

# /home/arda/resnet_logs/v2/coco.ipynb
def train_one_epoch(encoder, decoder, dataloader, criterion, optimizer, scaler):
    decoder.train()
    total_loss = 0
    hist = np.zeros((num_classes, num_classes))
    total_pixels = 0
    correct_pixels = 0
    
    for images, masks in tqdm(dataloader, desc="Training"):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        
        # with torch.no_grad():
            # Ensure encoder runs in float32 without autocast
        
        with autocast(device_type='cuda'):
            features = encoder(images)['feature_map']

            outputs = decoder(features)
            loss = criterion(outputs, masks)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        
        # Calculate metrics
        preds = torch.argmax(outputs, dim=1)
        
        # Pixel Accuracy
        valid_mask = masks != 255
        total_pixels += valid_mask.sum().item()
        correct_pixels += ((preds == masks) & valid_mask).sum().item()
        
        # IoU
        preds = preds.cpu().numpy()
        target = masks.cpu().numpy()
        hist += fast_hist(preds.flatten(), target.flatten(), num_classes)
    
    # Calculate final metrics
    pixel_acc = correct_pixels / total_pixels
    class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)
    mean_class_acc = np.nanmean(class_acc)
    iou = per_class_iou(hist)
    mean_iou = np.nanmean(iou)
    
    return {
        'loss': total_loss / len(dataloader),
        'pixel_acc': pixel_acc,
        'mean_class_acc': mean_class_acc,
        'mean_iou': mean_iou,
        'class_iou': iou,
        'class_acc': class_acc
    }

# /home/arda/resnet_logs/v2/coco.ipynb
# /home/arda/resnet_logs/v2/coco.ipynb
def validate(encoder, decoder, dataloader, criterion):
    decoder.eval()
    total_loss = 0
    hist = np.zeros((num_classes, num_classes))
    total_pixels = 0
    correct_pixels = 0
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Validation"):
            images, masks = images.to(device), masks.to(device)
            
            # Ensure encoder runs in float32 without autocast
            
            with autocast(device_type='cuda'):
                features = encoder(images)['feature_map']
                outputs = decoder(features)
                loss = criterion(outputs, masks)
                

            
            total_loss += loss.item()
            
            preds = torch.argmax(outputs, dim=1)
            
            valid_mask = masks != 255
            total_pixels += valid_mask.sum().item()
            correct_pixels += ((preds == masks) & valid_mask).sum().item()
            
            preds = preds.cpu().numpy()
            target = masks.cpu().numpy()
            hist += fast_hist(preds.flatten(), target.flatten(), num_classes)
    
    pixel_acc = correct_pixels / total_pixels
    class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)
    mean_class_acc = np.nanmean(class_acc)
    iou = per_class_iou(hist)
    mean_iou = np.nanmean(iou)
    
    return {
        'loss': total_loss / len(dataloader),
        'pixel_acc': pixel_acc,
        'mean_class_acc': mean_class_acc,
        'mean_iou': mean_iou,
        'class_iou': iou,
        'class_acc': class_acc
    }

# Training loop
num_epochs = 10
best_val_miou = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    train_metrics = train_one_epoch(encoder, decoder, train_loader, criterion, optimizer, scaler)
    val_metrics = validate(encoder, decoder, val_loader, criterion)
    
    print(f"Train - Loss: {train_metrics['loss']:.4f}, Pixel Acc: {train_metrics['pixel_acc']:.4f}, "
          f"Mean IoU: {train_metrics['mean_iou']:.4f}")
    print(f"Val   - Loss: {val_metrics['loss']:.4f}, Pixel Acc: {val_metrics['pixel_acc']:.4f}, "
          f"Mean IoU: {val_metrics['mean_iou']:.4f}")
    
    # Save best model based on validation mIoU
    if val_metrics['mean_iou'] > best_val_miou:
        best_val_miou = val_metrics['mean_iou']
        torch.save({
            'decoder_state_dict': decoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'metrics': val_metrics
        }, 'best_segmentation_model.pth')


Epoch 1/10


Training:  53%|█████▎    | 1957/3697 [1:36:47<1:25:21,  2.94s/it]