In [13]:
import torch.nn as nn
import torch.nn.functional as F
import torch
class DistillationLoss:
    """Handles all loss computations for distillation."""
    
    def __init__(self, alpha=1.0, beta=1.0):
        """
        Args:
            alpha: Weight for MSE loss
            beta: Weight for cosine loss
        """
        self.alpha = alpha
        self.beta = beta
        # Option 1: Normalize by sum of coefficients
        self.normalizer = alpha + beta if (alpha + beta) > 0 else 1.0

    def __call__(self, student_features, teacher_features):
        """Compute all losses and return as dictionary."""

        student_norm = F.normalize(student_features, dim=1)
        teacher_norm = F.normalize(teacher_features, dim=1)
        N,C,H,W = student_norm.shape
        # MSE on normalized features
        mse = nn.MSELoss(reduction='sum')
        mse_loss = mse(student_norm, teacher_norm)/N

        cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=1)
        cosine_loss = 1 - cosine_sim.mean()


        # Option 1: Normalize by sum of coefficients (recommended)
        total_loss = (self.alpha * mse_loss + self.beta * cosine_loss) / self.normalizer


        return {
            'total': total_loss,
            'mse': mse_loss,
            'cosine': cosine_loss
        }

loss = DistillationLoss()
student_features = torch.randn((32,2048,7,7))
teacher_features = torch.randn((32,2048,7,7))

loss(student_features, teacher_features)


{'total': tensor(49.4986), 'mse': tensor(97.9973), 'cosine': tensor(1.0000)}

In [14]:
class DistillationLoss:
    """Handles all loss computations for distillation."""
    
    def __init__(self, alpha=1.0, beta=1.0):
        """
        Args:
            alpha: Weight for MSE loss
            beta: Weight for cosine loss
        """
        self.alpha = alpha
        self.beta = beta
        # Option 1: Normalize by sum of coefficients
        self.normalizer = alpha + beta if (alpha + beta) > 0 else 1.0

    def __call__(self, student_features, teacher_features):
        """Compute all losses and return as dictionary."""
        mse_loss = F.mse_loss(student_features, teacher_features)
        
        # Cosine similarity loss
        student_norm = F.normalize(student_features, p=2, dim=1)
        teacher_norm = F.normalize(teacher_features, p=2, dim=1)
        cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=1)
        cosine_loss = 1 - cosine_sim.mean()


        # Option 1: Normalize by sum of coefficients (recommended)
        total_loss = (self.alpha * mse_loss + self.beta * cosine_loss) / self.normalizer


        return {
            'total': total_loss,
            'mse': mse_loss,
            'cosine': cosine_loss
        }
loss = DistillationLoss()
student_features = torch.randn((32,2048,7,7))
teacher_features = torch.randn((32,2048,7,7))

loss(student_features, teacher_features)

{'total': tensor(1.4989), 'mse': tensor(1.9981), 'cosine': tensor(0.9998)}

In [7]:

transform = A.Compose([
    A.Resize(512, 1024)
    # A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# GTA5_dataset = GTA5(GTA5_path=GTA5_PATH, transform=transform)


def load_student_checkpoint(checkpoint_path: str) -> dict:
    """
    Load and process student model checkpoint.
    
    Args:
        checkpoint_path (str): Path to the checkpoint file
        
    Returns:
        dict: Processed state dict containing only student model weights
    """
    # Load checkpoint
    state_dict = torch.load(checkpoint_path)
    
    # Extract and process student weights
    student_state_dict = {k.replace('student.model.', ''): v 
                         for k, v in state_dict['state_dict'].items() 
                         if k.startswith('student') and not k.startswith('student.feature_matchers')}
    
    return student_state_dict

# Load checkpoint and save state dict
checkpoint_path = '/home/arda/dinov2/distillation/logs/stdc/distillation/version_7/checkpoints/epoch=29-val_similarity=0.36.ckpt'
student_state_dict = load_student_checkpoint(checkpoint_path)
torch.save(student_state_dict, '/home/arda/dinov2/distillation/logs/stdc/distillation/version_7/checkpoints/student_state_dict.pth')
student_state_dict.keys()






dict_keys(['model.features.0.conv.weight', 'model.features.0.bn.weight', 'model.features.0.bn.bias', 'model.features.0.bn.running_mean', 'model.features.0.bn.running_var', 'model.features.0.bn.num_batches_tracked', 'model.features.1.conv.weight', 'model.features.1.bn.weight', 'model.features.1.bn.bias', 'model.features.1.bn.running_mean', 'model.features.1.bn.running_var', 'model.features.1.bn.num_batches_tracked', 'model.features.2.conv_list.0.conv.weight', 'model.features.2.conv_list.0.bn.weight', 'model.features.2.conv_list.0.bn.bias', 'model.features.2.conv_list.0.bn.running_mean', 'model.features.2.conv_list.0.bn.running_var', 'model.features.2.conv_list.0.bn.num_batches_tracked', 'model.features.2.conv_list.1.conv.weight', 'model.features.2.conv_list.1.bn.weight', 'model.features.2.conv_list.1.bn.bias', 'model.features.2.conv_list.1.bn.running_mean', 'model.features.2.conv_list.1.bn.running_var', 'model.features.2.conv_list.1.bn.num_batches_tracked', 'model.features.2.conv_list.2

: 

In [8]:
# # Load pretrained ResNet50 from torchvision
# import torchvision.models as models
# resnet50 = models.resnet50(pretrained=True)
# resnet50.to(device)
# resnet50.state_dict().keys()

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.0.downsample.1.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.we

In [10]:
# # Create mapping between torchvision ResNet50 and student model keys
# key_mapping = {
#     # Stem mapping
#     'conv1.weight': 'model.stem.conv1.weight',
#     'bn1.weight': 'model.stem.conv1.norm.weight',
#     'bn1.bias': 'model.stem.conv1.norm.bias',
#     'bn1.running_mean': 'model.stem.conv1.norm.running_mean', 
#     'bn1.running_var': 'model.stem.conv1.norm.running_var',
#     'bn1.num_batches_tracked': 'model.stem.conv1.norm.num_batches_tracked',
# }

# # Add mappings for each layer
# for i in range(1, 5):  # ResNet50 has 4 layers
#     for j in range(3):  # Each layer has 3 bottleneck blocks
#         # Map each component of the bottleneck block
#         for k in range(1, 4):  # Each block has 3 conv layers
#             torchvision_prefix = f'layer{i}.{j}'
#             student_prefix = f'model.res{i+1}.{j}'
            
#             # Conv layers
#             key_mapping[f'{torchvision_prefix}.conv{k}.weight'] = f'{student_prefix}.conv{k}.weight'
            
#             # Batch norm layers
#             key_mapping[f'{torchvision_prefix}.bn{k}.weight'] = f'{student_prefix}.conv{k}.norm.weight'
#             key_mapping[f'{torchvision_prefix}.bn{k}.bias'] = f'{student_prefix}.conv{k}.norm.bias'
#             key_mapping[f'{torchvision_prefix}.bn{k}.running_mean'] = f'{student_prefix}.conv{k}.norm.running_mean'
#             key_mapping[f'{torchvision_prefix}.bn{k}.running_var'] = f'{student_prefix}.conv{k}.norm.running_var'
#             key_mapping[f'{torchvision_prefix}.bn{k}.num_batches_tracked'] = f'{student_prefix}.conv{k}.norm.num_batches_tracked'
            
#         # Map downsample layers if they exist
#         if j == 0:  # Only first block in each layer has downsample
#             key_mapping[f'layer{i}.0.downsample.0.weight'] = f'model.res{i+1}.0.shortcut.weight'
#             key_mapping[f'layer{i}.0.downsample.1.weight'] = f'model.res{i+1}.0.shortcut.norm.weight'
#             key_mapping[f'layer{i}.0.downsample.1.bias'] = f'model.res{i+1}.0.shortcut.norm.bias'
#             key_mapping[f'layer{i}.0.downsample.1.running_mean'] = f'model.res{i+1}.0.shortcut.norm.running_mean'
#             key_mapping[f'layer{i}.0.downsample.1.running_var'] = f'model.res{i+1}.0.shortcut.norm.running_var'
#             key_mapping[f'layer{i}.0.downsample.1.num_batches_tracked'] = f'model.res{i+1}.0.shortcut.norm.num_batches_tracked'

# # Create new state dict with mapped keys
# mapped_state_dict = {}
# for k, v in resnet50.state_dict().items():
#     if k in key_mapping:
#         mapped_state_dict[key_mapping[k]] = v

# # Load the mapped state dict into student model
# student_state_dict.update(mapped_state_dict)


In [4]:
from models.resnet_wrapper import ResNetWrapper
encoder = ResNetWrapper(depth=50, out_features=['res5'])
encoder.load_state_dict(student_state_dict)

<All keys matched successfully>

In [5]:

encoder.eval()
encoder.to(device)
encoder.model.eval()
encoder = encoder.model

In [6]:
asd = torch.randn(1, 3, 512, 1024).to(device)
encoder(asd)["res5"].shape

# Freeze all parameters of the encoder
for param in encoder.parameters():
    param.requires_grad = False

In [7]:
# First, let's create a simple decoder network
import numpy as np
import tqdm as tqdm
class SegmentationDecoder(torch.nn.Module):
    def __init__(self, in_channels=2048, num_classes=19):
        super().__init__()
        self.decoder = torch.nn.Sequential(
            # 16x32 -> 32x64
            torch.nn.ConvTranspose2d(in_channels, 1024, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            
            # 32x64 -> 64x128
            torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            
            # 64x128 -> 128x256
            torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            
            # 128x256 -> 256x512
            torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            
            # 256x512 -> 512x1024
            torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            
            # Final 1x1 conv to get to num_classes
            torch.nn.Conv2d(64, num_classes, kernel_size=1)
        )
    def forward(self, x):
        x = self.decoder(x)
        # Ensure exact output size
        if x.shape[-2:] != (512, 1024):
            x = torch.nn.functional.interpolate(
                x, size=(512, 1024), 
                mode='bilinear', 
                align_corners=False
            )
        return x

# Initialize decoder, optimizer, and loss function
decoder = SegmentationDecoder().to(device)
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
    k = (b >= 0) & (b < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iou(hist: np.ndarray) -> np.ndarray:
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)

def train_epoch(encoder, decoder, dataloader, optimizer, criterion, device, num_classes=19):
    decoder.train()
    encoder.eval()  # Keep DINO frozen
    
    total_loss = 0
    hist = np.zeros((num_classes, num_classes))  # Single histogram for entire epoch
    total_pixels = 0
    correct_pixels = 0
    
    for images, labels in tqdm.tqdm(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Get DINO features
        with torch.no_grad():
            features = encoder(images)["res5"]
        
        # Forward pass through decoder
        outputs = decoder(features)
        
        # Resize outputs to match label size if needed
        if outputs.shape[-2:] != labels.shape[-2:]:
            outputs = torch.nn.functional.interpolate(
                outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate metrics
        preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        
        # Pixel Accuracy
        valid_mask = labels != 255  # Ignore index
        total_pixels += valid_mask.sum().item()
        correct_pixels += ((preds == labels) & valid_mask).sum().item()
        
        # IoU
        preds = preds.cpu().numpy()
        target = labels.cpu().numpy()
        hist += fast_hist(preds.flatten(), target.flatten(), num_classes)
    
    # Calculate final metrics
    pixel_acc = correct_pixels / total_pixels
    
    # Per-class accuracy (mean class accuracy)
    class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)
    mean_class_acc = np.nanmean(class_acc)
    
    # IoU metrics
    iou = per_class_iou(hist)
    mean_iou = np.nanmean(iou)
    
    metrics = {
        'loss': total_loss / len(dataloader),
        'pixel_acc': pixel_acc,
        'mean_class_acc': mean_class_acc,
        'mean_iou': mean_iou,
        'class_iou': iou,
        'class_acc': class_acc
    }
    
    return metrics

train_loader = torch.utils.data.DataLoader(
    GTA5_dataset, 
    batch_size=4,
    shuffle=True,
    num_workers=4
)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    metrics = train_epoch(encoder, decoder, train_loader, optimizer, criterion, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Loss: {metrics['loss']:.4f}")
    print(f"Pixel Accuracy: {metrics['pixel_acc']:.4f}")
    print(f"Mean Class Accuracy: {metrics['mean_class_acc']:.4f}")
    print(f"Mean IoU: {metrics['mean_iou']:.4f}")
    
    # Optionally print per-class metrics
    print("\nPer-class metrics:")
    for i in range(19):  # Assuming 19 classes
        print(f"Class {i:2d} - Acc: {metrics['class_acc'][i]:.4f}, IoU: {metrics['class_iou'][i]:.4f}")

100%|██████████| 625/625 [04:58<00:00,  2.10it/s]



Epoch 1/10
Loss: 0.8306
Pixel Accuracy: 0.8372
Mean Class Accuracy: 0.3153
Mean IoU: 0.2264

Per-class metrics:
Class  0 - Acc: 0.9398, IoU: 0.9225
Class  1 - Acc: 0.5606, IoU: 0.3499
Class  2 - Acc: 0.6936, IoU: 0.6392
Class  3 - Acc: 0.4527, IoU: 0.1909
Class  4 - Acc: 0.0188, IoU: 0.0001
Class  5 - Acc: 0.0234, IoU: 0.0015
Class  6 - Acc: 0.0036, IoU: 0.0001
Class  7 - Acc: 0.0013, IoU: 0.0000
Class  8 - Acc: 0.6751, IoU: 0.5282
Class  9 - Acc: 0.6168, IoU: 0.3076
Class 10 - Acc: 0.9087, IoU: 0.8485
Class 11 - Acc: 0.0066, IoU: 0.0005
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.6739, IoU: 0.4691
Class 14 - Acc: 0.3938, IoU: 0.0370
Class 15 - Acc: 0.0041, IoU: 0.0009
Class 16 - Acc: 0.0168, IoU: 0.0059
Class 17 - Acc: 0.0007, IoU: 0.0004
Class 18 - Acc: 0.0001, IoU: 0.0000


100%|██████████| 625/625 [05:03<00:00,  2.06it/s]



Epoch 2/10
Loss: 0.4657
Pixel Accuracy: 0.8774
Mean Class Accuracy: 0.4680
Mean IoU: 0.2936

Per-class metrics:
Class  0 - Acc: 0.9635, IoU: 0.9439
Class  1 - Acc: 0.6833, IoU: 0.5072
Class  2 - Acc: 0.7679, IoU: 0.7086
Class  3 - Acc: 0.5749, IoU: 0.3549
Class  4 - Acc: 0.9668, IoU: 0.0000
Class  5 - Acc: 0.5494, IoU: 0.0057
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.7590, IoU: 0.6195
Class  9 - Acc: 0.7045, IoU: 0.4892
Class 10 - Acc: 0.9244, IoU: 0.8840
Class 11 - Acc: 0.2000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.7732, IoU: 0.6677
Class 14 - Acc: 0.5850, IoU: 0.3976
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.4403, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:00<00:00,  2.08it/s]



Epoch 3/10
Loss: 0.3906
Pixel Accuracy: 0.8857
Mean Class Accuracy: 0.4597
Mean IoU: 0.3160

Per-class metrics:
Class  0 - Acc: 0.9659, IoU: 0.9470
Class  1 - Acc: 0.7114, IoU: 0.5351
Class  2 - Acc: 0.7877, IoU: 0.7262
Class  3 - Acc: 0.6198, IoU: 0.4046
Class  4 - Acc: 0.5523, IoU: 0.0540
Class  5 - Acc: 0.4975, IoU: 0.1100
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.7734, IoU: 0.6386
Class  9 - Acc: 0.7307, IoU: 0.5263
Class 10 - Acc: 0.9300, IoU: 0.8905
Class 11 - Acc: 0.0000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8035, IoU: 0.7005
Class 14 - Acc: 0.6140, IoU: 0.4585
Class 15 - Acc: 0.0672, IoU: 0.0000
Class 16 - Acc: 0.6818, IoU: 0.0125
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [04:55<00:00,  2.11it/s]



Epoch 4/10
Loss: 0.3518
Pixel Accuracy: 0.8925
Mean Class Accuracy: 0.5109
Mean IoU: 0.3442

Per-class metrics:
Class  0 - Acc: 0.9690, IoU: 0.9512
Class  1 - Acc: 0.7336, IoU: 0.5620
Class  2 - Acc: 0.8054, IoU: 0.7417
Class  3 - Acc: 0.6490, IoU: 0.4280
Class  4 - Acc: 0.5472, IoU: 0.1737
Class  5 - Acc: 0.5167, IoU: 0.1548
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.7852, IoU: 0.6528
Class  9 - Acc: 0.7471, IoU: 0.5519
Class 10 - Acc: 0.9345, IoU: 0.8961
Class 11 - Acc: 0.5065, IoU: 0.0001
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8144, IoU: 0.7139
Class 14 - Acc: 0.6393, IoU: 0.4890
Class 15 - Acc: 0.4503, IoU: 0.0673
Class 16 - Acc: 0.6088, IoU: 0.1571
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [04:59<00:00,  2.09it/s]



Epoch 5/10
Loss: 0.3237
Pixel Accuracy: 0.8986
Mean Class Accuracy: 0.5695
Mean IoU: 0.3784

Per-class metrics:
Class  0 - Acc: 0.9710, IoU: 0.9537
Class  1 - Acc: 0.7448, IoU: 0.5791
Class  2 - Acc: 0.8206, IoU: 0.7554
Class  3 - Acc: 0.6778, IoU: 0.4612
Class  4 - Acc: 0.5469, IoU: 0.2201
Class  5 - Acc: 0.5318, IoU: 0.1763
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.7899, IoU: 0.6616
Class  9 - Acc: 0.7613, IoU: 0.5733
Class 10 - Acc: 0.9374, IoU: 0.8996
Class 11 - Acc: 0.1944, IoU: 0.0009
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8268, IoU: 0.7332
Class 14 - Acc: 0.7216, IoU: 0.5451
Class 15 - Acc: 0.6491, IoU: 0.2787
Class 16 - Acc: 0.6472, IoU: 0.3516
Class 17 - Acc: 1.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [04:59<00:00,  2.09it/s]



Epoch 6/10
Loss: 0.3063
Pixel Accuracy: 0.9027
Mean Class Accuracy: 0.5891
Mean IoU: 0.3990

Per-class metrics:
Class  0 - Acc: 0.9716, IoU: 0.9551
Class  1 - Acc: 0.7551, IoU: 0.5898
Class  2 - Acc: 0.8305, IoU: 0.7660
Class  3 - Acc: 0.6914, IoU: 0.4779
Class  4 - Acc: 0.5610, IoU: 0.2469
Class  5 - Acc: 0.5531, IoU: 0.1968
Class  6 - Acc: 1.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.7947, IoU: 0.6687
Class  9 - Acc: 0.7708, IoU: 0.5850
Class 10 - Acc: 0.9400, IoU: 0.9032
Class 11 - Acc: 0.3188, IoU: 0.0224
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8326, IoU: 0.7395
Class 14 - Acc: 0.7593, IoU: 0.5778
Class 15 - Acc: 0.6971, IoU: 0.3741
Class 16 - Acc: 0.7175, IoU: 0.4774
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:00<00:00,  2.08it/s]



Epoch 7/10
Loss: 0.2879
Pixel Accuracy: 0.9070
Mean Class Accuracy: 0.6551
Mean IoU: 0.4237

Per-class metrics:
Class  0 - Acc: 0.9733, IoU: 0.9575
Class  1 - Acc: 0.7669, IoU: 0.6053
Class  2 - Acc: 0.8379, IoU: 0.7734
Class  3 - Acc: 0.7093, IoU: 0.4956
Class  4 - Acc: 0.5746, IoU: 0.2726
Class  5 - Acc: 0.5684, IoU: 0.2110
Class  6 - Acc: 0.6714, IoU: 0.0077
Class  7 - Acc: 0.9047, IoU: 0.0168
Class  8 - Acc: 0.8027, IoU: 0.6777
Class  9 - Acc: 0.7799, IoU: 0.6029
Class 10 - Acc: 0.9416, IoU: 0.9055
Class 11 - Acc: 0.4485, IoU: 0.1534
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8461, IoU: 0.7616
Class 14 - Acc: 0.7860, IoU: 0.6102
Class 15 - Acc: 0.7789, IoU: 0.4482
Class 16 - Acc: 0.7469, IoU: 0.5513
Class 17 - Acc: 0.3095, IoU: 0.0001
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [04:57<00:00,  2.10it/s]



Epoch 8/10
Loss: 0.2735
Pixel Accuracy: 0.9108
Mean Class Accuracy: 0.6755
Mean IoU: 0.4583

Per-class metrics:
Class  0 - Acc: 0.9748, IoU: 0.9592
Class  1 - Acc: 0.7771, IoU: 0.6197
Class  2 - Acc: 0.8456, IoU: 0.7819
Class  3 - Acc: 0.7225, IoU: 0.5151
Class  4 - Acc: 0.5906, IoU: 0.2929
Class  5 - Acc: 0.5847, IoU: 0.2260
Class  6 - Acc: 0.5912, IoU: 0.0709
Class  7 - Acc: 0.8288, IoU: 0.2413
Class  8 - Acc: 0.8088, IoU: 0.6859
Class  9 - Acc: 0.7898, IoU: 0.6169
Class 10 - Acc: 0.9441, IoU: 0.9087
Class 11 - Acc: 0.5317, IoU: 0.2718
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8533, IoU: 0.7699
Class 14 - Acc: 0.7997, IoU: 0.6310
Class 15 - Acc: 0.7623, IoU: 0.4514
Class 16 - Acc: 0.7732, IoU: 0.5798
Class 17 - Acc: 0.6573, IoU: 0.0859
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [04:59<00:00,  2.09it/s]



Epoch 9/10
Loss: 0.2584
Pixel Accuracy: 0.9147
Mean Class Accuracy: 0.6825
Mean IoU: 0.4991

Per-class metrics:
Class  0 - Acc: 0.9769, IoU: 0.9619
Class  1 - Acc: 0.7865, IoU: 0.6363
Class  2 - Acc: 0.8521, IoU: 0.7884
Class  3 - Acc: 0.7363, IoU: 0.5310
Class  4 - Acc: 0.6043, IoU: 0.3181
Class  5 - Acc: 0.5987, IoU: 0.2400
Class  6 - Acc: 0.5599, IoU: 0.1105
Class  7 - Acc: 0.7606, IoU: 0.3243
Class  8 - Acc: 0.8134, IoU: 0.6918
Class  9 - Acc: 0.7952, IoU: 0.6286
Class 10 - Acc: 0.9457, IoU: 0.9105
Class 11 - Acc: 0.5703, IoU: 0.3089
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8619, IoU: 0.7828
Class 14 - Acc: 0.8351, IoU: 0.6770
Class 15 - Acc: 0.8378, IoU: 0.5885
Class 16 - Acc: 0.8224, IoU: 0.6386
Class 17 - Acc: 0.6096, IoU: 0.3460
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:03<00:00,  2.06it/s]


Epoch 10/10
Loss: 0.2473
Pixel Accuracy: 0.9175
Mean Class Accuracy: 0.7309
Mean IoU: 0.5156

Per-class metrics:
Class  0 - Acc: 0.9777, IoU: 0.9631
Class  1 - Acc: 0.7910, IoU: 0.6440
Class  2 - Acc: 0.8601, IoU: 0.7979
Class  3 - Acc: 0.7574, IoU: 0.5576
Class  4 - Acc: 0.6202, IoU: 0.3400
Class  5 - Acc: 0.6131, IoU: 0.2537
Class  6 - Acc: 0.5646, IoU: 0.1300
Class  7 - Acc: 0.7487, IoU: 0.3514
Class  8 - Acc: 0.8163, IoU: 0.6964
Class  9 - Acc: 0.7987, IoU: 0.6347
Class 10 - Acc: 0.9471, IoU: 0.9122
Class 11 - Acc: 0.5891, IoU: 0.3240
Class 12 - Acc: 0.8488, IoU: 0.0045
Class 13 - Acc: 0.8644, IoU: 0.7846
Class 14 - Acc: 0.8366, IoU: 0.6859
Class 15 - Acc: 0.8333, IoU: 0.5910
Class 16 - Acc: 0.8482, IoU: 0.6902
Class 17 - Acc: 0.5716, IoU: 0.4358
Class 18 - Acc: 0.0000, IoU: 0.0000





In [6]:
# Epoch 3/10
# Loss: 1.2027
# Pixel Accuracy: 0.6440
# Mean Class Accuracy: 0.1213
# Mean IoU: 0.0818

# Per-class metrics:
# Class  0 - Acc: 0.6736, IoU: 0.6670
# Class  1 - Acc: 0.1000, IoU: 0.0000
# Class  2 - Acc: 0.3823, IoU: 0.1999
# Class  3 - Acc: 0.0000, IoU: 0.0000
# Class  4 - Acc: 0.0000, IoU: 0.0000
# Class  5 - Acc: 0.0000, IoU: 0.0000
# Class  6 - Acc: 0.0000, IoU: 0.0000
# Class  7 - Acc: 0.0000, IoU: 0.0000
# Class  8 - Acc: 0.4519, IoU: 0.0889
# Class  9 - Acc: 0.0000, IoU: 0.0000
# Class 10 - Acc: 0.6970, IoU: 0.5976
# Class 11 - Acc: 0.0000, IoU: 0.0000
# Class 12 - Acc: 0.0000, IoU: 0.0000
# Class 13 - Acc: 0.0000, IoU: 0.0000
# Class 14 - Acc: 0.0000, IoU: 0.0000
# Class 15 - Acc: 0.0000, IoU: 0.0000
# Class 16 - Acc: 0.0000, IoU: 0.0000
# Class 17 - Acc: 0.0000, IoU: 0.0000
# Class 18 - Acc: 0.0000, IoU: 0.0000

In [7]:
import matplotlib.pyplot as plt
from datasets import get_id_to_color   


img_idx = 310

embeddings = encoder(GTA5_dataset[img_idx][0].unsqueeze(0).to(device), n=1, reshape=True, return_class_token=False, norm=False)[0]
out = decoder(embeddings)
id_to_color = get_id_to_color()

pred = out.argmax(1).cpu().numpy()
pred = pred.reshape(518, 1036)
# Convert class IDs to RGB colors
color_map = np.array([id_to_color.get(i, (0, 0, 0)) for i in range(max(id_to_color.keys()) + 1)])
pred_rgb = color_map[pred]

plt.figure(figsize=(10, 10))
plt.imshow(pred_rgb)

plt.figure(figsize=(10, 10))
labels = GTA5_dataset[img_idx][1].cpu().numpy()
color_map = np.array([id_to_color.get(i, (0, 0, 0)) for i in range(max(id_to_color.keys()) + 1)])
pred_rgb = np.zeros((*labels.shape, 3), dtype=np.uint8)
mask = labels < len(color_map)
pred_rgb[mask] = color_map[labels[mask]]
plt.imshow(pred_rgb)


ImportError: cannot import name 'get_id_to_color' from 'datasets' (/home/arda/miniconda3/envs/dinov2/lib/python3.9/site-packages/datasets/__init__.py)

In [28]:
GTA5_dataset[img_idx][1].shape


torch.Size([518, 1036])