In [1]:
import os
from datasets_gta5 import GTA5
import albumentations as A
import torch

GTA5_PATH = '/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/GTA5/GTA5'
GTA5_IMAGES = os.path.join(GTA5_PATH, 'images')
GTA5_LABELS = os.path.join(GTA5_PATH, 'labels')

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [2]:

transform = A.Compose([
    A.Resize(512, 1024)
])
GTA5_dataset = GTA5(GTA5_path=GTA5_PATH, transform=transform)



In [3]:
from models import CustomResNet
encoder = CustomResNet()
# encoder.load_state_dict(torch.load('student_backbone_checkpoint.pth')['backbone_state_dict'])

encoder.eval()
encoder.to(device)




CustomResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), strid

In [4]:
asd = torch.randn(1, 3, 512, 1024).to(device)
encoder(asd)["feature_map"].shape

# Freeze all parameters of the encoder
for param in encoder.parameters():
    param.requires_grad = True



In [5]:
# First, let's create a simple decoder network
import numpy as np
import tqdm as tqdm
class SegmentationDecoder(torch.nn.Module):
    def __init__(self, in_channels=2048, num_classes=19):
        super().__init__()
        self.decoder = torch.nn.Sequential(
            # 16x32 -> 32x64
            torch.nn.ConvTranspose2d(in_channels, 1024, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            
            # 32x64 -> 64x128
            torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            
            # 64x128 -> 128x256
            torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            
            # 128x256 -> 256x512
            torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            
            # 256x512 -> 512x1024
            torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            
            # Final 1x1 conv to get to num_classes
            torch.nn.Conv2d(64, num_classes, kernel_size=1)
        )
    def forward(self, x):
        x = self.decoder(x)
        # Ensure exact output size
        if x.shape[-2:] != (512, 1024):
            x = torch.nn.functional.interpolate(
                x, size=(512, 1024), 
                mode='bilinear', 
                align_corners=False
            )
        return x

# Initialize decoder, optimizer, and loss function
decoder = SegmentationDecoder().to(device)
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
    k = (b >= 0) & (b < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iou(hist: np.ndarray) -> np.ndarray:
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)

def train_epoch(encoder, decoder, dataloader, optimizer, criterion, device, num_classes=19):
    decoder.train()
    encoder.train()  # Keep DINO frozen
    
    total_loss = 0
    hist = np.zeros((num_classes, num_classes))  # Single histogram for entire epoch
    total_pixels = 0
    correct_pixels = 0
    
    for images, labels in tqdm.tqdm(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Get DINO features
        # with torch.no_grad():
        features = encoder(images)["feature_map"]
        
        # Forward pass through decoder
        outputs = decoder(features)
        
        # Resize outputs to match label size if needed
        if outputs.shape[-2:] != labels.shape[-2:]:
            outputs = torch.nn.functional.interpolate(
                outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate metrics
        preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        
        # Pixel Accuracy
        valid_mask = labels != 255  # Ignore index
        total_pixels += valid_mask.sum().item()
        correct_pixels += ((preds == labels) & valid_mask).sum().item()
        
        # IoU
        preds = preds.cpu().numpy()
        target = labels.cpu().numpy()
        hist += fast_hist(preds.flatten(), target.flatten(), num_classes)
    
    # Calculate final metrics
    pixel_acc = correct_pixels / total_pixels
    
    # Per-class accuracy (mean class accuracy)
    class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)
    mean_class_acc = np.nanmean(class_acc)
    
    # IoU metrics
    iou = per_class_iou(hist)
    mean_iou = np.nanmean(iou)
    
    metrics = {
        'loss': total_loss / len(dataloader),
        'pixel_acc': pixel_acc,
        'mean_class_acc': mean_class_acc,
        'mean_iou': mean_iou,
        'class_iou': iou,
        'class_acc': class_acc
    }
    
    return metrics

train_loader = torch.utils.data.DataLoader(
    GTA5_dataset, 
    batch_size=4,
    shuffle=True,
    num_workers=4
)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    metrics = train_epoch(encoder, decoder, train_loader, optimizer, criterion, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Loss: {metrics['loss']:.4f}")
    print(f"Pixel Accuracy: {metrics['pixel_acc']:.4f}")
    print(f"Mean Class Accuracy: {metrics['mean_class_acc']:.4f}")
    print(f"Mean IoU: {metrics['mean_iou']:.4f}")
    
    # Optionally print per-class metrics
    print("\nPer-class metrics:")
    for i in range(19):  # Assuming 19 classes
        print(f"Class {i:2d} - Acc: {metrics['class_acc'][i]:.4f}, IoU: {metrics['class_iou'][i]:.4f}")

100%|██████████| 625/625 [05:04<00:00,  2.05it/s]



Epoch 1/10
Loss: 1.0042
Pixel Accuracy: 0.8131
Mean Class Accuracy: 0.3035
Mean IoU: 0.2119

Per-class metrics:
Class  0 - Acc: 0.9479, IoU: 0.9112
Class  1 - Acc: 0.4739, IoU: 0.3257
Class  2 - Acc: 0.6635, IoU: 0.5876
Class  3 - Acc: 0.3823, IoU: 0.0700
Class  4 - Acc: 0.0076, IoU: 0.0012
Class  5 - Acc: 0.0200, IoU: 0.0007
Class  6 - Acc: 0.0012, IoU: 0.0005
Class  7 - Acc: 0.0030, IoU: 0.0016
Class  8 - Acc: 0.6669, IoU: 0.4891
Class  9 - Acc: 0.6008, IoU: 0.2622
Class 10 - Acc: 0.8642, IoU: 0.8031
Class 11 - Acc: 0.0023, IoU: 0.0003
Class 12 - Acc: 0.0005, IoU: 0.0000
Class 13 - Acc: 0.5293, IoU: 0.4233
Class 14 - Acc: 0.5114, IoU: 0.1243
Class 15 - Acc: 0.0106, IoU: 0.0013
Class 16 - Acc: 0.0783, IoU: 0.0213
Class 17 - Acc: 0.0027, IoU: 0.0020
Class 18 - Acc: 0.0001, IoU: 0.0001


100%|██████████| 625/625 [05:02<00:00,  2.06it/s]



Epoch 2/10
Loss: 0.5420
Pixel Accuracy: 0.8578
Mean Class Accuracy: 0.3853
Mean IoU: 0.2714

Per-class metrics:
Class  0 - Acc: 0.9575, IoU: 0.9333
Class  1 - Acc: 0.6296, IoU: 0.4457
Class  2 - Acc: 0.7383, IoU: 0.6627
Class  3 - Acc: 0.5219, IoU: 0.2756
Class  4 - Acc: 0.0285, IoU: 0.0000
Class  5 - Acc: 0.5090, IoU: 0.0061
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0207, IoU: 0.0028
Class  8 - Acc: 0.7244, IoU: 0.5757
Class  9 - Acc: 0.6691, IoU: 0.4452
Class 10 - Acc: 0.8985, IoU: 0.8491
Class 11 - Acc: 0.0000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.7140, IoU: 0.5871
Class 14 - Acc: 0.5454, IoU: 0.3653
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.2911, IoU: 0.0080
Class 17 - Acc: 0.0736, IoU: 0.0001
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:09<00:00,  2.02it/s]



Epoch 3/10
Loss: 0.4474
Pixel Accuracy: 0.8694
Mean Class Accuracy: 0.4404
Mean IoU: 0.2925

Per-class metrics:
Class  0 - Acc: 0.9605, IoU: 0.9393
Class  1 - Acc: 0.6845, IoU: 0.4889
Class  2 - Acc: 0.7617, IoU: 0.6874
Class  3 - Acc: 0.5666, IoU: 0.3313
Class  4 - Acc: 0.5337, IoU: 0.0441
Class  5 - Acc: 0.4405, IoU: 0.0674
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0418, IoU: 0.0063
Class  8 - Acc: 0.7413, IoU: 0.5980
Class  9 - Acc: 0.7017, IoU: 0.4898
Class 10 - Acc: 0.9103, IoU: 0.8633
Class 11 - Acc: 0.0000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.7591, IoU: 0.6366
Class 14 - Acc: 0.5490, IoU: 0.3969
Class 15 - Acc: 0.0137, IoU: 0.0000
Class 16 - Acc: 0.1809, IoU: 0.0080
Class 17 - Acc: 0.5217, IoU: 0.0001
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:21<00:00,  1.94it/s]



Epoch 4/10
Loss: 0.3998
Pixel Accuracy: 0.8774
Mean Class Accuracy: 0.4623
Mean IoU: 0.3114

Per-class metrics:
Class  0 - Acc: 0.9630, IoU: 0.9424
Class  1 - Acc: 0.6986, IoU: 0.5060
Class  2 - Acc: 0.7858, IoU: 0.7109
Class  3 - Acc: 0.6149, IoU: 0.3757
Class  4 - Acc: 0.5206, IoU: 0.1590
Class  5 - Acc: 0.4578, IoU: 0.1030
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0656, IoU: 0.0054
Class  8 - Acc: 0.7545, IoU: 0.6145
Class  9 - Acc: 0.7230, IoU: 0.5221
Class 10 - Acc: 0.9161, IoU: 0.8715
Class 11 - Acc: 0.0000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.7828, IoU: 0.6665
Class 14 - Acc: 0.5560, IoU: 0.4277
Class 15 - Acc: 0.7750, IoU: 0.0002
Class 16 - Acc: 0.1692, IoU: 0.0113
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:18<00:00,  1.97it/s]



Epoch 5/10
Loss: 0.3691
Pixel Accuracy: 0.8840
Mean Class Accuracy: 0.5277
Mean IoU: 0.3295

Per-class metrics:
Class  0 - Acc: 0.9657, IoU: 0.9464
Class  1 - Acc: 0.7189, IoU: 0.5318
Class  2 - Acc: 0.7985, IoU: 0.7254
Class  3 - Acc: 0.6484, IoU: 0.4077
Class  4 - Acc: 0.5407, IoU: 0.2147
Class  5 - Acc: 0.4760, IoU: 0.1247
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.1132, IoU: 0.0053
Class  8 - Acc: 0.7630, IoU: 0.6253
Class  9 - Acc: 0.7370, IoU: 0.5430
Class 10 - Acc: 0.9210, IoU: 0.8773
Class 11 - Acc: 1.0000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.7987, IoU: 0.6872
Class 14 - Acc: 0.6004, IoU: 0.4685
Class 15 - Acc: 0.5672, IoU: 0.0035
Class 16 - Acc: 0.3776, IoU: 0.0990
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:21<00:00,  1.95it/s]



Epoch 6/10
Loss: 0.3422
Pixel Accuracy: 0.8916
Mean Class Accuracy: 0.5649
Mean IoU: 0.3674

Per-class metrics:
Class  0 - Acc: 0.9678, IoU: 0.9492
Class  1 - Acc: 0.7343, IoU: 0.5554
Class  2 - Acc: 0.8114, IoU: 0.7392
Class  3 - Acc: 0.6759, IoU: 0.4458
Class  4 - Acc: 0.5666, IoU: 0.2546
Class  5 - Acc: 0.4890, IoU: 0.1428
Class  6 - Acc: 0.0053, IoU: 0.0000
Class  7 - Acc: 0.3257, IoU: 0.0115
Class  8 - Acc: 0.7729, IoU: 0.6382
Class  9 - Acc: 0.7548, IoU: 0.5681
Class 10 - Acc: 0.9247, IoU: 0.8828
Class 11 - Acc: 0.8371, IoU: 0.0013
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8124, IoU: 0.7043
Class 14 - Acc: 0.7332, IoU: 0.5560
Class 15 - Acc: 0.7383, IoU: 0.1539
Class 16 - Acc: 0.5842, IoU: 0.3779
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:01<00:00,  2.07it/s]



Epoch 7/10
Loss: 0.3161
Pixel Accuracy: 0.8988
Mean Class Accuracy: 0.6234
Mean IoU: 0.4084

Per-class metrics:
Class  0 - Acc: 0.9703, IoU: 0.9527
Class  1 - Acc: 0.7466, IoU: 0.5767
Class  2 - Acc: 0.8263, IoU: 0.7558
Class  3 - Acc: 0.7073, IoU: 0.4823
Class  4 - Acc: 0.6031, IoU: 0.3065
Class  5 - Acc: 0.5035, IoU: 0.1583
Class  6 - Acc: 0.5900, IoU: 0.0231
Class  7 - Acc: 0.7235, IoU: 0.1230
Class  8 - Acc: 0.7802, IoU: 0.6478
Class  9 - Acc: 0.7666, IoU: 0.5857
Class 10 - Acc: 0.9281, IoU: 0.8875
Class 11 - Acc: 0.6169, IoU: 0.0523
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8327, IoU: 0.7291
Class 14 - Acc: 0.8073, IoU: 0.6301
Class 15 - Acc: 0.7692, IoU: 0.3528
Class 16 - Acc: 0.6729, IoU: 0.4956
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:01<00:00,  2.07it/s]



Epoch 8/10
Loss: 0.2998
Pixel Accuracy: 0.9030
Mean Class Accuracy: 0.6303
Mean IoU: 0.4445

Per-class metrics:
Class  0 - Acc: 0.9711, IoU: 0.9539
Class  1 - Acc: 0.7571, IoU: 0.5845
Class  2 - Acc: 0.8360, IoU: 0.7655
Class  3 - Acc: 0.7286, IoU: 0.5109
Class  4 - Acc: 0.6284, IoU: 0.3354
Class  5 - Acc: 0.5278, IoU: 0.1734
Class  6 - Acc: 0.5810, IoU: 0.0764
Class  7 - Acc: 0.6968, IoU: 0.2643
Class  8 - Acc: 0.7860, IoU: 0.6541
Class  9 - Acc: 0.7736, IoU: 0.6015
Class 10 - Acc: 0.9303, IoU: 0.8904
Class 11 - Acc: 0.5340, IoU: 0.1885
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8442, IoU: 0.7447
Class 14 - Acc: 0.8279, IoU: 0.6598
Class 15 - Acc: 0.8079, IoU: 0.4820
Class 16 - Acc: 0.7457, IoU: 0.5606
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:06<00:00,  2.04it/s]



Epoch 9/10
Loss: 0.2808
Pixel Accuracy: 0.9083
Mean Class Accuracy: 0.6407
Mean IoU: 0.4691

Per-class metrics:
Class  0 - Acc: 0.9735, IoU: 0.9573
Class  1 - Acc: 0.7711, IoU: 0.6084
Class  2 - Acc: 0.8466, IoU: 0.7781
Class  3 - Acc: 0.7457, IoU: 0.5347
Class  4 - Acc: 0.6598, IoU: 0.3831
Class  5 - Acc: 0.5488, IoU: 0.1873
Class  6 - Acc: 0.5466, IoU: 0.1055
Class  7 - Acc: 0.6838, IoU: 0.3030
Class  8 - Acc: 0.7932, IoU: 0.6649
Class  9 - Acc: 0.7872, IoU: 0.6204
Class 10 - Acc: 0.9330, IoU: 0.8935
Class 11 - Acc: 0.5515, IoU: 0.2404
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8499, IoU: 0.7519
Class 14 - Acc: 0.8496, IoU: 0.6852
Class 15 - Acc: 0.8560, IoU: 0.5703
Class 16 - Acc: 0.7770, IoU: 0.6282
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 625/625 [05:02<00:00,  2.07it/s]


Epoch 10/10
Loss: 0.2632
Pixel Accuracy: 0.9133
Mean Class Accuracy: 0.6571
Mean IoU: 0.4930

Per-class metrics:
Class  0 - Acc: 0.9747, IoU: 0.9590
Class  1 - Acc: 0.7830, IoU: 0.6237
Class  2 - Acc: 0.8578, IoU: 0.7914
Class  3 - Acc: 0.7660, IoU: 0.5668
Class  4 - Acc: 0.6851, IoU: 0.4189
Class  5 - Acc: 0.5689, IoU: 0.2065
Class  6 - Acc: 0.5763, IoU: 0.1325
Class  7 - Acc: 0.7205, IoU: 0.3465
Class  8 - Acc: 0.7998, IoU: 0.6733
Class  9 - Acc: 0.7946, IoU: 0.6361
Class 10 - Acc: 0.9361, IoU: 0.8980
Class 11 - Acc: 0.5662, IoU: 0.2605
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8638, IoU: 0.7721
Class 14 - Acc: 0.8702, IoU: 0.7258
Class 15 - Acc: 0.8836, IoU: 0.6627
Class 16 - Acc: 0.8378, IoU: 0.6929
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000





In [6]:
# Epoch 3/10
# Loss: 1.2027
# Pixel Accuracy: 0.6440
# Mean Class Accuracy: 0.1213
# Mean IoU: 0.0818

# Per-class metrics:
# Class  0 - Acc: 0.6736, IoU: 0.6670
# Class  1 - Acc: 0.1000, IoU: 0.0000
# Class  2 - Acc: 0.3823, IoU: 0.1999
# Class  3 - Acc: 0.0000, IoU: 0.0000
# Class  4 - Acc: 0.0000, IoU: 0.0000
# Class  5 - Acc: 0.0000, IoU: 0.0000
# Class  6 - Acc: 0.0000, IoU: 0.0000
# Class  7 - Acc: 0.0000, IoU: 0.0000
# Class  8 - Acc: 0.4519, IoU: 0.0889
# Class  9 - Acc: 0.0000, IoU: 0.0000
# Class 10 - Acc: 0.6970, IoU: 0.5976
# Class 11 - Acc: 0.0000, IoU: 0.0000
# Class 12 - Acc: 0.0000, IoU: 0.0000
# Class 13 - Acc: 0.0000, IoU: 0.0000
# Class 14 - Acc: 0.0000, IoU: 0.0000
# Class 15 - Acc: 0.0000, IoU: 0.0000
# Class 16 - Acc: 0.0000, IoU: 0.0000
# Class 17 - Acc: 0.0000, IoU: 0.0000
# Class 18 - Acc: 0.0000, IoU: 0.0000

In [7]:
import matplotlib.pyplot as plt
from datasets import get_id_to_color   


img_idx = 310

embeddings = encoder.get_intermediate_layers(GTA5_dataset[img_idx][0].unsqueeze(0).to(device), n=1, reshape=True, return_class_token=False, norm=False)[0]
out = decoder(embeddings)
id_to_color = get_id_to_color()

pred = out.argmax(1).cpu().numpy()
pred = pred.reshape(518, 1036)
# Convert class IDs to RGB colors
color_map = np.array([id_to_color.get(i, (0, 0, 0)) for i in range(max(id_to_color.keys()) + 1)])
pred_rgb = color_map[pred]

plt.figure(figsize=(10, 10))
plt.imshow(pred_rgb)

plt.figure(figsize=(10, 10))
labels = GTA5_dataset[img_idx][1].cpu().numpy()
color_map = np.array([id_to_color.get(i, (0, 0, 0)) for i in range(max(id_to_color.keys()) + 1)])
pred_rgb = np.zeros((*labels.shape, 3), dtype=np.uint8)
mask = labels < len(color_map)
pred_rgb[mask] = color_map[labels[mask]]
plt.imshow(pred_rgb)


ImportError: cannot import name 'get_id_to_color' from 'datasets' (unknown location)

In [28]:
GTA5_dataset[img_idx][1].shape


torch.Size([518, 1036])