In [2]:
import os
from datasets_gta5 import GTA5, CityScapes
import albumentations as A
import torch

CITYSCAPES_PATH = '/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/Cityscapes/Cityscapes'


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  check_for_updates()


In [3]:
transform = A.Compose([
    A.Resize(256, 512),
    # A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
CITYSCAPES_dataset = CityScapes(CITYSCAPES_PATH, train_val='train', transform=transform)


def load_student_checkpoint(checkpoint_path: str) -> dict:
    """
    Load and process student model checkpoint.
    
    Args:
        checkpoint_path (str): Path to the checkpoint file
        
    Returns:
        dict: Processed state dict containing only student model weights
    """
    # Load checkpoint
    state_dict = torch.load(checkpoint_path)
    
    # Extract and process student weights, excluding batch norm layers
    student_state_dict = {k.replace('student.model.', ''): v 
                         for k, v in state_dict['state_dict'].items() 
                         if k.startswith('student.model.') and not k.startswith('student.batch_norm_layers')}
    
    return student_state_dict

# Load checkpoint and save state dict
checkpoint_path = '/home/arda/dinov2/distillation/checkpoints/resnet50/epoch=19-val_similarity=0.37.ckpt'
student_state_dict = load_student_checkpoint(checkpoint_path)
# torch.save(student_state_dict, '/home/arda/dinov2/distillation/logs/resnet50/distillation/version_7/checkpoints/student_state_dict.pth')
# student_state_dict.keys()


  state_dict = torch.load(checkpoint_path)


In [4]:
from models.resnet_wrapper import ResNetWrapper
encoder = ResNetWrapper(depth=50, out_features=['res5'])
encoder.load_state_dict(student_state_dict)

<All keys matched successfully>

In [5]:

encoder.eval()
encoder.to(device)
encoder.model.eval()
encoder = encoder.model

In [6]:
asd = torch.randn(1, 3, 512, 1024).to(device)
encoder(asd)["res5"].shape

# Freeze all parameters of the encoder
for param in encoder.parameters():
    param.requires_grad = False

In [7]:
# First, let's create a simple decoder network
import numpy as np
import tqdm as tqdm
class SegmentationDecoder(torch.nn.Module):
    def __init__(self, in_channels=2048, num_classes=19):
        super().__init__()
        self.decoder = torch.nn.Sequential(
            # 16x32 -> 32x64
            torch.nn.ConvTranspose2d(in_channels, 1024, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            
            # 32x64 -> 64x128
            torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            
            # 64x128 -> 128x256
            torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            
            # 128x256 -> 256x512
            torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            
            # 256x512 -> 512x1024
            torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            
            # Final 1x1 conv to get to num_classes
            torch.nn.Conv2d(64, num_classes, kernel_size=1)
        )
    def forward(self, x):
        x = self.decoder(x)
        # Ensure exact output size
        if x.shape[-2:] != (256, 512):
            x = torch.nn.functional.interpolate(
                x, size=(256, 512), 
                mode='bilinear', 
                align_corners=False
            )
        return x

# Initialize decoder, optimizer, and loss function
decoder = SegmentationDecoder().to(device)
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
    k = (b >= 0) & (b < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iou(hist: np.ndarray) -> np.ndarray:
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)

def train_epoch(encoder, decoder, dataloader, optimizer, criterion, device, num_classes=19):
    decoder.train()
    encoder.eval()  # Keep DINO frozen
    
    total_loss = 0
    hist = np.zeros((num_classes, num_classes))  # Single histogram for entire epoch
    total_pixels = 0
    correct_pixels = 0
    
    for images, labels in tqdm.tqdm(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Get DINO features
        with torch.no_grad():
            features = encoder(images)["res5"]
        
        # Forward pass through decoder
        outputs = decoder(features)
        
        # Resize outputs to match label size if needed
        if outputs.shape[-2:] != labels.shape[-2:]:
            outputs = torch.nn.functional.interpolate(
                outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate metrics
        preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        
        # Pixel Accuracy
        valid_mask = labels != 255  # Ignore index
        total_pixels += valid_mask.sum().item()
        correct_pixels += ((preds == labels) & valid_mask).sum().item()
        
        # IoU
        preds = preds.cpu().numpy()
        target = labels.cpu().numpy()
        hist += fast_hist(preds.flatten(), target.flatten(), num_classes)
    
    # Calculate final metrics
    pixel_acc = correct_pixels / total_pixels
    
    # Per-class accuracy (mean class accuracy)
    class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)
    mean_class_acc = np.nanmean(class_acc)
    
    # IoU metrics
    iou = per_class_iou(hist)
    mean_iou = np.nanmean(iou)
    
    metrics = {
        'loss': total_loss / len(dataloader),
        'pixel_acc': pixel_acc,
        'mean_class_acc': mean_class_acc,
        'mean_iou': mean_iou,
        'class_iou': iou,
        'class_acc': class_acc
    }
    
    return metrics

train_loader = torch.utils.data.DataLoader(
    CITYSCAPES_dataset, 
    batch_size=4,
    shuffle=True,
    num_workers=4
)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    metrics = train_epoch(encoder, decoder, train_loader, optimizer, criterion, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Loss: {metrics['loss']:.4f}")
    print(f"Pixel Accuracy: {metrics['pixel_acc']:.4f}")
    print(f"Mean Class Accuracy: {metrics['mean_class_acc']:.4f}")
    print(f"Mean IoU: {metrics['mean_iou']:.4f}")
    
    # Optionally print per-class metrics
    print("\nPer-class metrics:")
    for i in range(19):  # Assuming 19 classes
        print(f"Class {i:2d} - Acc: {metrics['class_acc'][i]:.4f}, IoU: {metrics['class_iou'][i]:.4f}")

100%|██████████| 393/393 [00:38<00:00, 10.11it/s]



Epoch 1/10
Loss: 1.1080
Pixel Accuracy: 0.8075
Mean Class Accuracy: 0.2705
Mean IoU: 0.2137

Per-class metrics:
Class  0 - Acc: 0.9423, IoU: 0.8843
Class  1 - Acc: 0.6248, IoU: 0.4569
Class  2 - Acc: 0.8034, IoU: 0.7337
Class  3 - Acc: 0.0080, IoU: 0.0001
Class  4 - Acc: 0.0108, IoU: 0.0000
Class  5 - Acc: 0.0165, IoU: 0.0005
Class  6 - Acc: 0.0001, IoU: 0.0000
Class  7 - Acc: 0.0096, IoU: 0.0001
Class  8 - Acc: 0.7590, IoU: 0.6499
Class  9 - Acc: 0.0644, IoU: 0.0119
Class 10 - Acc: 0.8298, IoU: 0.6172
Class 11 - Acc: 0.3396, IoU: 0.0786
Class 12 - Acc: 0.0031, IoU: 0.0026
Class 13 - Acc: 0.6892, IoU: 0.6111
Class 14 - Acc: 0.0125, IoU: 0.0021
Class 15 - Acc: 0.0003, IoU: 0.0001
Class 16 - Acc: 0.0136, IoU: 0.0068
Class 17 - Acc: 0.0044, IoU: 0.0035
Class 18 - Acc: 0.0078, IoU: 0.0017


100%|██████████| 393/393 [00:36<00:00, 10.74it/s]



Epoch 2/10
Loss: 0.6376
Pixel Accuracy: 0.8524
Mean Class Accuracy: 0.3375
Mean IoU: 0.2520

Per-class metrics:
Class  0 - Acc: 0.9518, IoU: 0.9150
Class  1 - Acc: 0.6669, IoU: 0.5514
Class  2 - Acc: 0.8315, IoU: 0.7739
Class  3 - Acc: 0.0189, IoU: 0.0000
Class  4 - Acc: 0.3279, IoU: 0.0049
Class  5 - Acc: 0.0000, IoU: 0.0000
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.8054, IoU: 0.7144
Class  9 - Acc: 0.0762, IoU: 0.0002
Class 10 - Acc: 0.8773, IoU: 0.7913
Class 11 - Acc: 0.4674, IoU: 0.3198
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.7882, IoU: 0.7152
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.1782, IoU: 0.0017
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.4225, IoU: 0.0007


 24%|██▍       | 96/393 [00:10<00:37,  7.96it/s]

In [41]:
# import tarfile
# from tqdm import tqdm
# data_dir = '/home/arda/data/sam/downloads/sa_000020.tar'

# with tarfile.open(data_dir, 'r') as tar:
#     for member in tqdm(tar.getmembers()):
#         tar.extract(member, path='/home/arda/data/sam/extracted')