In [11]:
import os
from datasets_gta5 import GTA5, CityScapes
import albumentations as A
import torch

CITYSCAPES_PATH = '/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/Cityscapes/Cityscapes'


device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [12]:

transform = A.Compose([
    A.Resize(512, 1024),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
CITYSCAPES_dataset = CityScapes(CITYSCAPES_PATH, train_val='train', transform=transform)


def load_student_checkpoint(checkpoint_path: str) -> dict:
    """
    Load and process student model checkpoint.
    
    Args:
        checkpoint_path (str): Path to the checkpoint file
        
    Returns:
        dict: Processed state dict containing only student model weights
    """
    # Load checkpoint
    state_dict = torch.load(checkpoint_path)
    
    # Extract and process student weights
    student_state_dict = {k.replace('student.model.', ''): v 
                         for k, v in state_dict['state_dict'].items() 
                         if k.startswith('student') and not k.startswith('student.feature_matchers')}
    
    return student_state_dict

# Load checkpoint and save state dict
# checkpoint_path = '/home/arda/dinov2/distillation/logs/resnet50/distillation/version_7/checkpoints/epoch=24-val_similarity=0.37.ckpt'
# student_state_dict = load_student_checkpoint(checkpoint_path)
# torch.save(student_state_dict, '/home/arda/dinov2/distillation/logs/resnet50/distillation/version_7/checkpoints/student_state_dict.pth')
# student_state_dict.keys()






In [13]:
# from models.resnet_wrapper import ResNetWrapper
import torchvision
encoder = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
# encoder = torchvision.models.resnet50(pretrained=True)
encoder = torch.nn.Sequential(*list(encoder.children())[:-2])  # Remove pooling and fc layers

In [14]:

encoder.eval()
encoder.to(device)


Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [15]:
asd = torch.randn(1, 3, 512, 1024).to(device)
encoder(asd)

# # # Freeze all parameters of the encoder
# for param in encoder.parameters():
#     param.requires_grad = False

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0212, 0.0141, 0.0000,  ..., 0.2158, 0.1861, 0.0000],
          [0.0000, 0.1595, 0.2351,  ..., 0.1786, 0.0847, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.1378, 0.0811],
          [0.0893, 0.0688, 0.0000,  ..., 0.1983, 0.4144, 0.4044],
          [0.2365, 0.3231, 0.3871,  ..., 0.5128, 0.6075, 0.6301]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0675, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.1769,  ..., 0.2568, 0.1633, 0.0000],
          [0.1378, 0.2916, 0.4574,  ..., 0.3505, 0.1265, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0673,  ..., 0.3052, 0.2187, 0.0205],
          [0.0000, 0.0000, 0.0347,  ..., 0.2146, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0040]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

In [16]:
# First, let's create a simple decoder network
import numpy as np
from tqdm.notebook import tqdm
class SegmentationDecoder(torch.nn.Module):
    def __init__(self, in_channels=2048, num_classes=19):
        super().__init__()
        self.decoder = torch.nn.Sequential(
            # 16x32 -> 32x64
            torch.nn.ConvTranspose2d(in_channels, 1024, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            
            # 32x64 -> 64x128
            torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            
            # 64x128 -> 128x256
            torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            
            # 128x256 -> 256x512
            torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            
            # 256x512 -> 512x1024
            torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            
            # Final 1x1 conv to get to num_classes
            torch.nn.Conv2d(64, num_classes, kernel_size=1)
        )
    def forward(self, x):
        x = self.decoder(x)
        # Ensure exact output size
        if x.shape[-2:] != (512, 1024):
            x = torch.nn.functional.interpolate(
                x, size=(512, 1024), 
                mode='bilinear', 
                align_corners=False
            )
        return x

# Initialize decoder, optimizer, and loss function
decoder = SegmentationDecoder().to(device)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
    k = (b >= 0) & (b < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iou(hist: np.ndarray) -> np.ndarray:
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)

def train_epoch(encoder, decoder, dataloader, optimizer, criterion, device, num_classes=19):
    decoder.train()
    encoder.eval()  # Keep DINO frozen
    
    total_loss = 0
    hist = np.zeros((num_classes, num_classes))  # Single histogram for entire epoch
    total_pixels = 0
    correct_pixels = 0
    
    for images, labels in tqdm(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Get DINO features
        # with torch.no_grad():
        features = encoder(images)
        
        # Forward pass through decoder
        outputs = decoder(features)
        
        # Resize outputs to match label size if needed
        if outputs.shape[-2:] != labels.shape[-2:]:
            outputs = torch.nn.functional.interpolate(
                outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate metrics
        preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        
        # Pixel Accuracy
        valid_mask = labels != 255  # Ignore index
        total_pixels += valid_mask.sum().item()
        correct_pixels += ((preds == labels) & valid_mask).sum().item()
        
        # IoU
        preds = preds.cpu().numpy()
        target = labels.cpu().numpy()
        hist += fast_hist(preds.flatten(), target.flatten(), num_classes)
    
    # Calculate final metrics
    pixel_acc = correct_pixels / total_pixels
    
    # Per-class accuracy (mean class accuracy)
    class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)
    mean_class_acc = np.nanmean(class_acc)
    
    # IoU metrics
    iou = per_class_iou(hist)
    mean_iou = np.nanmean(iou)
    
    metrics = {
        'loss': total_loss / len(dataloader),
        'pixel_acc': pixel_acc,
        'mean_class_acc': mean_class_acc,
        'mean_iou': mean_iou,
        'class_iou': iou,
        'class_acc': class_acc
    }
    
    return metrics

train_loader = torch.utils.data.DataLoader(
    CITYSCAPES_dataset, 
    batch_size=4,
    shuffle=True,
    num_workers=4
)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    metrics = train_epoch(encoder, decoder, train_loader, optimizer, criterion, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Loss: {metrics['loss']:.4f}")
    print(f"Pixel Accuracy: {metrics['pixel_acc']:.4f}")
    print(f"Mean Class Accuracy: {metrics['mean_class_acc']:.4f}")
    print(f"Mean IoU: {metrics['mean_iou']:.4f}")
    
    # Optionally print per-class metrics
    print("\nPer-class metrics:")
    for i in range(19):  # Assuming 19 classes
        print(f"Class {i:2d} - Acc: {metrics['class_acc'][i]:.4f}, IoU: {metrics['class_iou'][i]:.4f}")

  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 1/10
Loss: 1.0925
Pixel Accuracy: 0.8250
Mean Class Accuracy: 0.2539
Mean IoU: 0.2155

Per-class metrics:
Class  0 - Acc: 0.9333, IoU: 0.8834
Class  1 - Acc: 0.6320, IoU: 0.4549
Class  2 - Acc: 0.8220, IoU: 0.7669
Class  3 - Acc: 0.0146, IoU: 0.0001
Class  4 - Acc: 0.0000, IoU: 0.0000
Class  5 - Acc: 0.0178, IoU: 0.0001
Class  6 - Acc: 0.0005, IoU: 0.0001
Class  7 - Acc: 0.0055, IoU: 0.0004
Class  8 - Acc: 0.8338, IoU: 0.7382
Class  9 - Acc: 0.0074, IoU: 0.0005
Class 10 - Acc: 0.8233, IoU: 0.6128
Class 11 - Acc: 0.0176, IoU: 0.0004
Class 12 - Acc: 0.0108, IoU: 0.0057
Class 13 - Acc: 0.6775, IoU: 0.6276
Class 14 - Acc: 0.0025, IoU: 0.0020
Class 15 - Acc: 0.0028, IoU: 0.0000
Class 16 - Acc: 0.0005, IoU: 0.0000
Class 17 - Acc: 0.0027, IoU: 0.0020
Class 18 - Acc: 0.0199, IoU: 0.0003


  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 2/10
Loss: 0.5447
Pixel Accuracy: 0.8854
Mean Class Accuracy: 0.3830
Mean IoU: 0.2704

Per-class metrics:
Class  0 - Acc: 0.9699, IoU: 0.9448
Class  1 - Acc: 0.7202, IoU: 0.6388
Class  2 - Acc: 0.8488, IoU: 0.8156
Class  3 - Acc: 0.2197, IoU: 0.0023
Class  4 - Acc: 0.0000, IoU: 0.0000
Class  5 - Acc: 0.2267, IoU: 0.0019
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0850, IoU: 0.0052
Class  8 - Acc: 0.8706, IoU: 0.8195
Class  9 - Acc: 0.0222, IoU: 0.0003
Class 10 - Acc: 0.8605, IoU: 0.8129
Class 11 - Acc: 0.7324, IoU: 0.2778
Class 12 - Acc: 0.0185, IoU: 0.0031
Class 13 - Acc: 0.8484, IoU: 0.8001
Class 14 - Acc: 0.0045, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0803, IoU: 0.0157
Class 18 - Acc: 0.7692, IoU: 0.0000


  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 3/10
Loss: 0.4005
Pixel Accuracy: 0.9020
Mean Class Accuracy: 0.5113
Mean IoU: 0.3197

Per-class metrics:
Class  0 - Acc: 0.9760, IoU: 0.9553
Class  1 - Acc: 0.8015, IoU: 0.7176
Class  2 - Acc: 0.8747, IoU: 0.8412
Class  3 - Acc: 0.2969, IoU: 0.0194
Class  4 - Acc: 0.4698, IoU: 0.0206
Class  5 - Acc: 0.4092, IoU: 0.1434
Class  6 - Acc: 0.4286, IoU: 0.0000
Class  7 - Acc: 0.7664, IoU: 0.0786
Class  8 - Acc: 0.8761, IoU: 0.8322
Class  9 - Acc: 0.6666, IoU: 0.1855
Class 10 - Acc: 0.9343, IoU: 0.8853
Class 11 - Acc: 0.7321, IoU: 0.5781
Class 12 - Acc: 0.0232, IoU: 0.0002
Class 13 - Acc: 0.8522, IoU: 0.8151
Class 14 - Acc: 0.0259, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.1122, IoU: 0.0003
Class 18 - Acc: 0.4692, IoU: 0.0016


  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 4/10
Loss: 0.3208
Pixel Accuracy: 0.9169
Mean Class Accuracy: 0.5693
Mean IoU: 0.4014

Per-class metrics:
Class  0 - Acc: 0.9787, IoU: 0.9607
Class  1 - Acc: 0.8477, IoU: 0.7570
Class  2 - Acc: 0.9035, IoU: 0.8685
Class  3 - Acc: 0.4271, IoU: 0.1017
Class  4 - Acc: 0.4895, IoU: 0.2524
Class  5 - Acc: 0.5942, IoU: 0.2640
Class  6 - Acc: 0.8340, IoU: 0.0101
Class  7 - Acc: 0.7153, IoU: 0.4072
Class  8 - Acc: 0.9112, IoU: 0.8608
Class  9 - Acc: 0.7011, IoU: 0.4697
Class 10 - Acc: 0.9504, IoU: 0.9085
Class 11 - Acc: 0.7403, IoU: 0.6165
Class 12 - Acc: 0.0811, IoU: 0.0000
Class 13 - Acc: 0.8749, IoU: 0.8376
Class 14 - Acc: 0.1875, IoU: 0.0001
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.5805, IoU: 0.3111


  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 5/10
Loss: 0.2748
Pixel Accuracy: 0.9247
Mean Class Accuracy: 0.6589
Mean IoU: 0.4457

Per-class metrics:
Class  0 - Acc: 0.9811, IoU: 0.9641
Class  1 - Acc: 0.8695, IoU: 0.7835
Class  2 - Acc: 0.9158, IoU: 0.8784
Class  3 - Acc: 0.5393, IoU: 0.2383
Class  4 - Acc: 0.5712, IoU: 0.3390
Class  5 - Acc: 0.6505, IoU: 0.3242
Class  6 - Acc: 0.7641, IoU: 0.1565
Class  7 - Acc: 0.7253, IoU: 0.4891
Class  8 - Acc: 0.9212, IoU: 0.8726
Class  9 - Acc: 0.7496, IoU: 0.5518
Class 10 - Acc: 0.9485, IoU: 0.9064
Class 11 - Acc: 0.7468, IoU: 0.6294
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8918, IoU: 0.8545
Class 14 - Acc: 0.3498, IoU: 0.0050
Class 15 - Acc: 0.8054, IoU: 0.0021
Class 16 - Acc: 0.4835, IoU: 0.0028
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.6054, IoU: 0.4708


  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 6/10
Loss: 0.2403
Pixel Accuracy: 0.9315
Mean Class Accuracy: 0.6507
Mean IoU: 0.4890

Per-class metrics:
Class  0 - Acc: 0.9820, IoU: 0.9664
Class  1 - Acc: 0.8806, IoU: 0.7943
Class  2 - Acc: 0.9276, IoU: 0.8914
Class  3 - Acc: 0.6490, IoU: 0.3822
Class  4 - Acc: 0.6604, IoU: 0.4315
Class  5 - Acc: 0.6780, IoU: 0.3619
Class  6 - Acc: 0.6902, IoU: 0.3300
Class  7 - Acc: 0.7944, IoU: 0.5492
Class  8 - Acc: 0.9273, IoU: 0.8801
Class  9 - Acc: 0.7591, IoU: 0.5696
Class 10 - Acc: 0.9547, IoU: 0.9170
Class 11 - Acc: 0.7720, IoU: 0.6628
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.9056, IoU: 0.8690
Class 14 - Acc: 0.4647, IoU: 0.1399
Class 15 - Acc: 0.1896, IoU: 0.0018
Class 16 - Acc: 0.4095, IoU: 0.0520
Class 17 - Acc: 0.1250, IoU: 0.0000
Class 18 - Acc: 0.5942, IoU: 0.4916


  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 7/10
Loss: 0.2089
Pixel Accuracy: 0.9392
Mean Class Accuracy: 0.7336
Mean IoU: 0.5295

Per-class metrics:
Class  0 - Acc: 0.9859, IoU: 0.9726
Class  1 - Acc: 0.9012, IoU: 0.8298
Class  2 - Acc: 0.9380, IoU: 0.9029
Class  3 - Acc: 0.7254, IoU: 0.4954
Class  4 - Acc: 0.7104, IoU: 0.5000
Class  5 - Acc: 0.6972, IoU: 0.3927
Class  6 - Acc: 0.6896, IoU: 0.3745
Class  7 - Acc: 0.7981, IoU: 0.5661
Class  8 - Acc: 0.9335, IoU: 0.8881
Class  9 - Acc: 0.7796, IoU: 0.6132
Class 10 - Acc: 0.9583, IoU: 0.9234
Class 11 - Acc: 0.7854, IoU: 0.6804
Class 12 - Acc: 0.7092, IoU: 0.0080
Class 13 - Acc: 0.9215, IoU: 0.8820
Class 14 - Acc: 0.4576, IoU: 0.2350
Class 15 - Acc: 0.0227, IoU: 0.0000
Class 16 - Acc: 0.5411, IoU: 0.2726
Class 17 - Acc: 0.7723, IoU: 0.0011
Class 18 - Acc: 0.6120, IoU: 0.5234


  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 8/10
Loss: 0.1941
Pixel Accuracy: 0.9417
Mean Class Accuracy: 0.7760
Mean IoU: 0.5683

Per-class metrics:
Class  0 - Acc: 0.9853, IoU: 0.9716
Class  1 - Acc: 0.9000, IoU: 0.8243
Class  2 - Acc: 0.9422, IoU: 0.9066
Class  3 - Acc: 0.7435, IoU: 0.5203
Class  4 - Acc: 0.7343, IoU: 0.5280
Class  5 - Acc: 0.7138, IoU: 0.4199
Class  6 - Acc: 0.6852, IoU: 0.4073
Class  7 - Acc: 0.8111, IoU: 0.5926
Class  8 - Acc: 0.9346, IoU: 0.8905
Class  9 - Acc: 0.7876, IoU: 0.6253
Class 10 - Acc: 0.9570, IoU: 0.9195
Class 11 - Acc: 0.7925, IoU: 0.6838
Class 12 - Acc: 0.6662, IoU: 0.1942
Class 13 - Acc: 0.9445, IoU: 0.9082
Class 14 - Acc: 0.5465, IoU: 0.3848
Class 15 - Acc: 0.5594, IoU: 0.0082
Class 16 - Acc: 0.5207, IoU: 0.3639
Class 17 - Acc: 0.8717, IoU: 0.1009
Class 18 - Acc: 0.6473, IoU: 0.5480


  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 9/10
Loss: 0.1651
Pixel Accuracy: 0.9495
Mean Class Accuracy: 0.8138
Mean IoU: 0.6284

Per-class metrics:
Class  0 - Acc: 0.9877, IoU: 0.9758
Class  1 - Acc: 0.9136, IoU: 0.8467
Class  2 - Acc: 0.9506, IoU: 0.9179
Class  3 - Acc: 0.7955, IoU: 0.5987
Class  4 - Acc: 0.7807, IoU: 0.6031
Class  5 - Acc: 0.7322, IoU: 0.4535
Class  6 - Acc: 0.7228, IoU: 0.4516
Class  7 - Acc: 0.8253, IoU: 0.6220
Class  8 - Acc: 0.9410, IoU: 0.9005
Class  9 - Acc: 0.8268, IoU: 0.6775
Class 10 - Acc: 0.9613, IoU: 0.9281
Class 11 - Acc: 0.8274, IoU: 0.7240
Class 12 - Acc: 0.6658, IoU: 0.3621
Class 13 - Acc: 0.9544, IoU: 0.9222
Class 14 - Acc: 0.6454, IoU: 0.5188
Class 15 - Acc: 0.8196, IoU: 0.0579
Class 16 - Acc: 0.6389, IoU: 0.5182
Class 17 - Acc: 0.7746, IoU: 0.2719
Class 18 - Acc: 0.6986, IoU: 0.5890


  0%|          | 0/393 [00:00<?, ?it/s]


Epoch 10/10
Loss: 0.1515
Pixel Accuracy: 0.9527
Mean Class Accuracy: 0.8330
Mean IoU: 0.6640

Per-class metrics:
Class  0 - Acc: 0.9883, IoU: 0.9773
Class  1 - Acc: 0.9200, IoU: 0.8551
Class  2 - Acc: 0.9547, IoU: 0.9239
Class  3 - Acc: 0.8204, IoU: 0.6405
Class  4 - Acc: 0.8081, IoU: 0.6445
Class  5 - Acc: 0.7439, IoU: 0.4724
Class  6 - Acc: 0.7384, IoU: 0.4752
Class  7 - Acc: 0.8364, IoU: 0.6432
Class  8 - Acc: 0.9434, IoU: 0.9036
Class  9 - Acc: 0.8194, IoU: 0.6683
Class 10 - Acc: 0.9629, IoU: 0.9305
Class 11 - Acc: 0.8365, IoU: 0.7350
Class 12 - Acc: 0.7009, IoU: 0.4385
Class 13 - Acc: 0.9544, IoU: 0.9223
Class 14 - Acc: 0.7445, IoU: 0.5843
Class 15 - Acc: 0.8245, IoU: 0.0964
Class 16 - Acc: 0.6154, IoU: 0.5346
Class 17 - Acc: 0.8647, IoU: 0.5406
Class 18 - Acc: 0.7510, IoU: 0.6294
