# **Initial Setup**

In [1]:
import torch
import os
import gc

def check_gpu():
    print(f"CUDA Available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        print(f"GPU Device Name: {torch.cuda.get_device_name(0)}")
        print(f"Current GPU Device: {torch.cuda.current_device()}")
    else:
        print("No GPU detected. Running on CPU.")

check_gpu()

# CUDA configs
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
gc.collect()

CUDA Available: True
Number of GPUs: 1
GPU Device Name: Tesla V100-SXM2-32GB
Current GPU Device: 0


107

In [1]:
!pip install opencv-python
!pip install numpy
!pip install pandas
!pip install torch torchvision
!pip install transformers
!pip install tqdm
!pip install matplotlib
!pip install albumentations
!pip install ipywidgets
!pip install jupyterlab-widgets

from tqdm.auto import tqdm
from tqdm import TqdmWarning
import warnings
warnings.filterwarnings('ignore', category=TqdmWarning)

Looking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic
Looking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic
Looking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic
Looking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic
^C
[31mERROR: Operation cancelled by user[0m[31m
[0m^C
Traceback (most recen

ModuleNotFoundError: No module named 'tqdm'

# **Main Code**

In [2]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import confusion_matrix

# Define paths
BASE_DIR = "/home/thatkar/projects/def-saadi/thatkar"
ROOT_DIR = os.path.join(BASE_DIR, "CamVid")
TRAIN_DIR = os.path.join(ROOT_DIR, "train")
TRAIN_LABELS_DIR = os.path.join(ROOT_DIR, "train_labels")
VAL_DIR = os.path.join(ROOT_DIR, "val")
VAL_LABELS_DIR = os.path.join(ROOT_DIR, "val_labels")
TEST_DIR = os.path.join(ROOT_DIR, "test")
TEST_LABELS_DIR = os.path.join(ROOT_DIR, "test_labels")
CLASS_DICT_PATH = os.path.join(ROOT_DIR, "class_dict.csv")
CHECKPOINT_DIR = os.path.join(ROOT_DIR, "checkpoints")
VISUALIZATION_DIR = os.path.join(ROOT_DIR, "visualizations")

# Create necessary directories
for directory in [CHECKPOINT_DIR, VISUALIZATION_DIR]:
    os.makedirs(directory, exist_ok=True)

# Training Hyperparameters (kept same as B5 for fair comparison)
IMAGE_HEIGHT = 960
IMAGE_WIDTH = 768
BATCH_SIZE = 1
ACCUMULATION_STEPS = 16
NUM_EPOCHS = 45
LEARNING_RATE = 2e-4
WEIGHT_DECAY = 0.01
GRAD_CLIP_VALUE = 1.0

# Define data augmentation pipelines
train_transforms = A.Compose([
    A.RandomResizedCrop(
        size=(IMAGE_HEIGHT, IMAGE_WIDTH),
        scale=(0.8, 1.0),
        ratio=(0.75, 1.33),
        p=1.0,
        interpolation=cv2.INTER_LINEAR
    ),
    A.OneOf([
        A.RandomBrightnessContrast(
            brightness_limit=0.2,
            contrast_limit=0.2,
            p=0.5
        ),
        A.RandomGamma(gamma_limit=(80, 120), p=0.5),
        A.RandomShadow(
            shadow_roi=(0, 0.5, 1, 1),
            p=0.5
        ),
    ], p=0.7),
    A.OneOf([
        A.Affine(
            scale=(0.9, 1.1),
            translate_percent={'x': (-0.1, 0.1), 'y': (-0.1, 0.1)},
            rotate=(-15, 15),
            border_mode=cv2.BORDER_CONSTANT
        ),
        A.ElasticTransform(
            alpha=120,
            sigma=6,
            p=0.5
        ),
    ], p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transforms = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, interpolation=cv2.INTER_LINEAR),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

class CamVidDataset(Dataset):
    """
    Dataset class for CamVid with enhanced validation and error handling
    """
    def __init__(self, image_dir, label_dir, feature_extractor, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.feature_extractor = feature_extractor
        self.ignore_index = 255
        
        # Load class definitions and create mappings
        self.class_df = pd.read_csv(CLASS_DICT_PATH)
        self.num_classes = len(self.class_df)
        
        # Create color mapping dictionary (BGR format for OpenCV)
        self.color_mapping = {}
        for idx, row in self.class_df.iterrows():
            bgr_color = (int(row['b']), int(row['g']), int(row['r']))
            self.color_mapping[bgr_color] = idx
        
        # Validate and filter images
        all_images = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
        self.images = []
        
        print(f"Validating images in {image_dir}...")
        for img_name in all_images:
            label_name = img_name.replace('.png', '_L.png')
            img_path = os.path.join(image_dir, img_name)
            label_path = os.path.join(label_dir, label_name)
            
            if os.path.exists(label_path):
                img_test = cv2.imread(img_path)
                label_test = cv2.imread(label_path)
                if img_test is not None and label_test is not None:
                    self.images.append(img_name)
                else:
                    print(f"Warning: Failed to load {img_name} or its label")
            else:
                print(f"Warning: Missing label for {img_name}")
        
        print(f"Found {len(self.images)} valid image-label pairs")
        
        if len(self.images) == 0:
            raise RuntimeError(f"No valid image-label pairs found in {image_dir}")

    def __getitem__(self, idx):
        try:
            # Load and process input image
            image_path = os.path.join(self.image_dir, self.images[idx])
            image = cv2.imread(image_path)
            if image is None:
                raise ValueError(f"Failed to load image: {image_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # Load and process label image
            label_path = os.path.join(self.label_dir, self.images[idx].replace('.png', '_L.png'))
            label = cv2.imread(label_path)
            if label is None:
                raise ValueError(f"Failed to load label: {label_path}")
            
            # Create label mask using vectorized operations
            h, w = label.shape[:2]
            pixels = label.reshape(-1, 3)
            pixel_classes = np.full(len(pixels), self.ignore_index, dtype=np.int64)
            
            for bgr_color, class_idx in self.color_mapping.items():
                matches = np.all(pixels == bgr_color, axis=1)
                pixel_classes[matches] = class_idx
            label_mask = pixel_classes.reshape(h, w)
            
            # Apply transforms with error handling
            if self.transforms:
                try:
                    transformed = self.transforms(image=image, mask=label_mask)
                    image = transformed['image']
                    label_mask = transformed['mask']
                except Exception as e:
                    print(f"Transform error for image {self.images[idx]}: {str(e)}")
                    raise
                    
            return {
                'pixel_values': image,
                'labels': torch.as_tensor(label_mask, dtype=torch.long)
            }
        except Exception as e:
            print(f"Error processing image {self.images[idx]}: {str(e)}")
            return self.__getitem__((idx + 1) % len(self))

    def __len__(self):
        return len(self.images)

class EnhancedSegmentationLoss(nn.Module):
    def __init__(self, num_classes, ignore_index=255):
        super().__init__()
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        
        # Enhanced class weights
        class_weights = torch.ones(num_classes)
        rare_classes = [6, 9, 10, 15, 16, 23, 24]
        small_objects = [6, 8, 23, 24]
        vehicle_classes = [5, 22, 27]
        dark_objects = [5, 22, 27, 8, 23]
        
        for class_idx in range(num_classes):
            if class_idx in dark_objects:
                class_weights[class_idx] = 3.5
            elif class_idx in small_objects:
                class_weights[class_idx] = 3.0
            elif class_idx in vehicle_classes:
                class_weights[class_idx] = 2.5
            elif class_idx in rare_classes:
                class_weights[class_idx] = 2.0
        
        self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, ignore_index=ignore_index)
        self.smooth = 1e-5

    def get_boundaries(self, tensor):
        boundaries = torch.zeros_like(tensor, dtype=torch.float)
        kernel_sizes = [3, 5, 7, 9]
        weights = [0.4, 0.3, 0.2, 0.1]
        
        for k_size, weight in zip(kernel_sizes, weights):
            pooled = F.max_pool2d(
                tensor.float(),
                kernel_size=k_size,
                stride=1,
                padding=k_size//2
            )
            boundaries += weight * (pooled != tensor.float()).float()
        return boundaries

    def calculate_iou_loss(self, pred, target):
        pred = F.softmax(pred, dim=1)
        pred = pred.flatten(2)
        target = F.one_hot(target, num_classes=self.num_classes).permute(0, 3, 1, 2).flatten(2)
        
        intersection = (pred * target).sum(-1)
        total = (pred + target).sum(-1)
        union = total - intersection
        valid_mask = union > self.smooth
        iou = torch.zeros_like(intersection)
        iou[valid_mask] = (intersection[valid_mask] + self.smooth) / (union[valid_mask] + self.smooth)
        
        return 1 - iou.mean()

    def forward(self, outputs, targets):
        ce_loss = self.ce_loss(outputs, targets)
        iou_loss = self.calculate_iou_loss(outputs, targets)
        
        edges = self.get_boundaries(targets)
        pred_edges = self.get_boundaries(torch.argmax(outputs, dim=1))
        boundary_loss = F.mse_loss(pred_edges, edges)
        
        total_loss = ce_loss + 0.4 * iou_loss + 0.8 * boundary_loss
        
        return total_loss

def train_epoch(model, train_loader, optimizer, scheduler, scaler, criterion, device, epoch):
    model.train()
    epoch_loss = 0
    batch_losses = []
    
    pbar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    for batch_idx, batch in enumerate(pbar):
        try:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            with torch.amp.autocast(device_type=str(device), dtype=torch.float16):
                outputs = model(pixel_values=pixel_values)
                logits = outputs.logits
                
                logits = F.interpolate(
                    logits,
                    size=labels.shape[-2:],
                    mode="bilinear",
                    align_corners=False
                )
                
                loss = criterion(logits, labels) / ACCUMULATION_STEPS
            
            scaler.scale(loss).backward()
            
            if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_VALUE)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
            
            current_loss = loss.item() * ACCUMULATION_STEPS
            epoch_loss += current_loss
            batch_losses.append(current_loss)
            
            pbar.set_postfix({
                'loss': current_loss,
                'avg_loss': epoch_loss / (batch_idx + 1)
            })
            
        except Exception as e:
            print(f"Error in batch {batch_idx}: {str(e)}")
            continue
            
    return epoch_loss / len(train_loader), batch_losses

@torch.no_grad()
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    predictions = []
    ground_truths = []
    
    for batch in tqdm(val_loader, desc='Validation'):
        try:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits
            
            logits = F.interpolate(
                logits,
                size=labels.shape[-2:],
                mode="bilinear",
                align_corners=False
            )
            
            loss = criterion(logits, labels)
            val_loss += loss.item()
            
            predictions.append(torch.argmax(logits, dim=1).cpu())
            ground_truths.append(labels.cpu())
            
        except Exception as e:
            print(f"Error during validation: {str(e)}")
            continue
            
    return val_loss / len(val_loader), predictions, ground_truths

def plot_training_curves(train_losses, val_losses, save_path):
    plt.figure(figsize=(12, 6))
    plt.plot(train_losses, label='Training Loss', color='blue', alpha=0.7)
    plt.plot(val_losses, label='Validation Loss', color='red', alpha=0.7)
    
    window_size = 5
    if len(train_losses) >= window_size:
        train_ma = np.convolve(train_losses, np.ones(window_size)/window_size, mode='valid')
        val_ma = np.convolve(val_losses, np.ones(window_size)/window_size, mode='valid')
        plt.plot(range(window_size-1, len(train_losses)), train_ma, 
                '--', color='darkblue', alpha=0.5, label='Train Moving Avg')
        plt.plot(range(window_size-1, len(val_losses)), val_ma, 
                '--', color='darkred', alpha=0.5, label='Val Moving Avg')
    
    plt.title('Training and Validation Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(save_path)
    plt.close()


def main():
    # Clear CUDA cache before starting
    torch.cuda.empty_cache()
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    try:
        # Initialize feature extractor
        feature_extractor = SegformerImageProcessor.from_pretrained(
            "nvidia/mit-b3",  # Changed to B3
            do_reduce_labels=True,
            do_rescale=False,
            size={"height": IMAGE_HEIGHT, "width": IMAGE_WIDTH}
        )
        
        # Create datasets
        train_dataset = CamVidDataset(TRAIN_DIR, TRAIN_LABELS_DIR, feature_extractor, transforms=train_transforms)
        val_dataset = CamVidDataset(VAL_DIR, VAL_LABELS_DIR, feature_extractor, transforms=val_transforms)
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        # Initialize model
        model = SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/mit-b3",  # Changed to B3
            num_labels=train_dataset.num_classes,
            id2label={str(i): str(i) for i in range(train_dataset.num_classes)},
            label2id={str(i): i for i in range(train_dataset.num_classes)},
            ignore_mismatched_sizes=True
        ).to(device)
        
        # Enable memory efficient attention if available
        if hasattr(model.config, 'use_memory_efficient_attention'):
            model.config.use_memory_efficient_attention = True
        
        # Initialize optimizer and scheduler
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY
        )
        
        total_steps = len(train_loader) * NUM_EPOCHS // ACCUMULATION_STEPS
        scheduler = OneCycleLR(
            optimizer,
            max_lr=LEARNING_RATE,
            total_steps=total_steps,
            pct_start=0.1
        )
        
        criterion = EnhancedSegmentationLoss(train_dataset.num_classes).to(device)
        scaler = torch.amp.GradScaler()
        
        # Training loop
        train_losses = []
        val_losses = []
        best_val_loss = float('inf')
        
        for epoch in range(NUM_EPOCHS):
            print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
            
            # Training
            train_loss, batch_losses = train_epoch(model, train_loader, optimizer, scheduler, scaler, criterion, device, epoch)
            train_losses.append(train_loss)
    
            # Validation
            val_loss, predictions, ground_truths = validate(model, val_loader, criterion, device)
            val_losses.append(val_loss)
            
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Loss: {val_loss:.4f}")
            
            # Save checkpoint
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_loss': val_loss,
                }
                torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, f'best_model_loss_b3_{val_loss:.4f}.pth'))
                print(f"New best model saved! Val Loss: {val_loss:.4f}")
            
            # Plot training curves
            plot_training_curves(
                train_losses,
                val_losses,
                os.path.join(VISUALIZATION_DIR, f'training_curves_epoch_{epoch+1}.png')
            )
            
            # Clear cache after each epoch
            torch.cuda.empty_cache()
            
    except Exception as e:
        print(f"Error during training: {str(e)}")
        raise

# if __name__ == '__main__':
#     main()

  check_for_updates()


# **Visualization**

In [3]:
import glob
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Find the best model checkpoint
checkpoint_dir = "/home/thatkar/projects/def-saadi/thatkar/CamVid/checkpoints/"
best_model_path = "/home/thatkar/projects/def-saadi/thatkar/CamVid/checkpoints/best_model_loss_b3_0.6197.pth"
print(f"Loading checkpoint from: {best_model_path}")

TEST_IMAGE_PATH = "/home/thatkar/projects/def-saadi/thatkar/CamVid/test/0001TP_006690.png"
# TEST_IMAGE_PATH = "/home/thatkar/projects/def-saadi/thatkar/CamVid/test/0016E5_04530.png"
CLASS_DICT_PATH = "/home/thatkar/projects/def-saadi/thatkar/CamVid/class_dict.csv"

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load class definitions
class_df = pd.read_csv(CLASS_DICT_PATH)

# Initialize feature extractor
feature_extractor = SegformerImageProcessor.from_pretrained(
    "nvidia/mit-b3",  # Changed to B3
    do_reduce_labels=True,
    do_rescale=False,
    size={"height": 640, "width": 640}
)

# Initialize model
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b3",  # Changed to B3
    num_labels=len(class_df),
    ignore_mismatched_sizes=True
).to(device)

# Load checkpoint
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

# Get label path
LABEL_PATH = TEST_IMAGE_PATH.replace('/test/', '/test_labels/').replace('.png', '_L.png')

def visualize_prediction(model, image_path, label_path, feature_extractor, class_df, device):
    # Read images
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    label = cv2.imread(label_path)
    orig_h, orig_w = image.shape[:2]

    test_transform = A.Compose([
        A.Resize(height=640, width=640),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

    transformed = test_transform(image=image)
    image_tensor = transformed['image'].unsqueeze(0)

    model.eval()
    with torch.no_grad():
        outputs = model(pixel_values=image_tensor.to(device))
        logits = outputs.logits
        upsampled_logits = nn.functional.interpolate(
            logits,
            size=(orig_h, orig_w),
            mode="bilinear",
            align_corners=False
        )
        predicted = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()

    # Create visualization masks
    pred_mask = np.zeros_like(image)
    truth_mask = np.zeros_like(image)

    for idx, row in class_df.iterrows():
        class_color = np.array([row['r'], row['g'], row['b']])
        pred_mask[predicted == idx] = class_color
        bgr_color = (int(row['b']), int(row['g']), int(row['r']))
        truth_pixels = np.all(label == bgr_color, axis=2)
        truth_mask[truth_pixels] = class_color

    # Plot results
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(truth_mask)
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')
    
    axes[2].imshow(pred_mask)
    axes[2].set_title('Model Prediction (B3)')  # Added B3 to title
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Run visualization
# visualize_prediction(model, TEST_IMAGE_PATH, LABEL_PATH, feature_extractor, class_df, device)

Loading checkpoint from: /home/thatkar/projects/def-saadi/thatkar/CamVid/checkpoints/best_model_loss_b3_0.6197.pth


  return func(*args, **kwargs)
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b3 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Testing Model Interpretability using XAI 

In [3]:
import matplotlib.cm as cm

from transformers import SegformerImageProcessor

def visualize_attention(model, image_tensor, processor, original_image, class_names, device):
    """
    Generate and visualize the attention maps from the last layer for interpretability.
    """
    model.eval()
    image_tensor = image_tensor.unsqueeze(0).to(device)
    
    # Forward pass with hooks
    with torch.no_grad():
        outputs = model(pixel_values=image_tensor)
        logits = outputs.logits  # shape: [1, num_classes, H, W]
        preds = torch.argmax(logits, dim=1)[0].cpu().numpy()

    # Upsample for original image size
    upsampled = F.interpolate(logits, size=(IMAGE_HEIGHT, IMAGE_WIDTH), mode='bilinear', align_corners=False)
    predicted_class = torch.argmax(upsampled, dim=1)[0].cpu().numpy()
    
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    axs[0].imshow(original_image)
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    axs[1].imshow(predicted_class, cmap='tab20')
    axs[1].set_title("Predicted Segmentation")
    axs[1].axis('off')

    plt.tight_layout()
    plt.show()


In [4]:
from PIL import Image

processor = SegformerImageProcessor.from_pretrained(
    "nvidia/mit-b3",
    do_reduce_labels=True,
    do_rescale=False,
    size={"height": IMAGE_HEIGHT, "width": IMAGE_WIDTH}
)

# Load one test image for explainability
sample_img_path = os.path.join(TEST_IMAGE_PATH)  # update this filename if needed
original = cv2.imread(sample_img_path)
original_rgb = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
resized = cv2.resize(original_rgb, (IMAGE_WIDTH, IMAGE_HEIGHT))

# Normalize like training
normalized = processor(resized, return_tensors="pt")["pixel_values"][0]

# Load the best trained model
model_path = os.path.join(CHECKPOINT_DIR, "best_model_loss_b3_0.6197.pth")  # Use the actual best file name
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])

# Class names from CSV
class_names = list(pd.read_csv(CLASS_DICT_PATH)['name'])

# Run visualization
visualize_attention(model, normalized, processor, resized, class_names, device)

  return func(*args, **kwargs)


NameError: name 'TEST_IMAGE_PATH' is not defined

In [6]:
from torchcam.methods import GradCAM
from torchvision.transforms.functional import normalize
import matplotlib.pyplot as plt

# Load trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b3",
    num_labels=32,
    ignore_mismatched_sizes=True
)
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "best_model_loss_b3_0.6197.pth"))["model_state_dict"])
model.eval().to(device)

# Initialize Grad-CAM with last encoder layer
cam_extractor = GradCAM(model, target_layer="encoder.encoder.layer.11")  # Adjust layer depending on variant

# Pick a sample image from test set
sample_img_path = os.path.join(TEST_IMAGE_PATH)  # update this filename if needed
img = cv2.imread(sample_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img, (IMAGE_WIDTH, IMAGE_HEIGHT))
input_tensor = val_transforms(image=img_resized)["image"].unsqueeze(0).to(device)

# Forward pass to get logits
with torch.no_grad():
    out = model(pixel_values=input_tensor)
    logits = out.logits
    preds = torch.argmax(logits, dim=1)

# Generate CAM for a specific class (e.g., class 10: car)
target_class = 10
activation_map = cam_extractor(class_idx=target_class, scores=logits)[0].cpu().numpy()

# Plot
plt.figure(figsize=(12, 5))
plt.subplot(1, 3, 1)
plt.imshow(img_resized)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(preds.squeeze().cpu(), cmap="tab20")
plt.title("Predicted Segmentation")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(img_resized)
plt.imshow(activation_map, cmap="jet", alpha=0.5)
plt.title(f"Grad-CAM for class {target_class}")
plt.axis("off")
plt.tight_layout()
plt.show()

ModuleNotFoundError: No module named 'torchcam'