In [None]:
import numpy as np
import pandas as pd
from dataclasses import dataclass
import os
import torch
import torchvision.models
from torch import nn
from torch.utils.data import Dataset as TorchDataset, DataLoader
import torch.optim as TorchOptimizers
import torchvision.transforms.v2 as T
from torchvision.models import ResNet50_Weights
from torchvision.models.segmentation import deeplabv3_resnet50
from torchinfo import summary as torch_summary
import matplotlib.pyplot as plt
from PIL import Image, ImageFile
import math
import time
import wandb
from sklearn.model_selection import train_test_split
import json
from typing import Callable

In [None]:
@dataclass
class Config:
    train_csv_filepath: str
    test_csv_filepath: str
    submission_filepath: str
    images_root_folder: str
    image_masks_root_folder: str
    training_output_folder: str
    saved_weights_filepath: str
    device: str

    def __post_init__(self):
        """ For configuration variables that are shared across environments """
        self.num_classes = 12
        # DeepLabV3 takes a lot of memory. Batch size of 8 with other settings uses just under 16 GB of VRAM
        self.batch_size = 8
        # With gradient accumulation steps = 4, effective batch size is 8 * 4 = 32.
        # This allows for most of the benefits of larger batch sizes
        # while fitting within VRAM constraints
        self.gradient_accumulation_steps = 4
        self.starting_learning_rate = 4e-4
        self.max_epochs = 100
        self.patience = 20
        self.seed = 1234
        self.num_workers = 4 if self.device == 'cuda' else 0
        self.pin_memory = self.num_workers > 0
        self.image_width = 512
        self.image_height = 512
        self.image_dims = (512, 512)
        # For mixed precision training, greatly reduces VRAM usage
        self.use_amp = self.device == 'cuda'

    # noinspection PyAttributeOutsideInit
    def init(self, training):
        """ Adjust configuration setup for training vs inference """
        self.training = training

        if self.training:
            os.makedirs(self.training_output_folder, exist_ok=True)

        self.imagenet_mean_cpu_tensor = torch.tensor(imagenet_mean_array)
        self.imagenet_std_cpu_tensor = torch.tensor(imagenet_std_array)
        self.channelwise_imagenet_mean_cpu_tensor = self.imagenet_mean_cpu_tensor.view(3, 1, 1)
        self.channelwise_imagenet_std_cpu_tensor = self.imagenet_std_cpu_tensor.view(3, 1, 1)
        self.imagenet_mean_gpu_tensor = gpu_tensor(imagenet_mean_array)
        self.imagenet_std_gpu_tensor = gpu_tensor(imagenet_std_array)
        self.channelwise_imagenet_mean_gpu_tensor = self.imagenet_mean_gpu_tensor.view(3, 1, 1)
        self.channelwise_imagenet_std_gpu_tensor = self.imagenet_std_gpu_tensor.view(3, 1, 1)

        self.image_transforms = T.Compose([
            T.ToImage(),
            T.Resize(self.image_dims, interpolation=T.InterpolationMode.BILINEAR),
            T.ToDtype(torch.float32, scale=True),
            T.Normalize(self.imagenet_mean_cpu_tensor, self.imagenet_std_cpu_tensor),
        ])

config: Config = None
""" Set to environment-relevant config before training/inference """;

In [None]:
local_config = Config(
    train_csv_filepath='data/train.csv',
    test_csv_filepath='data/test.csv',
    submission_filepath='data/submission.csv',
    images_root_folder='data/images/',
    image_masks_root_folder='data/masks/',
    training_output_folder='data_gen/training_output/',
    saved_weights_filepath='data_gen/training_output/model_weights.pth',
    device='cpu',
)
kaggle_config = Config(
    train_csv_filepath='/kaggle/input/opencv-pytorch-segmentation-project-round2/train.csv',
    test_csv_filepath='/kaggle/input/opencv-pytorch-segmentation-project-round2/test.csv',
    images_root_folder='/kaggle/input/opencv-pytorch-segmentation-project-round2/imgs/imgs/',
    image_masks_root_folder='/kaggle/input/opencv-pytorch-segmentation-project-round2/masks/masks/',
    submission_filepath='/kaggle/working/submission.csv',
    training_output_folder='/kaggle/working/training_output/',
    saved_weights_filepath='/kaggle/working/training_output/model_weights.pth',
    device='cuda',
)

In [None]:
imagenet_mean_array = np.array([0.485, 0.456, 0.406], dtype=np.float32)
imagenet_std_array = np.array([0.229, 0.224, 0.225], dtype=np.float32)

# 12 visually diverse colors for semantic segmentation classes
CLASS_COLORS = np.array([
    [0, 0, 0],        # 0: Black
    [255, 0, 0],      # 1: Red
    [0, 255, 0],      # 2: Green
    [0, 0, 255],      # 3: Blue
    [255, 255, 0],    # 4: Yellow
    [255, 0, 255],    # 5: Magenta
    [0, 255, 255],    # 6: Cyan
    [255, 128, 0],    # 7: Orange
    [128, 0, 255],    # 8: Purple
    [0, 128, 255],    # 9: Light Blue
    [255, 128, 128],  # 10: Pink
    [128, 255, 128],  # 11: Light Green
], dtype=np.uint8)

def gpu_tensor(numpy_array):
    return torch.tensor(numpy_array, device=config.device)

def visualize_image(image_tensor):
    """ Input tensor should be on gpu """
    image = denormalize(image_tensor, config.channelwise_imagenet_mean_gpu_tensor, config.channelwise_imagenet_std_gpu_tensor)
    image = torch.clamp(image, 0, 1)
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * 255).astype('uint8')
    plt.imshow(image)
    plt.axis('off')
    plt.show()
    plt.close()

def visualize_mask(mask_tensor):
    """ mask_tensor: gpu image tensor with pixel intensity indicating class (0-11) """
    mask = mask_tensor.cpu().numpy()
    mask = np.clip(mask, 0, config.num_classes - 1).astype(np.int32)
    colored_mask = CLASS_COLORS[mask]
    plt.imshow(colored_mask)
    plt.axis('off')
    plt.show()
    plt.close()

def normalize(tensor, mean, std):
    return (tensor - mean) / std

def denormalize(tensor, mean, std):
    return tensor * std + mean

In [None]:
@dataclass
class ImageSegmentationDataset(TorchDataset):
    image_ids: np.ndarray
    image_transforms: Callable

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image = Image.open(f'{config.images_root_folder}{image_id}.jpg')
        mask = Image.open(f'{config.image_masks_root_folder}{image_id}.png')

        # Resize mask using nearest-neighbor to preserve class indices
        if config.image_dims is not None:
            mask = mask.resize(config.image_dims, resample=Image.NEAREST)

        transformed_image = self.image_transforms(image)
        # Convert mask to tensor directly - no normalization
        # Mask contains class indices 0-11, shape becomes (H, W)
        mask_array = np.array(mask)
        transformed_mask = torch.from_numpy(mask_array).long()

        return transformed_image, transformed_mask

In [None]:
# test_model = deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)
# torch_summary(test_model)

In [None]:
# test_model

In [None]:
# aux_classifier_children = test_model.aux_classifier.children()
# for layer in aux_classifier_children:
#     print(layer)

In [None]:
def create_deeplab_v3_model() -> nn.Module:
    # Prevent downloads during inference (relevant for kaggle competitions)
    weights_backbone = ResNet50_Weights.DEFAULT if config.training else None
    # Otherwise, initialize resnet50 backbone with imagenet weights
    model = deeplabv3_resnet50(
        weights=None,
        weights_backbone=weights_backbone,
        num_classes=config.num_classes,
        aux_loss=True,
    )

    # Freeze resnet50 backbone
    for param in model.backbone.parameters():
        param.requires_grad = False

    # Enable fine-tuning of classifier and aux-classifier
    for param in model.classifier.parameters():
        param.requires_grad = True
    for param in model.aux_classifier.parameters():
        param.requires_grad = True

    model.to(config.device)

    return model

In [None]:
def dice_coefficient_batch(pred_logits, true_masks, num_classes):
    """Compute per-class intersection and union for a batch.
    
    Args:
        pred_logits: (B, C, H, W) raw model output logits
        true_masks: (B, H, W) ground truth class indices
        num_classes: number of classes
    
    Returns:
        intersection: (num_classes,) tensor
        union: (num_classes,) tensor
    """
    pred_classes = pred_logits.argmax(dim=1)  # (B, H, W)
    
    intersection = torch.zeros(num_classes, device=pred_logits.device)
    union = torch.zeros(num_classes, device=pred_logits.device)
    
    for c in range(num_classes):
        pred_c = (pred_classes == c)
        true_c = (true_masks == c)
        intersection[c] = (pred_c & true_c).sum()
        union[c] = pred_c.sum() + true_c.sum()
    
    return intersection, union

def compute_mean_dice(total_intersection, total_union):
    """Compute mean Dice score from accumulated intersection/union values."""
    valid = total_union > 0
    dice = torch.zeros_like(total_intersection)
    dice[valid] = (2 * total_intersection[valid]) / total_union[valid]
    return dice[valid].mean().item()

def train_one_epoch(start_time, model, loader, optimizer, loss_function, scaler):
    model.train()
    running_loss = 0.0
    
    # Running sums for incremental Dice computation
    total_intersection = torch.zeros(config.num_classes, device=config.device)
    total_union = torch.zeros(config.num_classes, device=config.device)

    num_batches = math.ceil(len(loader.dataset) / config.batch_size)
    accumulation_steps = config.gradient_accumulation_steps
    
    optimizer.zero_grad()
    
    for batch_number, (x, y) in enumerate(loader):
        print(f't={time.time() - start_time:.2f}: Loading training batch {batch_number + 1}/{num_batches}')

        x = x.to(config.device, non_blocking=True)
        y = y.to(config.device, non_blocking=True)

        if batch_number == 0:
            allocated = torch.cuda.memory_allocated(config.device) / 1024**3
            reserved = torch.cuda.memory_reserved(config.device) / 1024**3
            print(f'Memory allocated={allocated:.2f} GiB, reserved={reserved:.2f} GiB')
            print(f'First image:')
            visualize_image(x[0])
            print(f'First mask (unique values: {torch.unique(y[0]).tolist()}):')
            visualize_mask(y[0])

        # Mixed precision forward pass
        with torch.amp.autocast('cuda', enabled=config.use_amp):
            output = model(x)
            preds = output['out'] if isinstance(output, dict) else output
            loss = loss_function(preds, y)
            # Scale loss for gradient accumulation
            loss = loss / accumulation_steps

        # Mixed precision backward pass
        scaler.scale(loss).backward()

        # Update weights every accumulation_steps batches
        if (batch_number + 1) % accumulation_steps == 0 or (batch_number + 1) == num_batches:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        running_loss += loss.item() * accumulation_steps * x.size(0)

        # Compute batch Dice incrementally (no accumulation)
        with torch.no_grad():
            batch_inter, batch_union = dice_coefficient_batch(preds, y, config.num_classes)
            total_intersection += batch_inter
            total_union += batch_union
        
        # Clean up to free memory
        del output, preds, loss, x, y
        
    # Clear GPU cache at end of epoch
    if config.device == 'cuda':
        torch.cuda.empty_cache()

    epoch_loss = running_loss / len(loader.dataset)
    epoch_score = compute_mean_dice(total_intersection, total_union)

    return epoch_loss, epoch_score

@torch.no_grad()
def validate_one_epoch(start_time, model, loader, loss_function):
    model.eval()
    running_loss = 0.0

    # Running sums for incremental Dice computation
    total_intersection = torch.zeros(config.num_classes, device=config.device)
    total_union = torch.zeros(config.num_classes, device=config.device)
    
    num_batches = math.ceil(len(loader.dataset) / config.batch_size)
    for batch_number, (x, y) in enumerate(loader):
        print(f't={time.time() - start_time:.2f}: Loading validation batch {batch_number + 1}/{num_batches}')

        x = x.to(config.device, non_blocking=True)
        y = y.to(config.device, non_blocking=True)

        if batch_number == 0:
            print('First image:')
            visualize_image(x[0])
            print(f'First mask (unique values: {torch.unique(y[0]).tolist()}):')
            visualize_mask(y[0])

        # Mixed precision inference
        with torch.amp.autocast('cuda', enabled=config.use_amp):
            output = model(x)
            preds = output['out'] if isinstance(output, dict) else output
            loss = loss_function(preds, y)

        running_loss += loss.item() * x.size(0)

        # Compute batch Dice incrementally (no accumulation)
        batch_inter, batch_union = dice_coefficient_batch(preds, y, config.num_classes)
        total_intersection += batch_inter
        total_union += batch_union
        
        # Clean up to free memory
        del output, preds, loss, x, y
    
    # Clear GPU cache at end of epoch
    if config.device == 'cuda':
        torch.cuda.empty_cache()

    epoch_loss = running_loss / len(loader.dataset)
    epoch_score = compute_mean_dice(total_intersection, total_union)

    return epoch_loss, epoch_score

In [None]:
def train():
    config.init(training=True)

    start_time = time.time()
    print('t=0: Starting data prep and model loading')

    effective_batch_size = config.batch_size * config.gradient_accumulation_steps
    run = wandb.init(
        project='drone_image_segmentation',
        name=f'run={int(start_time)}',
        config={
            'batch_size': config.batch_size,
            'gradient_accumulation_steps': config.gradient_accumulation_steps,
            'effective_batch_size': effective_batch_size,
            'learning_rate': config.starting_learning_rate,
            'max_epochs': config.max_epochs,
            'seed': config.seed,
            'model': 'deeplabv3_resnet50',
            'optimizer': 'Adam',
            'image_size': config.image_dims,
            'use_amp': config.use_amp,
        },
    )

    train_df = pd.read_csv(config.train_csv_filepath)
    train_array = train_df['ImageID'].to_numpy()
    train_ids, val_ids = train_test_split(train_array, test_size=0.2, random_state=config.seed)

    model = create_deeplab_v3_model()

    wandb.watch(model, log='gradients', log_freq=100)

    train_dataset = ImageSegmentationDataset(train_ids, config.image_transforms)
    val_dataset = ImageSegmentationDataset(val_ids, config.image_transforms)

    def loader(ds, shuffle):
        return DataLoader(ds, shuffle=shuffle, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=config.pin_memory)

    train_loader = loader(train_dataset, shuffle=True)
    val_loader = loader(val_dataset, shuffle=False)

    # Use CrossEntropyLoss for multi-class segmentation with class indices
    loss_function = nn.CrossEntropyLoss()
    optimizer = TorchOptimizers.Adam(model.parameters(), lr=config.starting_learning_rate)
    
    # Mixed precision scaler
    scaler = torch.amp.GradScaler('cuda', enabled=config.use_amp)

    best_val_loss = float('inf')
    best_val_loss_epoch = -1
    best_val_score = float('-inf')
    best_val_score_epoch = -1

    history = dict(train_loss=[], val_loss=[], train_score=[], val_score=[], best_val_score_epoch=dict(), best_val_loss_epoch=dict())

    training_start_time = time.time()
    print(f't={training_start_time - start_time:.2f}: Starting training')
    print(f'Batch size: {config.batch_size}, Gradient accumulation: {config.gradient_accumulation_steps}, Effective batch size: {effective_batch_size}')
    print(f'Image size: {config.image_dims}, AMP enabled: {config.use_amp}')
    torch.manual_seed(config.seed)

    best_score_weights_path = config.training_output_folder + 'best_model_weights.pth'
    best_loss_weights_path = config.training_output_folder + 'best_loss_model_weights.pth'
    
    epochs_since_best = 0

    for epoch in range(1, config.max_epochs + 1):
        epoch_start_time = time.time()
        print(f't={epoch_start_time - start_time:.2f}: Starting epoch {epoch}')
        train_loss, train_score = train_one_epoch(start_time, model, train_loader, optimizer, loss_function, scaler)
        val_loss, val_score = validate_one_epoch(start_time, model, val_loader, loss_function)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_score'].append(train_score)
        history['val_score'].append(val_score)

        print(f'================ Epoch {epoch:03d} stats ==================')
        print(f'train_loss: {train_loss:.4f}  val_loss: {val_loss:.4f}')
        print(f'train_score: {train_score:.4f}  val_score: {val_score:.4f}')
        print('===================================================')

        wandb.log(
            {
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_score': train_score,
                'val_score': val_score,
            }
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_loss_epoch = epoch
            torch.save(model.state_dict(), best_loss_weights_path)

        if val_score > best_val_score:
            best_val_score = val_score
            best_val_score_epoch = epoch
            epochs_since_best = 0
            torch.save(model.state_dict(), best_score_weights_path)
        else:
            epochs_since_best += 1
            if epochs_since_best >= config.patience:
                break

    history['best_val_score_epoch']['epoch'] = best_val_score_epoch
    history['best_val_score_epoch']['val_score'] = best_val_score
    history['best_val_loss_epoch']['epoch'] = best_val_loss_epoch
    history['best_val_loss_epoch']['val_loss'] = best_val_loss

    print()
    print('==================== Results ======================')
    print(f'Best val score epoch: {best_val_score_epoch}')
    print(f'Best val score: {best_val_score:.4f}')
    print(f'Best val loss epoch: {best_val_loss_epoch}')
    print(f'Best val loss: {best_val_loss:.2f}')
    print('===================================================')
    print()

    wandb.run.summary['best_val_score'] = best_val_score
    wandb.run.summary['best_val_score_epoch'] = best_val_score_epoch
    wandb.run.summary['best_val_loss'] = best_val_loss
    wandb.run.summary['best_val_loss_epoch'] = best_val_loss_epoch

    train_score = history['train_score']
    val_score = history['val_score']
    epochs = list(range(1, len(train_score) + 1))

    plt.figure(figsize=(8, 5))

    plt.plot(epochs, train_score, label='train_score', marker='o')
    plt.plot(epochs, val_score, label='val_score', marker='o')

    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Train vs Validation Score per Epoch')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

    wandb.log({"score_curve": wandb.Image(plt.gcf())})

    plt.show()
    plt.close()

    with open(config.training_output_folder + 'history.json', 'w') as json_file:
        json.dump(history, json_file, indent=4)

    wandb.save(best_score_weights_path)
    wandb.save(best_loss_weights_path)
    wandb.save(config.training_output_folder + 'history.json')

    wandb.finish()

In [None]:
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# wandb_key = user_secrets.get_secret("wandb_key")
# !wandb login $wandb_key

In [None]:
config = kaggle_config
train()