# Inference on a model loaded from a checkpoint
Use this notebook to evaluate the model after training. This allows you to load the fine-tuned model, without having to retrain the model in its entirety.

## 1. Validation data preparation

In [1]:
from torch.utils.data import DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [2]:
val_transform = A.Compose([
    A.Resize(768, 1024),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
    ToTensorV2(),
])

In [3]:
from Datasets import Cityscapes

In [4]:
root_dir = 'data'

val_dataset = Cityscapes(root=root_dir, split='val', mode='fine', target_type='labelTrainIds',
                         transforms=val_transform)

In [5]:
val_dataloader = DataLoader(val_dataset, batch_size=24,
                            shuffle=True, num_workers=8, pin_memory=True)

## 2. Inference preparation

In [6]:
import torch
import torch.nn as nn
import torchvision

from torchmetrics import JaccardIndex

# Check for multiple GPUs and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
multi_gpu = torch.cuda.device_count() > 1
if multi_gpu:
    print(f"Using {torch.cuda.device_count()} GPUs")
else:
    print(f"Using device: {device}")

INFO:matplotlib.font_manager:generated new fontManager


Using device: cuda


In [8]:
weights = torchvision.models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT
model = torchvision.models.segmentation.deeplabv3_resnet50(weights=weights)

# Model configuration
model.classifier[4] = nn.Conv2d(256, 20, kernel_size=1)
if model.aux_classifier:
    model.aux_classifier[4] = nn.Conv2d(256, 20, kernel_size=1)
    
# Device configuration
model = torch.nn.DataParallel(model)
model = model.to(device)

# Load checkpoint
checkpoint_path = 'deeplabv3resnet50_finetuned_7681024_1.pth'
model.load_state_dict(torch.load(checkpoint_path, map_location=device))

# Loss metric
miou = JaccardIndex(task="multiclass", num_classes=20).to(device)

## 3. Evaluation

In [9]:
model.eval()

with torch.no_grad():
    for inputs, labels in DataLoader(val_dataset, batch_size=24, shuffle=False, num_workers=8, pin_memory=True):
        inputs = inputs.to(device)
        labels = labels.to(device).long().squeeze(1)

        outputs = model(inputs)['out']
        predicted = torch.argmax(outputs, dim=1)

        miou.update(predicted, labels)

miou_accuracy = miou.compute().item()
print(f"Validation mIoU: {miou_accuracy}")

Validation mIoU: 0.6161433458328247
