In [None]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from tqdm import tqdm
import math
import seaborn as sns
import random
from collections import OrderedDict
from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import torch.nn.functional as F
import torchvision.transforms as transforms
import timm
from transformers import get_cosine_schedule_with_warmup
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.patches as patches
import torchvision.utils as vutils

In [None]:
BASE_URL = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/'
OUTPUT_DIR = 'rsna24-results'
IMAGE_URL = '/kaggle/input/crop-and-convert-to-png-axial/Converted_images/'
LEVELS = ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
SEED = 8620
DEBUG = False # if set to true, run fewer computations
if DEBUG:
    import unittest
device = 'cuda' if torch.cuda.is_available() else 'cpu'
N_WORKERS = os.cpu_count()


N_LABELS = 25
N_CLASSES = 3 * N_LABELS

N_FOLDS = 4 if not DEBUG else 2
EPOCHS = 30 if not DEBUG else 20
MODEL_NAME = 'resnet18' if DEBUG else 'densenet161.tv_in1k'

GRAD_ACCUMULATION = 4
TARGET_BATCH_SIZE = 16
BATCH_SIZE = TARGET_BATCH_SIZE // GRAD_ACCUMULATION
MAX_GRAD_NORM = 1
EARLY_STOPPING_EPOCH = 6

LEARNING_RATE = 2e-4 * TARGET_BATCH_SIZE / 32
WEIGHT_DECAY = 1e-2
AUGMENTATION = True

USE_AUTOMATIC_MIXED_PRECISION = True # can change True if using T4 or newer than Ampere
AUGMENTATION_PROBABILITY = 0.75

SAGITTAL_IMAGE_SHAPE = (490, 275) # based on an image that was (620, 620), height first then width
AXIAL_IMAGE_SHAPE = (310, 250) # based on an image that was (620, 620), height first then width

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok = True)

In [None]:
def set_random_seed(seed: int = 8620, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = deterministic  # type: ignore

set_random_seed(SEED)

In [None]:
train_df = pd.read_csv(f'{BASE_URL}train.csv')
train_descriptions = pd.read_csv(f'{BASE_URL}train_series_descriptions.csv')
coordinates_df = pd.read_csv(f'{BASE_URL}train_label_coordinates.csv')
coordinates_df = pd.merge(coordinates_df, train_descriptions, on = ['study_id', 'series_id'])
coordinates_df.head()

In [None]:
sagittal_t2_coordinates = pd.read_csv('/kaggle/input/crop-and-convert-to-png-axial/Updated_coordinates/new_sagittal_t2_coordinates.csv')
sagittal_t1_coordinates = pd.read_csv('/kaggle/input/crop-and-convert-to-png-axial/Updated_coordinates/new_sagittal_t1_coordinates.csv')
axial_t2_coordinates = pd.read_csv('/kaggle/input/crop-and-convert-to-png-axial/Updated_coordinates/new_axial_t2_coordinates.csv')

### Sorting the coordinates dataframes

In [None]:
sagittal_t2_coordinates.sort_values(by = ['study_id', 'level', 'condition'], inplace = True)
sagittal_t1_coordinates.sort_values(by = ['study_id', 'condition', 'level'], ascending = [True, False, True], inplace = True)
axial_t2_coordinates.sort_values(by = ['study_id', 'level', 'condition'], ascending = [True, True, False], inplace = True)

In [None]:
condition_encoding = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}

def map_condition(condition):
    return condition_encoding.get(condition, -100) # we need to make sure that these NA filled with -100 are not used

train_df.iloc[:, 1:] = train_df.iloc[:, 1:].map(map_condition)
train_df.head()

## Getting study ids that have all the coordinates

In [None]:
def get_side(instance_number, num_images):
    return 'Right' if instance_number < num_images // 2 else 'Left'

def get_num_images(study_id, series_id):
    image_paths = glob(f'{BASE_URL}train_images/{study_id}/{series_id}/*dcm')
    return len(image_paths)

def filter_study_ids(sagittal_t2_coordinates, sagittal_t1_coordinates, axial_t2_coordinates):
    valid_study_ids = set(sagittal_t2_coordinates['study_id'].unique())

    valid_sagittal_t2_studies = sagittal_t2_coordinates.groupby('study_id').filter(
        lambda x: set(x['level']) == set(['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1'])
    )['study_id'].unique()

    valid_sagittal_t1_studies = sagittal_t1_coordinates.groupby('study_id').filter(
        lambda x: set(x['level']) == set(['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1'])
    )['study_id'].unique()

    valid_axial_t2_studies = axial_t2_coordinates.groupby('study_id').filter(
        lambda x: set(x['level']) == set(['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1'])
    )['study_id'].unique()

    valid_study_ids = set(valid_sagittal_t2_studies) & set(valid_sagittal_t1_studies) & set(valid_axial_t2_studies)
    
    return valid_study_ids

def filter_sides(sagittal_t1_coordinates, axial_t2_coordinates, valid_study_ids):
    final_study_ids = set()
    
    for study_id in valid_study_ids:
        sagittal_t1_study = sagittal_t1_coordinates[sagittal_t1_coordinates['study_id'] == study_id]
        axial_t2_study = axial_t2_coordinates[axial_t2_coordinates['study_id'] == study_id]
        
        sagittal_t1_sides = set()
        axial_t2_sides = set()
        
        for level in LEVELS:
            sagittal_t1_level = sagittal_t1_study[sagittal_t1_study['level'] == level]
            axial_t2_level = axial_t2_study[axial_t2_study['level'] == level]
            
            if sagittal_t1_level.empty or axial_t2_level.empty:
                continue
            
            num_images_sagittal_t1 = get_num_images(study_id, sagittal_t1_level.iloc[0]['series_id'])
            sagittal_t1_level = sagittal_t1_level[sagittal_t1_level['instance_number'] < num_images_sagittal_t1]
            sagittal_t1_sides.update(sagittal_t1_level.apply(lambda row: get_side(row['instance_number'], num_images_sagittal_t1), axis=1))
            
            num_images_axial_t2 = get_num_images(study_id, axial_t2_level.iloc[0]['series_id'])
            axial_t2_level = axial_t2_level[axial_t2_level['instance_number'] < num_images_axial_t2]
            axial_t2_sides.update(axial_t2_level['condition'].apply(lambda cond: cond.split()[0]))
        
        if sagittal_t1_sides == {'Right', 'Left'} and axial_t2_sides == {'Right', 'Left'}:
            final_study_ids.add(study_id)
    
    return final_study_ids

valid_study_ids = filter_study_ids(sagittal_t2_coordinates, sagittal_t1_coordinates, axial_t2_coordinates)
final_study_ids = filter_sides(sagittal_t1_coordinates, axial_t2_coordinates, valid_study_ids)

all_labels_df = train_df[train_df['study_id'].isin(final_study_ids)]
all_labels_df


# Dataloader

In [None]:
class MultiTaskDataset(Dataset):
    def __init__(
        self, 
        label_df = train_df,
        series_descriptions = train_descriptions,
        sagittal_t2_coordinates = sagittal_t2_coordinates, 
        sagittal_t1_coordinates = sagittal_t1_coordinates, 
        axial_t2_coordinates = axial_t2_coordinates, 
        phase = 'train', 
        transform = None
    ):
        self.label_df = label_df
        self.series_descriptions = series_descriptions
        self.sagittal_t2_coordinates = sagittal_t2_coordinates
        self.sagittal_t1_coordinates = sagittal_t1_coordinates
        self.axial_t2_coordinates = axial_t2_coordinates
        self.transform = transform
        self.phase = phase
        self.levels = ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
        self.sides = ['right', 'left']

        # Initialize dictionaries
        self.images = {}
        self.keypoints = {}
        self.masks = {}
        
        self.PILToTensor = transforms.Compose([transforms.PILToTensor()])
        # Preload images
        self.preload_images()

    def preload_images(self):
        study_ids = self.label_df['study_id']
        for idx in range(len(self.label_df)):
            study_id = study_ids.iloc[idx]
            if study_id not in self.images:
                self.images[study_id] = {}
                self.keypoints[study_id] = {}
                self.masks[study_id] = {}
                
                # loading Sagittal T2
                self._load_sagittal_t2(study_id)
                
                # loading Sagittal T1
                self._load_sagittal_t1(study_id)
                
                # loading Axial T2
                self._load_axial_t2(study_id)

    def _load_sagittal_t2(self, study_id):
        description = 'Sagittal T2/STIR'
        series_id_df = self.series_descriptions.query('@study_id == study_id and @description == series_description')
        study_id_coordinates = self.sagittal_t2_coordinates.query('@study_id == study_id')
        self.keypoints[study_id]['Sagittal T2'] = {level: [] for level in self.levels}
        self.masks[study_id]['Sagittal T2'] = 0

        if not series_id_df.empty:
            series_id = series_id_df['series_id'].iloc[0]
            image_paths = glob(f'{IMAGE_URL}{study_id}/Sagittal_T2/*.png')
            try:
                self.images[study_id]['Sagittal T2'] = torch.stack([self.PILToTensor(Image.open(path).convert('RGB')) for path in image_paths], dim=1).squeeze()
            except Exception as e:
                print(f'Study id: {study_id} Sagittal T2 error while loading image: {str(e)}')
            coordinates_dict = {}
            for level in self.levels:
                level_coords = study_id_coordinates.query('level == @level')
                if level_coords.empty:
                    coordinates_dict[level] = []
                else:
                    coordinates_dict[level] = level_coords[['x', 'y']].astype(float).values.tolist()
            
            self.keypoints[study_id]['Sagittal T2'] = coordinates_dict
            self.masks[study_id]['Sagittal T2'] = 1
        else:
            self.images[study_id]['Sagittal T2'] = torch.zeros((3, SAGITTAL_IMAGE_SHAPE[0], SAGITTAL_IMAGE_SHAPE[1]))
            self.keypoints[study_id]['Sagittal T2'] = {level: [] for level in self.levels}
            self.masks[study_id]['Sagittal T2'] = 0

    def _load_sagittal_t1(self, study_id):
        description = 'Sagittal T1'
        series_ids = self.series_descriptions.query('@study_id == study_id and @description == series_description')
        self.images[study_id]['Sagittal T1'] = []
        self.keypoints[study_id]['Sagittal T1'] = {side: {level: [] for level in self.levels} for side in self.sides}
        self.masks[study_id]['Sagittal T1'] = {side: {level: 0 for level in self.levels} for side in self.sides}
        for side in self.sides:
            if not series_ids.empty:
                series_id = series_ids['series_id'].iloc[0]
                series_id_coordinates = self.sagittal_t1_coordinates.query('@study_id == study_id and @series_id == series_id and condition.str.startswith(@side.capitalize())')
                image_paths = glob(f'{IMAGE_URL}{study_id}/Sagittal_T1/{side.lower()}.png')
                if image_paths:
                    self.images[study_id]['Sagittal T1'].append(torch.stack([self.PILToTensor(Image.open(path).convert('RGB')) for path in image_paths], dim=1).squeeze())
                    side_keypoints = {level: series_id_coordinates.query('level == @level')[['x', 'y']].astype(float).values.tolist() for level in self.levels}
                    self.keypoints[study_id]['Sagittal T1'][side] = side_keypoints
                    self.masks[study_id]['Sagittal T1'][side] = 1
                else:
                    self.images[study_id]['Sagittal T1'].append(torch.zeros((3, SAGITTAL_IMAGE_SHAPE[0], SAGITTAL_IMAGE_SHAPE[1])))
                    self.keypoints[study_id]['Sagittal T1'][side] = {level: [] for level in self.levels}
                    self.masks[study_id]['Sagittal T1'][side] = 0
            else:
                self.images[study_id]['Sagittal T1'].append(torch.zeros((3, SAGITTAL_IMAGE_SHAPE[0], SAGITTAL_IMAGE_SHAPE[1])))
                self.keypoints[study_id]['Sagittal T1'][side] = {level: [] for level in self.levels}
                self.masks[study_id]['Sagittal T1'][side] = 0

    def _load_axial_t2(self, study_id):
        description = 'Axial T2'
        series_id = self.series_descriptions.query('@study_id == study_id and @description == series_description')
        self.images[study_id]['Axial T2'] = []
        self.keypoints[study_id]['Axial T2'] = {level: {side: [] for side in self.sides} for level in self.levels}
        self.masks[study_id]['Axial T2'] = {level: {side: 0 for side in self.sides} for level in self.levels}
        for level in self.levels:
            if not series_id.empty:
                level_path = level.replace('/', '-')
                image_paths = glob(f'{IMAGE_URL}{study_id}/Axial_T2/{level_path}.png')
                if image_paths:
                    self.images[study_id]['Axial T2'].append(torch.stack([self.PILToTensor(Image.open(path).convert('RGB')) for path in image_paths], dim=1).squeeze())
                    level_keypoints = {}
                    for side in self.sides:
                        query_result = self.axial_t2_coordinates.query('@study_id == study_id and level == @level and condition.str.startswith(@side.capitalize())')
                        if not query_result.empty:
                            level_keypoints[side] = query_result[['x', 'y']].astype(float).values.tolist()
                        else:
                            level_keypoints[side] = []
                            
                    self.keypoints[study_id]['Axial T2'][level] = level_keypoints
                    self.masks[study_id]['Axial T2'][level] = 1
                else:
                    self.images[study_id]['Axial T2'].append(torch.zeros((3, AXIAL_IMAGE_SHAPE[0], AXIAL_IMAGE_SHAPE[1])))
                    self.keypoints[study_id]['Axial T2'][level] = {side: [] for side in self.sides}
                    self.masks[study_id]['Axial T2'][level] = 0
            else:
                self.images[study_id]['Axial T2'].append(torch.zeros((3, AXIAL_IMAGE_SHAPE[0], AXIAL_IMAGE_SHAPE[1])))
                self.keypoints[study_id]['Axial T2'][level] = {side: [] for side in self.sides}
                self.masks[study_id]['Axial T2'][level] = 0

    def extend_dataset(self, new_label_df):
        # Extend the dataset with new samples
        self.label_df = new_label_df
        self.preload_images()  # This will only load new images

    def __len__(self):
        return len(self.label_df)
    
    def augment_image_and_keypoints(self, image_tensor, keypoints):
        # Convert image tensor to numpy (H, W, C)
        image_np = image_tensor.permute(1, 2, 0).cpu().numpy()
        assert np.sum(image_np) > 0, "Input image is all zeros"
        
        # Flatten keypoints for augmentation
        flat_keypoints = [point for level_points in keypoints.values() for point in level_points] # they should be side points for axial

        # Apply augmentations to both image and keypoints
        augmented = self.transform(image=image_np, keypoints=flat_keypoints)
        augmented_image_tensor = augmented['image']

        # Perform checks on the augmented image
        assert augmented_image_tensor.numel() > 0, "Augmented image is empty"
        assert torch.any(augmented_image_tensor != 0), "Augmented image is all zeros"
        assert augmented_image_tensor.dtype == torch.float32, "Augmented image is not float32"
        assert torch.all(augmented_image_tensor >= -3) and torch.all(augmented_image_tensor <= 3), \
            f"Augmented image values out of expected range: min={augmented_image_tensor.min()}, max={augmented_image_tensor.max()}"
        assert augmented_image_tensor.shape[0] == 3 and augmented_image_tensor.shape[1] > 0 and augmented_image_tensor.shape[2] > 0, \
            f"Unexpected augmented image shape: {augmented_image_tensor.shape}"

        # Reconstruct keypoints dictionary
        augmented_keypoints = augmented['keypoints']
        reconstructed_keypoints = {}
        idx = 0
        for level, points in keypoints.items():
            level_points = augmented_keypoints[idx:idx+len(points)]
            # Check if keypoints are still in the image
            valid_points = [point for point in level_points if 0 <= point[0] < augmented_image_tensor.shape[2] and 0 <= point[1] < augmented_image_tensor.shape[1]]
            reconstructed_keypoints[level] = valid_points if valid_points else []
            idx += len(points)
        
        # Perform checks on the reconstructed keypoints
        assert len(reconstructed_keypoints) == len(keypoints), "Number of keypoint levels changed after augmentation"

        return augmented_image_tensor, reconstructed_keypoints

    def __getitem__(self, idx):
        labels_row = self.label_df.iloc[idx]
        study_id = labels_row['study_id']
        labels = labels_row.iloc[1:].values.astype(np.int64)

        # Get images and keypoints
        sagittal_t2_image = self.images[study_id]['Sagittal T2']
        sagittal_t1_images = self.images[study_id]['Sagittal T1']
        axial_t2_images = self.images[study_id]['Axial T2']

        sagittal_t2_keypoints = self.keypoints[study_id]['Sagittal T2']
        sagittal_t1_keypoints = self.keypoints[study_id]['Sagittal T1']
        axial_t2_keypoints = self.keypoints[study_id]['Axial T2']

        # Get masks
        mask = [
            self.masks[study_id]['Sagittal T2'],
            *self.masks[study_id]['Sagittal T1'].values(),
            *self.masks[study_id]['Axial T2'].values()
        ]
        mask = torch.tensor(mask, dtype=torch.float32)

        # applying transformations
        if self.transform is not None:
            if self.masks[study_id]['Sagittal T2'] == 1:
                sagittal_t2_image, sagittal_t2_keypoints = self.augment_image_and_keypoints(sagittal_t2_image, sagittal_t2_keypoints)
            
            for i, side in enumerate(self.sides):
                if self.masks[study_id]['Sagittal T1'][side] == 1:  
                    sagittal_t1_images[i], sagittal_t1_keypoints[side] = self.augment_image_and_keypoints(sagittal_t1_images[i], sagittal_t1_keypoints[side])
                    
            for i, level in enumerate(self.levels):
                if self.masks[study_id]['Axial T2'][level] == 1:
                    axial_t2_images[i], axial_t2_keypoints[level] = self.augment_image_and_keypoints(axial_t2_images[i], axial_t2_keypoints[level])
                    
        x = torch.nested.nested_tensor([
            sagittal_t2_image.unsqueeze(0),
            torch.stack(sagittal_t1_images),
            torch.stack(axial_t2_images)
        ])
            
        assert x.numel() > 0, "Transformed image is empty"
        assert any(torch.any(tensor != 0) for tensor in x.unbind()), "Transformed image is all zeros"
        assert x.dtype == torch.float32, "Transformed image is not float32"
        assert len([_ for _ in x.unbind()]) == 3  # Nested tensor has 3 elements: sagittal T2, sagittal T1, axial T2
        assert x[0].shape == (1, 3, SAGITTAL_IMAGE_SHAPE[0], SAGITTAL_IMAGE_SHAPE[1])  # Sagittal T2
        assert x[1].shape == (2, 3, SAGITTAL_IMAGE_SHAPE[0], SAGITTAL_IMAGE_SHAPE[1])  # Sagittal T1 (2 images)
        assert x[2].shape == (5, 3, AXIAL_IMAGE_SHAPE[0], AXIAL_IMAGE_SHAPE[1])  # Axial T2 (5 images)
        
        
        keypoints = {
            'Sagittal T2': sagittal_t2_keypoints,
            'Sagittal T1': sagittal_t1_keypoints,
            'Axial T2': axial_t2_keypoints
        }
                    
        return {
            'images': x,
            'labels': labels,
            'keypoints': keypoints,
            'mask': mask,
            'study_id': study_id
        }
    
def custom_collate(batch):
    # Separate the different items
    images = [item['images'] for item in batch]
    labels = [item['labels'] for item in batch]
    keypoints = [item['keypoints'] for item in batch]
    masks = [item['mask'] for item in batch]
    study_ids = [item['study_id'] for item in batch]

    # Create nested tensor for images, preserving the structure
    nested_images = []
    for i in range(len([_ for _ in images[0].unbind()])):  # For each image type
        nested_images.append(torch.stack([sample[i] for sample in images]))
    
    images = torch.nested.nested_tensor(nested_images)

    # Stack labels and masks
    labels = torch.stack([torch.tensor(l) for l in labels])
    masks = torch.stack(masks)

    # Keypoints and study_ids can remain as lists

    return {
        'images': images,
        'labels': labels,
        'keypoints': keypoints,
        'mask': masks,
        'study_id': study_ids
    }

## Data Augmentation
Taken from Kaggle winner, but for a different problem. See, for instance, [this notebook](https://www.kaggle.com/code/haqishen/1st-place-soluiton-code-small-ver) (past Kaggle competition winner)

In [None]:
transforms_train = A.Compose([
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=AUGMENTATION_PROBABILITY),
    A.OneOf([
        A.MotionBlur(blur_limit=5),
        A.MedianBlur(blur_limit=5),
        A.GaussianBlur(blur_limit=5),
        A.GaussNoise(var_limit=(5.0, 30.0)),
    ], p=AUGMENTATION_PROBABILITY),

    A.OneOf([
        A.OpticalDistortion(distort_limit=0.1),
        A.GridDistortion(num_steps=5, distort_limit=0.1),
        A.ElasticTransform(alpha=1),
    ], p=AUGMENTATION_PROBABILITY),

    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=AUGMENTATION_PROBABILITY),
    A.CoarseDropout(max_holes=16, max_height=16, max_width=16, min_holes=1, min_height=8, min_width=8, p=AUGMENTATION_PROBABILITY),    
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet stats
    ToTensorV2()
], keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))

transforms_validation = A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet stats
    ToTensorV2()
], keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))

if DEBUG or not AUGMENTATION:
    transforms_train = transforms_validation

In [None]:
if DEBUG:
    def test_transformations():
        # Create a dummy image and keypoints
        dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
        dummy_keypoints = [(100, 100), (150, 150)]

        # Test training transformations
        print("Testing training transformations:")
        for _ in range(5):  # Test multiple times due to randomness
            transformed = transforms_train(image=dummy_image, keypoints=dummy_keypoints)
            transformed_image = transformed['image']
            transformed_keypoints = transformed['keypoints']

            # Check if the image is not null and within expected range
            assert transformed_image.numel() > 0, "Transformed image is empty"
            assert torch.any(transformed_image != 0), "Transformed image is all zeros"
            assert transformed_image.dtype == torch.float32, "Transformed image is not float32"
            assert torch.all(transformed_image >= -3) and torch.all(transformed_image <= 3), f"Transformed image values out of expected range: min={transformed_image.min()}, max={transformed_image.max()}"

            # Check keypoints
            assert transformed_keypoints is not None, "Transformation resulted in no keypoints"
            assert len(transformed_keypoints) == len(dummy_keypoints), "Transformation changed number of keypoints"
            assert all(isinstance(kp, tuple) and len(kp) == 2 for kp in transformed_keypoints), "Invalid keypoint format"
            assert all(0 <= kp[0] < 224 and 0 <= kp[1] < 224 for kp in transformed_keypoints), "Keypoints out of image bounds"

            # Check image shape and type
            assert transformed_image.shape == (3, 224, 224), f"Unexpected shape: {transformed_image.shape}"
            assert isinstance(transformed_image, torch.Tensor), "Transformed image is not a torch.Tensor"

            print("  Passed")

        # Test validation transformations
        print("\nTesting validation transformations:")
        transformed = transforms_validation(image=dummy_image, keypoints=dummy_keypoints)
        transformed_image = transformed['image']
        transformed_keypoints = transformed['keypoints']

        # Check if the image is not null and within expected range
        assert transformed_image.numel() > 0, "Transformed image is empty"
        assert torch.any(transformed_image != 0), "Transformed image is all zeros"
        assert transformed_image.dtype == torch.float32, "Transformed image is not float32"
        assert torch.all(transformed_image >= -3) and torch.all(transformed_image <= 3), f"Transformed image values out of expected range: min={transformed_image.min()}, max={transformed_image.max()}"

        # Check image shape and type
        assert transformed_image.shape == (3, 224, 224), f"Unexpected shape: {transformed_image.shape}"
        assert isinstance(transformed_image, torch.Tensor), "Transformed image is not a torch.Tensor"

        # Check if keypoints are unchanged (validation should not modify keypoints)
        assert transformed_keypoints == dummy_keypoints, "Validation transformation modified keypoints"

        print("  Passed")

        print("\nAll transformation tests passed!")

    # Run the test
    test_transformations()

In [None]:
if DEBUG:
    def test_multitask_dataset():
        # Create a small subset of your data for testing
        test_df = all_labels_df.head(10)  # Use first 10 rows for testing

        # Initialize the dataset
        test_dataset = MultiTaskDataset(
            label_df=test_df,
            series_descriptions=train_descriptions,
            sagittal_t2_coordinates=sagittal_t2_coordinates,
            sagittal_t1_coordinates=sagittal_t1_coordinates,
            axial_t2_coordinates=axial_t2_coordinates,
            phase='train',
            transform=transforms_train
        )

        # Create a DataLoader
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=2,
            shuffle=False,
            collate_fn=custom_collate,
            num_workers=0  # Use 0 for easier debugging
        )

        # Iterate through the dataloader
        for batch in test_dataloader:
            images = batch['images']
            labels = batch['labels']
            keypoints = batch['keypoints']
            masks = batch['mask']
            study_ids = batch['study_id']

            print(f"Batch size: {len(study_ids)}")
            print(f"Study IDs: {study_ids}")
            print(f"Labels shape: {labels.shape}")
            print(f"Masks shape: {masks.shape}")

            # Check nested tensor structure
            assert len([_ for _ in images.unbind()]) == 3, "Nested tensor should have 3 elements"
            print(f"Sagittal T2 shape: {images[0].shape}")
            print(f"Sagittal T1 shape: {images[1].shape}")
            print(f"Axial T2 shape: {images[2].shape}")

            # Visualize one image from each type
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))

            # Function to normalize image for plotting
            def normalize_for_plot(img):
                return (img - img.min()) / (img.max() - img.min())

            # Function to add rectangle patches for keypoints
            def add_keypoint_patches_sagittal_t2(ax, kp, color='r', height = SAGITTAL_IMAGE_SHAPE[0], width = SAGITTAL_IMAGE_SHAPE[1]):
                for level, points in kp.items():
                    x, y = points[0]
                    if 0 <= x < width and 0 <= y < height:
                        rect = plt.Rectangle((x-5, y-5), 10, 10, fill=False, edgecolor=color)
                        ax.add_patch(rect)
                        ax.text(x, y-15, level, color=color, fontsize=8, ha='center')

            # Plot Sagittal T2
            axes[0].imshow(normalize_for_plot(images[0][0][0]).permute(1, 2, 0))
            axes[0].set_title("Sagittal T2")
            add_keypoint_patches_sagittal_t2(axes[0], keypoints[0]['Sagittal T2'], height=images[0][0][0].shape[1], width=images[0][0][0].shape[2])

            def add_keypoint_patches_sagittal_t1(ax, kp, color='r', height = SAGITTAL_IMAGE_SHAPE[0], width = SAGITTAL_IMAGE_SHAPE[1]):
                for level, points in kp['right'].items():
                    x, y = points[0]
                    if 0 <= x < width and 0 <= y < height:
                        rect = plt.Rectangle((x-5, y-5), 10, 10, fill=False, edgecolor=color)
                        ax.add_patch(rect)
                        ax.text(x, y-15, level, color=color, fontsize=8, ha='center')

            # Plot Sagittal T1
            axes[1].imshow(normalize_for_plot(images[1][0][0]).permute(1, 2, 0))
            axes[1].set_title("Sagittal T1")
            add_keypoint_patches_sagittal_t1(axes[1], keypoints[0]['Sagittal T1'], height=images[1][0][0].shape[1], width=images[1][0][0].shape[2])

            def add_keypoint_patches_axial_t2(ax, kp, color='r', height = AXIAL_IMAGE_SHAPE[0], width = AXIAL_IMAGE_SHAPE[1]):
                for level, points in kp['L1/L2'].items():
                    x, y = points[0]
                    if 0 <= x < width and 0 <= y < height:
                        rect = plt.Rectangle((x-5, y-5), 10, 10, fill=False, edgecolor=color)
                        ax.add_patch(rect)
                        ax.text(x, y-15, level, color=color, fontsize=8, ha='center')
                    else:
                        print(f"Skipping keypoint {level} at {x}, {y} as it is out of bounds")

            # Plot Axial T2
            axes[2].imshow(normalize_for_plot(images[2][0][0]).permute(1, 2, 0))
            axes[2].set_title("Axial T2")
            add_keypoint_patches_axial_t2(axes[2], keypoints[0]['Axial T2'], height=images[2][0][0].shape[1], width=images[2][0][0].shape[2])

            plt.tight_layout()
            plt.show()

            # Print keypoints for one study
            print("Keypoints for first study:")
            for img_type, kp in keypoints[0].items():
                print(f"  {img_type}:")
                for level, points in kp.items():
                    print(f"    {level}: {points}")

            # Only process one batch for this test
            break

    if __name__ == "__main__":
        test_multitask_dataset()

## Need to get rid of some transformations that completely distort the image
### Unfortunately, Albumentations works on CPU so might slow everything down. Below is the implementation of transformations with torchvision, which are however limited in what they can do

# GPU Cropping

In [None]:
def gpu_crop_region(x, keypoints_to_use, crop_size=(64, 64), epoch=0, sr_model=None):
    """
    Crop region around the predicted keypoint (x, y) on the GPU with optional super-resolution.
    
    x: Nested tensor containing images of different types
    keypoints_to_use: List of dicts with keypoints for each image type
    crop_size: (height, width) size of the crop
    epoch: Current training epoch
    sr_model: Super-resolution model (if applicable)
    
    Returns: Tuple of (Nested tensor of cropped images, Nested tensor of masks)
    """
    h, w = crop_size
    all_cropped_imgs = []
    all_masks = []

    # Get the device from the input tensor
    device = x[0].device

    for i, img_tensor in enumerate(x.unbind()):
        if i == 0:  # Sagittal T2
            img = img_tensor.squeeze(1)  # Remove the extra dimension
            num_crops = 5
            keypoint_order = [('Sagittal T2', level) for level in LEVELS]
            img_h, img_w = SAGITTAL_IMAGE_SHAPE
        elif i == 1:  # Sagittal T1
            img = img_tensor
            num_crops = 10
            keypoint_order = [('Sagittal T1', side, level) for side in ['right', 'left'] for level in LEVELS]
            img_h, img_w = SAGITTAL_IMAGE_SHAPE
        else:  # Axial T2
            img = img_tensor
            num_crops = 10
            keypoint_order = [('Axial T2', level, side) for level in LEVELS for side in ['right', 'left']]
            img_h, img_w = AXIAL_IMAGE_SHAPE

        batch_size = img.shape[0]
        
        # Prepare keypoints and masks
        keypoints = torch.zeros(batch_size, num_crops, 2, device=device)
        masks = torch.zeros(batch_size, num_crops, device=device)

        for j in range(batch_size):
            for idx, key in enumerate(keypoint_order):
                if i == 0:  # Sagittal T2
                    kp = keypoints_to_use[j][key[0]][key[1]]
                elif i == 1:  # Sagittal T1
                    kp = keypoints_to_use[j][key[0]][key[1]][key[2]]
                else:  # Axial T2
                    kp = keypoints_to_use[j][key[0]][key[1]][key[2]]
                
                if is_valid_keypoint(kp, img_h, img_w):
                    keypoints[j, idx] = torch.tensor(kp, device=device)
                    masks[j, idx] = 1.0
                else:
                    # If keypoint is invalid, set it to the center of the image
                    keypoints[j, idx] = torch.tensor([img_w // 2, img_h // 2], device=device)
                    masks[j, idx] = 0.0

        # Crop images
        cropped_imgs = crop_image(img, keypoints, crop_size)

        # Set crops with invalid keypoints to zero
        cropped_imgs[masks == 0] = 0

        # Apply super-resolution if applicable
        if sr_model:
            with torch.no_grad():
                if epoch > 15:
                    output_size = (256, 256)
                elif epoch > 10:
                    output_size = (128, 128)
                elif epoch > 5:
                    output_size = (64, 64)
                else:
                    output_size = crop_size
                
                valid_crops = cropped_imgs[masks.bool()]
                if valid_crops.numel() > 0:
                    sr_crops = sr_model(valid_crops, output_size=output_size)
                    sr_cropped_imgs = torch.zeros(batch_size, num_crops, 3, *output_size, device=device)
                    sr_idx = 0
                    for b in range(batch_size):
                        for n in range(num_crops):
                            if masks[b, n] == 1:
                                sr_cropped_imgs[b, n] = sr_crops[sr_idx]
                                sr_idx += 1
                    cropped_imgs = sr_cropped_imgs

        all_cropped_imgs.append(cropped_imgs)
        all_masks.append(masks)
    
    # Return as nested tensors
    return torch.nested.as_nested_tensor(all_cropped_imgs), torch.cat(all_masks, dim=1)

def is_valid_keypoint(kp, img_h, img_w):
    """
    Check if a keypoint is valid (within image boundaries and not None).
    """
    if isinstance(kp, (list, tuple)):
        return (len(kp) == 2 and 
                0 <= kp[0] < img_w and 
                0 <= kp[1] < img_h)
    elif isinstance(kp, torch.Tensor):
        return (kp.numel() == 2 and 
                torch.all((0 <= kp) & (kp < torch.tensor([img_w, img_h], device=kp.device))))
    else:
        return False

def crop_image(img, keypoints, crop_size):
    """
    GPU-friendly function to crop images based on keypoints.
    
    img: Tensor of shape (batch_size, num_images, channels, height, width) or (batch_size, channels, height, width)
    keypoints: Tensor of shape (batch_size, num_crops, 2)
    crop_size: Tuple (height, width) for the size of crops
    
    Returns: Tensor of cropped images (batch_size, num_crops, channels, crop_height, crop_width)
    """
    batch_size, num_crops, _ = keypoints.shape
    h, w = crop_size
    
    if img.dim() == 4:
        img = img.unsqueeze(1)  # Add dimension for single image if not present
    _, num_images, c, img_h, img_w = img.shape

    # Compute crop boundaries
    left = (keypoints[:, :, 0] - w // 2).long().clamp(min=0, max=img_w - w)
    top = (keypoints[:, :, 1] - h // 2).long().clamp(min=0, max=img_h - h)
    
    # Create base indices for each crop
    base_y = torch.arange(h, device=img.device).view(1, 1, -1, 1).expand(batch_size, num_crops, -1, w)
    base_x = torch.arange(w, device=img.device).view(1, 1, 1, -1).expand(batch_size, num_crops, h, -1)
    
    # Offset indices by crop top-left corners
    y_indices = (top.view(batch_size, num_crops, 1, 1) + base_y).clamp(max=img_h-1)
    x_indices = (left.view(batch_size, num_crops, 1, 1) + base_x).clamp(max=img_w-1)
    
    # Reshape indices for gathering
    b_indices = torch.arange(batch_size, device=img.device).view(-1, 1, 1, 1).expand(-1, num_crops, h, w)
    i_indices = (torch.arange(num_crops, device=img.device) % num_images).view(1, -1, 1, 1).expand(batch_size, -1, h, w)
    
    # Gather crops
    crops = img[b_indices, i_indices, :, y_indices, x_indices].permute(0, 1, 4, 2, 3)
    
    return crops

In [None]:
if DEBUG:
    class TestGPUCropRegion(unittest.TestCase):
        def setUp(self):
            self.device = 'cpu'
            self.batch_size = 2
            self.crop_size = (64, 64)

            # Create dummy nested tensor input
            self.sagittal_t2 = torch.randn(self.batch_size, 1, 3, 224, 224, device=self.device)
            self.sagittal_t1 = torch.randn(self.batch_size, 2, 3, 224, 224, device=self.device)
            self.axial = torch.randn(self.batch_size, 5, 3, 128, 128, device=self.device)
            self.dummy_input = torch.nested.nested_tensor([self.sagittal_t2, self.sagittal_t1, self.axial])

            # Create dummy keypoints
            self.keypoints_to_use = [
                {
                    'Sagittal T2': {level: [112, 112] for level in LEVELS},
                    'Sagittal T1': {
                        'right': {level: [112, 112] for level in LEVELS},
                        'left': {level: [112, 112] for level in LEVELS}
                    },
                    'Axial T2': {level: {'right': [64, 64], 'left': [64, 64]} for level in LEVELS}
                }
                for _ in range(self.batch_size)
            ]

        def test_gpu_crop_region_output_shape(self):
            cropped_images, crop_masks = gpu_crop_region(self.dummy_input, self.keypoints_to_use, self.crop_size)

            self.assertEqual(len([_ for _ in cropped_images.unbind()]), 3, "Should have 3 image types")
            self.assertEqual(cropped_images[0].shape, (self.batch_size, 5, 3, *self.crop_size), "Incorrect shape for Sagittal T2")
            self.assertEqual(cropped_images[1].shape, (self.batch_size, 10, 3, *self.crop_size), "Incorrect shape for Sagittal T1")
            self.assertEqual(cropped_images[2].shape, (self.batch_size, 10, 3, *self.crop_size), "Incorrect shape for Axial T2")

            self.assertEqual(crop_masks.shape, (self.batch_size, N_LABELS), "Incorrect mask shape for masks")

        def test_gpu_crop_region_mask_values(self):
            _, crop_masks = gpu_crop_region(self.dummy_input, self.keypoints_to_use, self.crop_size)

            for mask in crop_masks:
                self.assertTrue(torch.all(mask == 1), "All masks should be 1 when all keypoints are provided")

        def test_gpu_crop_region_missing_keypoints(self):
            # Remove some keypoints
            self.keypoints_to_use[0]['Sagittal T2']['L1/L2'] = []
            self.keypoints_to_use[0]['Sagittal T1']['right']['L2/L3'] = []
            self.keypoints_to_use[0]['Axial T2']['L3/L4']['left'] = []

            _, crop_masks = gpu_crop_region(self.dummy_input, self.keypoints_to_use, self.crop_size)

            self.assertEqual(crop_masks[0, 0].item(), 0, "Mask should be 0 for missing Sagittal T2 keypoint")
            self.assertEqual(crop_masks[0, 5 + 1].item(), 0, "Mask should be 0 for missing Sagittal T1 right keypoint")
            self.assertEqual(crop_masks[0, 15 + 5].item(), 0, "Mask should be 0 for missing Axial T2 left keypoint")

        def test_crop_image_output(self):
            # Test crop_image function directly
            img = torch.randn(2, 1, 3, 224, 224, device=self.device)
            keypoints = torch.tensor([[[112, 112]], [[56, 56]]], device=self.device)
            crops = crop_image(img, keypoints, self.crop_size)

            self.assertEqual(crops.shape, (2, 1, 3, *self.crop_size), "Incorrect crop shape")

            # Check if the crop is centered around the keypoint
            self.assertTrue(torch.allclose(crops[0, 0, :, 32, 32], img[0, 0, :, 112, 112], atol=1e-6))
            self.assertTrue(torch.allclose(crops[1, 0, :, 32, 32], img[1, 0, :, 56, 56], atol=1e-6))

        def test_gpu_crop_region_super_resolution(self):
            # Mock super-resolution model
            class MockSRModel(torch.nn.Module):
                def forward(self, x, output_size):
                    return F.interpolate(x, size=output_size, mode='bilinear', align_corners=False)

            sr_model = MockSRModel()

            for epoch in [0, 6, 11, 16]:
                cropped_images, _ = gpu_crop_region(self.dummy_input, self.keypoints_to_use, self.crop_size, epoch=epoch, sr_model=sr_model)

                if epoch > 15:
                    expected_size = (256, 256)
                elif epoch > 10:
                    expected_size = (128, 128)
                elif epoch > 5:
                    expected_size = (64, 64)
                else:
                    expected_size = self.crop_size

                for img_type in cropped_images:
                    self.assertEqual(img_type.shape[-2:], expected_size, f"Incorrect super-resolution size for epoch {epoch}")

        def test_gpu_crop_region_preserves_gradients(self):
            # Create input tensors that require gradients
            sagittal_t2 = torch.randn(self.batch_size, 1, 3, 224, 224, device=self.device, requires_grad=True)
            sagittal_t1 = torch.randn(self.batch_size, 2, 3, 224, 224, device=self.device, requires_grad=True)
            axial = torch.randn(self.batch_size, 5, 3, 128, 128, device=self.device, requires_grad=True)
            dummy_input = torch.nested.as_nested_tensor([sagittal_t2, sagittal_t1, axial])

            # Perform the crop operation
            cropped_images, _ = gpu_crop_region(dummy_input, self.keypoints_to_use, self.crop_size)

            # Check if cropped images require gradients
            self.assertTrue(cropped_images[0].requires_grad, "Cropped Sagittal T2 images should require gradients")
            self.assertTrue(cropped_images[1].requires_grad, "Cropped Sagittal T1 images should require gradients")
            self.assertTrue(cropped_images[2].requires_grad, "Cropped Axial T2 images should require gradients")

            # Compute a dummy loss and perform backward pass
            dummy_loss = cropped_images[0].sum() + cropped_images[1].sum() + cropped_images[2].sum()
            dummy_loss.backward()

            # Check if gradients are computed for input tensors
            self.assertIsNotNone(sagittal_t2.grad, "Sagittal T2 input should have gradients")
            self.assertIsNotNone(sagittal_t1.grad, "Sagittal T1 input should have gradients")
            self.assertIsNotNone(axial.grad, "Axial T2 input should have gradients")

            # Check if gradients are non-zero
            self.assertGreater(sagittal_t2.grad.abs().sum().item(), 0, "Sagittal T2 gradients should be non-zero")
            self.assertGreater(sagittal_t1.grad.abs().sum().item(), 0, "Sagittal T1 gradients should be non-zero")
            self.assertGreater(axial.grad.abs().sum().item(), 0, "Axial T2 gradients should be non-zero")

    unittest.main(argv=[''], exit=False)



# Model

In [None]:
class MultiTaskModel(nn.Module):
    def __init__(self, backbone_name=MODEL_NAME, sagittal_t2_num_classes=5, sagittal_t1_num_classes=10, sagittal_num_keypoints=5, axial_num_classes=10, axial_num_keypoints=2):
        super(MultiTaskModel, self).__init__()
        # Load a timm backbone
        self.sagittal_backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0, global_pool='')
        self.axial_backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0, global_pool='')
        self.sagittal_t2_cropped_backbone = timm.create_model('efficientnet_b1', pretrained=True, num_classes=0, global_pool='')
        self.sagittal_t1_cropped_backbone = timm.create_model('efficientnet_b1', pretrained=True, num_classes=0, global_pool='')
        self.axial_cropped_backbone = timm.create_model('efficientnet_b1', pretrained=True, num_classes=0, global_pool='')

        # Ensure all backbone parameters require gradients
        for param in self.sagittal_backbone.parameters():
            param.requires_grad = True
        for param in self.axial_backbone.parameters():
            param.requires_grad = True

        # Get the number of features from the last layer of the backbone
        self.num_features = self.sagittal_backbone.feature_info[-1]['num_chs']

        self.num_features_cropped = self.sagittal_t2_cropped_backbone.num_features

        # Add adaptive pooling
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Keypoint prediction branches
        self.sagittal_t2_keypoint_fc = nn.Sequential(
            nn.Linear(self.num_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, sagittal_num_keypoints * 2)
        )

        self.sagittal_t1_keypoint_fc = nn.Sequential(
            nn.Linear(self.num_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, sagittal_num_keypoints * 2)
        )

        self.axial_t2_keypoint_fc = nn.Sequential(
            nn.Linear(self.num_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, axial_num_keypoints * 2)
        )

        self.sagittal_t2_fc_cropped = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.num_features_cropped, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, 3),
            ) for _ in range(sagittal_t2_num_classes)
        ])

        self.sagittal_t1_fc_cropped = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.num_features_cropped, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, 3),
            ) for _ in range(sagittal_t1_num_classes)
        ])
        
        self.axial_fc_cropped = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.num_features_cropped, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, 3),
            ) for _ in range(axial_num_classes)
        ])

    
    def forward(self, x, task='classification'):
        batch_size = x[0].size(0)  # All nested tensors have the same batch size

        if task == 'localization':
            # Process Sagittal T2 images (first nested tensor)
            sagittal_t2_features = self.sagittal_backbone(x[0].squeeze(1))  # Remove the extra dimension
            sagittal_t2_features = self.adaptive_pool(sagittal_t2_features).squeeze(-1).squeeze(-1)
            
            # Process Sagittal T1 images (second nested tensor)
            sagittal_t1_features = self.sagittal_backbone(x[1].view(-1, *x[1].shape[2:]))
            sagittal_t1_features = self.adaptive_pool(sagittal_t1_features).squeeze(-1).squeeze(-1)

            # Process Axial images (third nested tensor)
            axial_features = self.axial_backbone(x[2].view(-1, *x[2].shape[2:]))
            axial_features = self.adaptive_pool(axial_features).squeeze(-1).squeeze(-1)

            # Sagittal T2 localization
            sagittal_t2_keypoints = self.sagittal_t2_keypoint_fc(sagittal_t2_features).view(batch_size, 1, 5, 2)

            # Sagittal T1 localization
            sagittal_t1_keypoints = self.sagittal_t1_keypoint_fc(sagittal_t1_features).view(batch_size, 2, 5, 2)

            # Axial localization
            axial_keypoints = self.axial_t2_keypoint_fc(axial_features).view(batch_size, 5, 2, 2)

            nested_keypoints = torch.nested.as_nested_tensor([sagittal_t2_keypoints, sagittal_t1_keypoints, axial_keypoints])
            return nested_keypoints

        elif task == 'classification':
            # Process Sagittal T2 images (first nested tensor)
            sagittal_t2_features = self.sagittal_t2_cropped_backbone(x[0].view(-1, *x[0].shape[2:]))
            sagittal_t2_features = self.adaptive_pool(sagittal_t2_features).squeeze(-1).squeeze(-1)

            # Process Sagittal T1 images (second nested tensor)
            sagittal_t1_features = self.sagittal_t1_cropped_backbone(x[1].view(-1, *x[1].shape[2:]))
            sagittal_t1_features = self.adaptive_pool(sagittal_t1_features).squeeze(-1).squeeze(-1)

            # Process Axial images (third nested tensor)
            axial_features = self.axial_cropped_backbone(x[2].view(-1, *x[2].shape[2:]))
            axial_features = self.adaptive_pool(axial_features).squeeze(-1).squeeze(-1)

            # Process Sagittal T2 (5 cropped images)
            sagittal_t2_unflattened = sagittal_t2_features.view(batch_size, 5, -1)
            sagittal_t2_outputs = torch.stack([fc(sagittal_t2_unflattened[:, i, :]) for i, fc in enumerate(self.sagittal_t2_fc_cropped)], dim=1)

            # Process Sagittal T1 (10 cropped images)
            sagittal_t1_unflattened = sagittal_t1_features.view(batch_size, 10, -1)
            sagittal_t1_outputs = torch.stack([fc(sagittal_t1_unflattened[:, i, :]) for i, fc in enumerate(self.sagittal_t1_fc_cropped)], dim=1)

            # Process Axial (10 cropped images)
            axial_unflattened = axial_features.view(batch_size, 10, -1)
            axial_outputs = torch.stack([fc(axial_unflattened[:, i, :]) for i, fc in enumerate(self.axial_fc_cropped)], dim=1)

            return torch.cat([sagittal_t2_outputs, sagittal_t1_outputs, axial_outputs], dim=1)

        else:
            raise ValueError(f"Invalid task: {task}")

## Testing the model

In [None]:
if DEBUG:
    class TestMultiTaskModel(unittest.TestCase):
        def setUp(self):
            self.model = MultiTaskModel(
                backbone_name='densenet161.tv_in1k',
                sagittal_t2_num_classes=5,
                sagittal_t1_num_classes=10,
                sagittal_num_keypoints=5,
                axial_num_classes=10,
                axial_num_keypoints=2
            )
            self.batch_size = 2
            self.crop_size = (64, 64)

            # Create dummy nested tensor input for localization task
            sagittal_t2 = torch.randn(self.batch_size, 1, 3, 490, 275)  # (B, 1, C, H, W)
            sagittal_t1 = torch.randn(self.batch_size, 2, 3, 490, 275)  # (B, 2, C, H, W)
            axial = torch.randn(self.batch_size, 5, 3, 310, 250)        # (B, 5, C, H, W)
            self.dummy_input_localization = torch.nested.nested_tensor([sagittal_t2, sagittal_t1, axial])

            # Create dummy keypoints for classification task
            self.keypoints_to_use = [
                {
                    'Sagittal T2': {level: [112, 112] for level in LEVELS},
                    'Sagittal T1': {
                        'right': {level: [112, 112] for level in LEVELS},
                        'left': {level: [112, 112] for level in LEVELS}
                    },
                    'Axial T2': {level: {'right': [64, 64], 'left': [64, 64]} for level in LEVELS}
                }
                for _ in range(self.batch_size)
            ]

            # Create dummy nested tensor input for classification task
            self.dummy_input_classification, _ = gpu_crop_region(self.dummy_input_localization, self.keypoints_to_use, self.crop_size)

        def test_model_structure(self):
            self.assertIsInstance(self.model.sagittal_backbone, nn.Module)
            self.assertIsInstance(self.model.axial_backbone, nn.Module)
            self.assertIsInstance(self.model.sagittal_t2_keypoint_fc, nn.Linear)
            self.assertIsInstance(self.model.sagittal_t1_keypoint_fc, nn.Linear)
            self.assertIsInstance(self.model.axial_t2_keypoint_fc, nn.Linear)
            self.assertEqual(len(self.model.sagittal_t2_fc_cropped), 5)
            self.assertEqual(len(self.model.sagittal_t1_fc_cropped), 10)
            self.assertEqual(len(self.model.axial_fc_cropped), 10)

        def test_localization_task(self):
            self.model.eval()
            with torch.no_grad():
                output = self.model(self.dummy_input_localization, task='localization')

            self.assertIsInstance(output, torch.Tensor)
            self.assertEqual(len([_ for _ in output.unbind()]), 3)  # Should have 3 image types
            self.assertEqual(output[0].squeeze().shape, (self.batch_size, 5, 2))
            self.assertEqual(output[1].shape, (self.batch_size, 2, 5, 2))
            self.assertEqual(output[2].shape, (self.batch_size, 5, 2, 2))

        def test_classification_task(self):
            self.model.eval()
            with torch.no_grad():
                output = self.model(self.dummy_input_classification, task='classification')

            self.assertIsInstance(output, torch.Tensor)

            self.assertEqual(output.shape, (self.batch_size, N_LABELS, 3))

        def test_invalid_task(self):
            with self.assertRaises(ValueError):
                self.model(self.dummy_input_localization, task='invalid_task')

        def test_model_trainable_parameters(self):
            trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            total_params = sum(p.numel() for p in self.model.parameters())

            self.assertGreater(trainable_params, 0, "Model should have trainable parameters")
            self.assertEqual(trainable_params, total_params, "All parameters should be trainable")

            # If not all parameters are trainable, print out which ones are not
            if trainable_params != total_params:
                for name, param in self.model.named_parameters():
                    if not param.requires_grad:
                        print(f"Parameter {name} is not trainable")
        def test_all_parameters_require_grad(self):
            all_require_grad = all(p.requires_grad for p in self.model.parameters())
            self.assertTrue(all_require_grad, "Not all parameters require gradients")

            # Print out which parameters don't require gradients, if any
            if not all_require_grad:
                for name, param in self.model.named_parameters():
                    if not param.requires_grad:
                        print(f"Parameter {name} does not require gradients")

        def test_model_output_range(self):
            self.model.eval()
            with torch.no_grad():
                loc_output = self.model(self.dummy_input_localization, task='localization')
                class_output = self.model(self.dummy_input_classification, task='classification')

            # Check if keypoints are within image dimensions
            self.assertTrue(torch.all(loc_output[0] >= 0))
            self.assertTrue(torch.all(loc_output[0] <= 224))  # Assuming 224x224 input

            # Check if classification outputs are valid probabilities
            self.assertTrue(torch.all(class_output >= 0))
            self.assertTrue(torch.all(class_output <= 1))

        def test_backbone_freeze_unfreeze(self):
            # Test freezing backbone
            for param in self.model.sagittal_backbone.parameters():
                param.requires_grad = False

            self.assertTrue(all(not p.requires_grad for p in self.model.sagittal_backbone.parameters()))

            # Test unfreezing backbone
            for param in self.model.sagittal_backbone.parameters():
                param.requires_grad = True

            self.assertTrue(all(p.requires_grad for p in self.model.sagittal_backbone.parameters()))

        def test_forward_pass_with_grad(self):
            self.model.train()
            # Ensure input requires grad
            self.dummy_input_localization = torch.nested.nested_tensor([t.requires_grad_() for t in self.dummy_input_localization.unbind()])
            output = self.model(self.dummy_input_localization, task='localization')

            # Sum all outputs
            loss = sum(o.sum() for o in output)

            # Print debug information
            print(f"Loss: {loss}")
            print(f"Loss requires grad: {loss.requires_grad}")

            for i, o in enumerate(output):
                print(f"Output {i} requires grad: {o.requires_grad}")
                print(f"Output {i} grad_fn: {o.grad_fn}")

            # Try to backward
            try:
                    loss.backward()
            except Exception as e:
                print(f"Error during backward pass: {e}")
                for i, o in enumerate(output):
                    print(f"Output {i} grad_fn after error: {o.grad_fn}")
                raise

            # Check gradients for all parameters
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    if param.grad is None:
                        print(f"Gradient is None for parameter: {name}")
                    elif torch.sum(param.grad) == 0:
                        print(f"Gradient is all zeros for parameter: {name}")
                else:
                    print(f"Parameter does not require grad: {name}")

            # Assert that at least some parameters have non-zero gradients
            grads = [param.grad for param in self.model.parameters() if param.requires_grad]
            self.assertTrue(any(grad is not None and torch.sum(grad) != 0 for grad in grads), 
                            "No parameter has non-zero gradient")
            

        def test_forward_pass_with_grad_classification(self):
            self.model.train()
            # Ensure input requires grad
            self.dummy_input_classification = torch.nested.nested_tensor([t.requires_grad_() for t in self.dummy_input_classification.unbind()])
            output = self.model(self.dummy_input_classification, task='classification')

            # Create dummy labels
            dummy_labels = torch.randint(0, 3, (self.batch_size, N_LABELS)).to(output.device)

            # Use CrossEntropyLoss
            criterion = nn.CrossEntropyLoss()
            loss = criterion(output.view(-1, 3), dummy_labels.view(-1))

            # Print debug information
            print(f"Classification Loss: {loss}")
            print(f"Loss requires grad: {loss.requires_grad}")
            print(f"Output requires grad: {output.requires_grad}")
            print(f"Output grad_fn: {output.grad_fn}")

            # Try to backward
            try:
                loss.backward()
            except Exception as e:
                print(f"Error during backward pass: {e}")
                print(f"Output grad_fn after error: {output.grad_fn}")
                raise

            # Check gradients for all parameters
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    if param.grad is None:
                        print(f"Gradient is None for parameter: {name}")
                    elif torch.sum(param.grad) == 0:
                        print(f"Gradient is all zeros for parameter: {name}")
                else:
                    print(f"Parameter does not require grad: {name}")

            # Assert that at least some parameters have non-zero gradients
            grads = [param.grad for param in self.model.parameters() if param.requires_grad]
            self.assertTrue(any(grad is not None and torch.sum(grad) != 0 for grad in grads), 
                            "No parameter has non-zero gradient")

    # Run the tests
    unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestMultiTaskModel))

# Training

## Helper functions

### To possibly not evaluate models with missing labels at the beginning of training

## Actual Training Loop

### Would be good to implement stratified sampling

In [None]:
def points_to_tensor(points):
    if not points:  # If points is []
        return torch.tensor([]).to(device)
    else:  # If points is [[x,y]]
        return torch.tensor(points, dtype=torch.float32).squeeze(0).to(device) 

In [None]:
class CustomClassifierCriterion(nn.Module):
    def __init__(self, criterion = nn.CrossEntropyLoss()):
        super(CustomClassifierCriterion, self).__init__()
        self.criterion = criterion
        self.sagittal_t2_indices = range(0, 5)
        self.sagittal_t1_indices = range(5, 15)
        self.axial_t2_indices = range(15, 25)

    def forward(self, predictions, labels):
        device = predictions.device
        
        if not isinstance(labels, torch.Tensor):
            labels = torch.tensor(labels.values, dtype=torch.long, device=device)
        else:
            labels = labels.to(device)
        
        total_loss = 0
        sagittal_t2_loss = 0
        sagittal_t1_loss = 0
        axial_t2_loss = 0
        for idx in range(N_LABELS):
            prediction = predictions[:, idx, :]
            label = labels[:, idx]
            loss = self.criterion(prediction, label) / N_LABELS
            total_loss += loss
            if idx in self.sagittal_t2_indices:
                sagittal_t2_loss += loss.item()
            elif idx in self.sagittal_t1_indices:
                sagittal_t1_loss += loss.item()
            elif idx in self.axial_t2_indices:
                axial_t2_loss += loss.item()

        return total_loss, {
            'Sagittal T2': sagittal_t2_loss,
            'Sagittal T1': sagittal_t1_loss,
            'Axial T2': axial_t2_loss
        }

In [None]:
def visualize_keypoints(model, images, true_keypoints, device, epoch, fold, output_dir, study_id):
    model.eval()
    
    images = images.clone().detach().to(device)
    x = []
    for tensor in images.unbind():
        x.append(tensor[0].unsqueeze(dim = 0))
    
    with torch.no_grad():
        predicted_keypoints = model(x, task='localization')
    
    # Unpack the nested tensor
    sagittal_t2_keypoints, sagittal_t1_keypoints, axial_keypoints = predicted_keypoints.unbind()
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    image_types = ['Sagittal T2', 'Sagittal T1', 'Axial T2']
    
    for idx, (ax, image_type) in enumerate(zip(axs, image_types)):
        if image_type == 'Sagittal T2':
            img = x[0].squeeze().squeeze().cpu()
            pred_kp = sagittal_t2_keypoints[0, 0].cpu()
            true_kp = true_keypoints['Sagittal T2']
        elif image_type == 'Sagittal T1':
            img = x[1][0, 0, ...].cpu()
            pred_kp = sagittal_t1_keypoints[0, 0].cpu()  # Assuming right side
            true_kp = true_keypoints['Sagittal T1']['right']
        else:  # Axial T2
            img = x[2][0, 0, ...].cpu()  # Assuming the first axial slice
            pred_kp = axial_keypoints[0, 0].cpu()  # Assuming first level
            true_kp = true_keypoints['Axial T2']['L1/L2']
        
        img_plot = img.clone().detach()
        ax.imshow(img_plot.permute(1, 2, 0))
        ax.set_title(image_type)
        
        # Plot predicted keypoints
        for kp in pred_kp:
            rect = patches.Rectangle((kp[0]-5, kp[1]-5), 10, 10, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
        
        # Plot true keypoints
        for kp in true_kp.values():
            kp = kp[0]
            if kp:  # Check if keypoint exists
                rect = patches.Rectangle((kp[0]-5, kp[1]-5), 10, 10, linewidth=1, edgecolor='g', facecolor='none')
                ax.add_patch(rect)
    
    plt.tight_layout()
    plt.savefig(f'{output_dir}/keypoint_visualization_study_{study_id}_epoch_{epoch}_fold-{fold}.png')
    plt.close()

def save_crops_for_study(cropped_imgs, study_id, epoch, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    image_types = ['Sagittal_T2', 'Sagittal_T1', 'Axial_T2']
    
    for i, img_type in enumerate(image_types):
        crops = cropped_imgs[i]
        filename = f'{output_dir}/{img_type}_study_{study_id}_epoch_{epoch}.png'
        grid = vutils.make_grid(crops[0], nrow=crops.shape[1], padding=2, normalize=True)
        vutils.save_image(grid, filename)

In [None]:
from torch.nn.utils import clip_grad_norm_

def clip_and_count(parameters, max_norm, norm_type=2):
    total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters if p.grad is not None]), norm_type)
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
    for p in parameters:
        if p.grad is not None:
            p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
    return total_norm > max_norm

In [None]:
autocast = torch.amp.autocast(enabled=USE_AUTOMATIC_MIXED_PRECISION, dtype= torch.float16, device_type = device)
localizer_scaler = torch.amp.GradScaler(enabled=USE_AUTOMATIC_MIXED_PRECISION, init_scale=4096)
classifier_scaler = torch.amp.GradScaler(enabled=USE_AUTOMATIC_MIXED_PRECISION, init_scale=4096)

kfold = KFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
for fold, (training_index, validation_index) in enumerate(kfold.split(range(len(train_df)))):
    print('#' * 30)
    print(f'Starting fold {fold + 1}')
    print('#' * 30)

    current_training_index = train_df[train_df['study_id'].isin(all_labels_df['study_id'])].index.intersection(training_index)
    current_validation_index = train_df[train_df['study_id'].isin(all_labels_df['study_id'])].index.intersection(validation_index)
    study_id_for_kp_plotting = train_df.iloc[current_validation_index[0]]['study_id']
    training_rows = train_df.iloc[current_training_index]
    validation_rows = train_df.iloc[current_validation_index]

    print('Training length: ', len(training_rows), 'Validation length: ', len(validation_rows))

    training_dataset = MultiTaskDataset(label_df=training_rows, phase='train', transform=transforms_train)
    training_dataloader = DataLoader(
        training_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
        num_workers=N_WORKERS,
        collate_fn=custom_collate
    )

    validation_dataset = MultiTaskDataset(label_df=validation_rows, phase='val', transform=transforms_validation)
    validation_dataloader = DataLoader(
        validation_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        drop_last=False,
        num_workers=N_WORKERS,
        collate_fn=custom_collate
    )
    
    model = MultiTaskModel()
    model.to(device)
    model.train()
    
    keypoint_params = list(model.sagittal_backbone.parameters()) + \
                      list(model.axial_backbone.parameters()) + \
                      list(model.sagittal_t2_keypoint_fc.parameters()) + \
                      list(model.sagittal_t1_keypoint_fc.parameters()) + \
                      list(model.axial_t2_keypoint_fc.parameters())

    classification_params = list(model.sagittal_t2_cropped_backbone.parameters()) + \
                            list(model.sagittal_t1_cropped_backbone.parameters()) + \
                            list(model.axial_cropped_backbone.parameters()) + \
                            [p for fc_list in [model.sagittal_t2_fc_cropped, 
                                               model.sagittal_t1_fc_cropped, 
                                               model.axial_fc_cropped] 
                             for fc in fc_list for p in fc.parameters()]
    
    keypoint_optimizer = AdamW(keypoint_params, lr=1e-4, weight_decay = WEIGHT_DECAY)
    classification_optimizer = AdamW(classification_params, lr=1e-4, weight_decay = WEIGHT_DECAY)

    # Separate schedulers
    keypoint_scheduler = ReduceLROnPlateau(keypoint_optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    classification_scheduler = ReduceLROnPlateau(classification_optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    
    localizer_criterion = nn.MSELoss()
    
    weights = torch.tensor([1.0, 2.0, 4.0])
    classifier_criterion = CustomClassifierCriterion(nn.CrossEntropyLoss(weight=weights.to(device)))
    
    best_loss = float('inf')
    best_wall = float('inf')
    es_step = 0
    
    # New variables for tracking and checkpoints
    localizer_performance_threshold = 1.2 * 10e5
    localizer_freeze_threshold = 5.0 * 10e3
    full_dataset_used = False
    localizer_frozen = False
    
    # For plotting
    epoch_metrics = {
        'train_loss': [], 'val_loss': [], 'val_wall': [], 
        'train_class_loss': [], 'train_loc_loss': [], 'val_class_loss': [], 'val_loc_loss': [], 
        'train_loc_losses': {'Sagittal T2': [], 'Sagittal T1': [], 'Axial T2': []},
        'train_class_losses': {'Sagittal T2': [], 'Sagittal T1': [], 'Axial T2': []},
        'val_loc_losses': {'Sagittal T2': [], 'Sagittal T1': [], 'Axial T2': []},
        'val_class_losses': {'Sagittal T2': [], 'Sagittal T1': [], 'Axial T2': []}
    }
    
    keypoint_clip_count = 0
    classification_clip_count = 0
    
    for epoch in tqdm(range(1, EPOCHS + 1), desc="Epochs"):
        print(f'Starting epoch {epoch}')

        # Training phase
        model.train()
        total_class_loss = 0
        total_loc_loss = 0
        loc_losses = {'Sagittal T2': 0, 'Sagittal T1': 0, 'Axial T2': 0}
        class_losses = {'Sagittal T2': 0, 'Sagittal T1': 0, 'Axial T2': 0}
        
        with tqdm(training_dataloader, leave=False, desc="Training") as loaded_items:
            keypoint_optimizer.zero_grad()
            classification_optimizer.zero_grad()
            for idx, batch in enumerate(loaded_items):
                images = batch['images'].to(device)
                keypoint_labels = batch['keypoints']
                class_labels = batch['labels'].to(device)
                mask = batch['mask'].to(device)

                with autocast:
                    # Step 1: Localization
                    with torch.set_grad_enabled(not localizer_frozen):
                        predicted_keypoints = model(images, task='localization')
                    
                    # Unpack the nested tensor
                    sagittal_t2_keypoints, sagittal_t1_keypoints, axial_keypoints = predicted_keypoints.unbind()
                    sagittal_t2_keypoints = sagittal_t2_keypoints.squeeze()
                    
                    # Step 2: Determine which keypoints to use and calculate loss
                    keypoints_to_use = []
                    loc_loss = 0

                    for batch_idx, batch_item in enumerate(keypoint_labels):
                        sample_keypoints = {}
                        
                        for image_type, keypoints in batch_item.items():
                            sample_keypoints[image_type] = {}
                            
                            if image_type == 'Sagittal T2':
                                for level_idx, level in enumerate(LEVELS):
                                    points = keypoints.get(level, [])
                                    if mask[batch_idx, 0] == 0:
                                        sample_keypoints[image_type][level] = []
                                    else:
                                        sample_keypoints[image_type][level] = points_to_tensor(points) if points else sagittal_t2_keypoints[batch_idx, level_idx].detach()
                            
                            elif image_type == 'Sagittal T1':
                                sample_keypoints[image_type] = {'right': {}, 'left': {}}
                                for side_idx, side in enumerate(['right', 'left']):
                                    mask_idx = 1 if side == 'right' else 2
                                    for level_idx, level in enumerate(LEVELS):
                                        points = keypoints[side].get(level, [])
                                        if mask[batch_idx, mask_idx] == 0:
                                            sample_keypoints[image_type][side][level] = []
                                        else:
                                            sample_keypoints[image_type][side][level] = points_to_tensor(points) if points else sagittal_t1_keypoints[batch_idx, side_idx, level_idx].detach()
                            
                            elif image_type == 'Axial T2':
                                for level_idx, level in enumerate(LEVELS):
                                    sides = keypoints.get(level, {})
                                    mask_idx = 3 + level_idx
                                    sample_keypoints[image_type][level] = {}
                                    for side_idx, side in enumerate(['right', 'left']):
                                        points = sides.get(side, [])
                                        if mask[batch_idx, mask_idx] == 0:
                                            sample_keypoints[image_type][level][side] = []
                                        else:
                                            sample_keypoints[image_type][level][side] = points_to_tensor(points) if points else axial_keypoints[batch_idx, level_idx, side_idx].detach()
                        
                        keypoints_to_use.append(sample_keypoints)

                    if not localizer_frozen:
                        for batch_idx, sample_keypoints in enumerate(keypoints_to_use):
                            for image_type, keypoints in sample_keypoints.items():
                                if image_type == 'Sagittal T2':
                                    for level_idx, level in enumerate(LEVELS):
                                        points = keypoints.get(level, [])
                                        if len(points) > 0:
                                            level_loss = localizer_criterion(
                                                sagittal_t2_keypoints[batch_idx, level_idx],
                                                points
                                            ) * mask[batch_idx, 0]
                                            loc_loss += level_loss
                                            loc_losses[image_type] += level_loss.item()
                                
                                elif image_type == 'Sagittal T1':
                                    for side_idx, side in enumerate(['right', 'left']):
                                        for level_idx, level in enumerate(LEVELS):
                                            points = keypoints[side].get(level, [])
                                            if len(points) > 0:
                                                level_loss = localizer_criterion(
                                                    sagittal_t1_keypoints[batch_idx, side_idx, level_idx],
                                                    points) * mask[batch_idx, 1 if side == 'right' else 2]
                            
                                                loc_loss += level_loss
                                                loc_losses[image_type] += level_loss.item()
                                
                                elif image_type == 'Axial T2':
                                    for level_idx, level in enumerate(LEVELS):
                                        sides = keypoints.get(level, {})
                                        for side_idx, side in enumerate(['right', 'left']):
                                            points = sides.get(side, [])
                                            if len(points) > 0:
                                                level_loss = localizer_criterion(
                                                    axial_keypoints[batch_idx, level_idx, side_idx],
                                                    points
                                                ) * mask[batch_idx, 3 + level_idx]
                                                loc_loss += level_loss
                                                loc_losses[image_type] += level_loss.item()

                    if not localizer_frozen:
                        localizer_scaler.scale(loc_loss).backward() #(retain_graph=True)

                    # Track localization losses
                    total_loc_loss = sum(loc_losses.values())

                    # Step 3: Crop images on GPU
                    cropped_images, keypoint_masks = gpu_crop_region(images, keypoints_to_use, epoch=epoch)

                    # Step 4: Classification
                    with autocast:
                        class_outputs = model(cropped_images, task='classification')
                        class_loss, individual_losses = classifier_criterion(class_outputs, class_labels)
                        
                    for k, v in individual_losses.items():
                        class_losses[k] += v
                    
                    classifier_scaler.scale(class_loss).backward()

                    total_loc_loss += loc_loss.item() if not localizer_frozen else 0
                    total_class_loss += class_loss.item()

#                     if GRAD_ACCUMULATION > 1:
#                         total_loc_loss = total_loc_loss / GRAD_ACCUMULATION
#                         total_class_loss = total_class_loss / GRAD_ACCUMULATION

                if not math.isfinite(class_loss.item()):
                    print(f"Loss is {class_loss.item()}, stopping training")
                    sys.exit(1)
                
                loaded_items.set_postfix(
                    OrderedDict(
                        class_loss=f'{class_loss.item():.6f}',
                        loc_loss=f'{loc_loss.item():.6f}' if not localizer_frozen else 'N/A',
                        lr=f'{classification_optimizer.param_groups[0]["lr"]:.3e}'
                    )
                )

                if clip_and_count(keypoint_params, MAX_GRAD_NORM):
                    keypoint_clip_count += 1
                    
                if clip_and_count(classification_params, MAX_GRAD_NORM):
                    classification_clip_count += 1
                
                if (idx + 1) % GRAD_ACCUMULATION == 0:
                    if not localizer_frozen:
                        localizer_scaler.step(keypoint_optimizer)
                        localizer_scaler.update()
                    classifier_scaler.step(classification_optimizer)
                    classifier_scaler.update()
                    keypoint_optimizer.zero_grad()
                    classification_optimizer.zero_grad()
                
        train_class_loss = total_class_loss / len(training_dataloader)
        train_loc_loss = total_loc_loss / len(training_dataloader) if not localizer_frozen else 0
        epoch_metrics['train_class_loss'].append(train_class_loss)
        epoch_metrics['train_loc_loss'].append(train_loc_loss)
        for k in loc_losses.keys():
            epoch_metrics['train_loc_losses'][k].append(loc_losses[k] / len(training_dataloader))
            epoch_metrics['train_class_losses'][k].append(class_losses[k] / len(training_dataloader))
        
        # Validation phase
        model.eval()
        total_val_class_loss = 0
        total_val_loc_loss = 0
        val_loc_losses = {'Sagittal T2': 0, 'Sagittal T1': 0, 'Axial T2': 0}
        val_class_losses = {'Sagittal T2': 0, 'Sagittal T1': 0, 'Axial T2': 0}
        all_predictions = []
        all_labels = []
        
        with torch.no_grad():
            for idx, batch in tqdm(enumerate(validation_dataloader), leave=False, desc="Validation"):
                images = batch['images'].to(device)
                keypoint_labels = batch['keypoints']
                class_labels = batch['labels'].to(device)
                mask = batch['mask'].to(device)

                with autocast:
                    # Localization
                    predicted_keypoints = model(images, task='localization')
                    
                    # Unpack the nested tensor
                    sagittal_t2_keypoints, sagittal_t1_keypoints, axial_keypoints = predicted_keypoints.unbind()
                    sagittal_t2_keypoints = sagittal_t2_keypoints.squeeze()
                    
                    # Compute localization loss using keypoint_masks
                    loc_loss = 0
                    for batch_idx, batch_item in enumerate(keypoint_labels):
                        for image_type, keypoints in batch_item.items():
                            if image_type == 'Sagittal T2':
                                for level_idx, level in enumerate(LEVELS):
                                    points = keypoints.get(level, [])
                                    if len(points) > 0:
                                        level_loss = localizer_criterion(
                                            sagittal_t2_keypoints[batch_idx, level_idx],
                                            points_to_tensor(points)
                                        ) * mask[batch_idx, 0]
                                        loc_loss += level_loss
                                        val_loc_losses['Sagittal T2'] += level_loss.item()
                            
                            elif image_type == 'Sagittal T1':
                                for side_idx, side in enumerate(['right', 'left']):
                                    for level_idx, level in enumerate(LEVELS):
                                        points = keypoints[side].get(level, [])
                                        if len(points) > 0:
                                            level_loss = localizer_criterion(
                                                sagittal_t1_keypoints[batch_idx, side_idx, level_idx],
                                                points_to_tensor(points)
                                            ) * mask[batch_idx, 1 if side == 'right' else 2]
                                            loc_loss += level_loss
                                            val_loc_losses['Sagittal T1'] += level_loss.item()
                            
                            elif image_type == 'Axial T2':
                                for level_idx, level in enumerate(LEVELS):
                                    sides = keypoints.get(level, {})
                                    for side_idx, side in enumerate(['right', 'left']):
                                        points = sides.get(side, [])
                                        if len(points) > 0:
                                            level_loss = localizer_criterion(
                                                axial_keypoints[batch_idx, level_idx, side_idx],
                                                points_to_tensor(points)
                                            ) * mask[batch_idx, 3 + level_idx]
                                            loc_loss += level_loss
                                            val_loc_losses['Axial T2'] += level_loss.item()
                    
                # Crop images
                cropped_images, keypoint_masks = gpu_crop_region(images, keypoints_to_use, epoch=epoch)

                with autocast:
                    class_outputs = model(cropped_images, task='classification')
                    class_loss, individual_losses = classifier_criterion(class_outputs, class_labels)

                for k, v in individual_losses.items():
                    val_class_losses[k] += v

                total_val_loc_loss += loc_loss.item()
                total_val_class_loss += class_loss.item()

                all_predictions.append(class_outputs.cpu())
                all_labels.append(class_labels.cpu())
                
                # plot for monitoring
#                 if idx == 0:
#                     if (epoch - 1) % 6 == 0:
#                         save_crops_for_study(cropped_images, current_validation_index[idx], epoch, OUTPUT_DIR)
                    
#                     if epoch % 2 == 0:
#                         visualize_keypoints(model, images, keypoint_labels[0], device, epoch, fold, OUTPUT_DIR, study_id_for_kp_plotting)

            validation_class_loss = total_val_class_loss / len(validation_dataloader)
            validation_loc_loss = total_val_loc_loss / len(validation_dataloader) if not localizer_frozen else 0
        
            # take steps with the schedulers
            keypoint_scheduler.step(total_val_loc_loss)
            classification_scheduler.step(total_val_class_loss)

            all_predictions = torch.cat(all_predictions, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            
            all_predictions = all_predictions.to(device).to(torch.float32)
            all_labels = all_labels.to(device).to(torch.long)

            # Calculate validation metrics
            validation_wall = classifier_criterion(all_predictions, all_labels)[0].item()

            epoch_metrics['val_class_loss'].append(validation_class_loss)
            epoch_metrics['val_loc_loss'].append(validation_loc_loss)
            for k in val_loc_losses.keys():
                epoch_metrics['val_loc_losses'][k].append(val_loc_losses[k] / len(validation_dataloader))
                epoch_metrics['val_class_losses'][k].append(val_class_losses[k] / len(validation_dataloader))

            # Print summary
            print(f'\n{"="*50}')
            print(f'Epoch {epoch} Summary:')
            print(f'{"="*50}')

            print(f'\nTraining Losses:')
            print(f'  Total Classification Loss: {train_class_loss:.6f}')
            print(f'  Total Localization Loss: {train_loc_loss:.6f}')
            print(f'  Total Classification clip count: {classification_clip_count}')
            print(f'  Total Localization clip count: {keypoint_clip_count}')
            print(f'  Total Localization Loss: {train_loc_loss:.6f}')
            print(f'  Localization Losses:')
            for k, v in loc_losses.items():
                avg_loss = v / len(training_dataloader)
                print(f'    {k}: {avg_loss:.6f}')
            print(f'  Classification Losses:')
            for k, v in class_losses.items():
                avg_loss = v / len(training_dataloader)
                print(f'    {k}: {avg_loss:.6f}')

            print(f'\nValidation Losses:')
            print(f'  Total Classification Loss: {validation_class_loss:.6f}')
            print(f'  Total Localization Loss: {validation_loc_loss:.6f}')
            print(f'  Validation Wall: {validation_wall:.6f}')
            print(f'  Localization Losses:')
            for k, v in val_loc_losses.items():
                avg_loss = v / len(validation_dataloader)
                print(f'    {k}: {avg_loss:.6f}')
            print(f'  Classification Losses:')
            for k, v in val_class_losses.items():
                avg_loss = v / len(validation_dataloader)
                print(f'    {k}: {avg_loss:.6f}')

            print(f'\nLearning Rate: {classification_optimizer.param_groups[0]["lr"]:.3e}')

            if full_dataset_used:
                print('\nFull dataset is being used.')
            else:
                print('\nPartial dataset is being used.')

            if localizer_frozen:
                print('Localizer is frozen.')
            else:
                print('Localizer is trainable.')

            print(f'{"="*50}\n')

            # Check localizer performance and update dataset/freeze localizer if needed
            if not full_dataset_used and validation_loc_loss < localizer_performance_threshold:
                print("Localizer performance threshold reached. Switching to full dataset.")
                training_rows = train_df.iloc[training_index]
                validation_rows = train_df.iloc[validation_index]
                print('Training length: ', len(training_rows), 'Validation length: ', len(validation_rows))
                training_dataset.extend_dataset(training_rows)
                training_dataloader = DataLoader(
                    training_dataset,
                    batch_size=BATCH_SIZE,
                    shuffle=True,
                    pin_memory=True,
                    drop_last=True,
                    num_workers=N_WORKERS,
                    collate_fn=custom_collate
                )
                validation_dataset.extend_dataset(validation_rows)
                validation_dataloader = DataLoader(
                    validation_dataset,
                    batch_size=BATCH_SIZE,
                    shuffle=False,
                    pin_memory=True,
                    drop_last=False,
                    num_workers=N_WORKERS,
                    collate_fn=custom_collate
                )
                full_dataset_used = True

            if not localizer_frozen and validation_loc_loss < localizer_freeze_threshold:
                print("Localizer freeze threshold reached. Freezing localizer.")
                for param in model.sagittal_t2_keypoint_fc.parameters():
                    param.requires_grad = False

                for param in model.sagittal_t1_keypoint_fc.parameters():
                    param.requires_grad = False

                for param in model.axial_keypoint_fc.parameters():
                    param.requires_grad = False
                    
                localizer_frozen = True

            if validation_class_loss < best_loss or validation_wall < best_wall:
                es_step = 0
                if device != 'cuda:0':
                    model.to('cuda:0')

                if validation_class_loss < best_loss:
                    print(f'epoch:{epoch}, best loss updated from {best_loss:.6f} to {validation_class_loss:.6f}')
                    best_loss = validation_class_loss
                    model_path = f'{OUTPUT_DIR}/best_loss_model_fold-{fold}.pt'
                    torch.save(model.state_dict(), model_path)

                if validation_wall < best_wall:
                    print(f'epoch:{epoch}, best wall_metric updated from {best_wall:.6f} to {validation_wall:.6f}')
                    best_wall = validation_wall
                    model_path = f'{OUTPUT_DIR}/best_wall_model_fold-{fold}.pt'
                    torch.save(model.state_dict(), model_path)

                if device != 'cuda:0':
                    model.to(device)

            else:
                es_step += 1
                if es_step >= EARLY_STOPPING_EPOCH:
                    print('Early stopping')
                    break

# After training, plot the metrics
plt.figure(figsize=(20, 15))

# Helper function to move data to CPU
def to_cpu(data):
    if isinstance(data, torch.Tensor):
        return data.cpu().numpy()
    elif isinstance(data, list):
        return [to_cpu(item) for item in data]
    elif isinstance(data, dict):
        return {k: to_cpu(v) for k, v in data.items()}
    else:
        return data

# Move all data to CPU
epoch_metrics = to_cpu(epoch_metrics)

# Classification Losses
plt.subplot(2, 2, 1)
plt.plot(epoch_metrics['train_class_loss'], label='Train Class Loss')
plt.plot(epoch_metrics['val_class_loss'], label='Val Class Loss')
plt.plot(epoch_metrics['val_wall'], label='Validation Wall')
plt.title('Classification Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Localization Losses
plt.subplot(2, 2, 2)
plt.plot(epoch_metrics['train_loc_loss'], label='Train Loc Loss')
plt.plot(epoch_metrics['val_loc_loss'], label='Val Loc Loss')
plt.title('Localization Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Detailed Classification Losses
plt.subplot(2, 2, 3)
for k in epoch_metrics['train_class_losses'].keys():
    plt.plot(epoch_metrics['train_class_losses'][k], label=f'Train {k}')
    plt.plot(epoch_metrics['val_class_losses'][k], label=f'Val {k}')
plt.title('Detailed Classification Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Detailed Localization Losses
plt.subplot(2, 2, 4)
for k in epoch_metrics['train_loc_losses'].keys():
    plt.plot(epoch_metrics['train_loc_losses'][k], label=f'Train {k}')
    plt.plot(epoch_metrics['val_loc_losses'][k], label=f'Val {k}')
plt.title('Detailed Localization Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/metrics_fold-{fold}.png')
plt.close()

# Validation Classification Losses Heatmap
plt.figure(figsize=(10, 8))
val_class_losses_df = pd.DataFrame(epoch_metrics['val_class_losses'])
sns.heatmap(val_class_losses_df.T, annot=True, cmap='YlOrRd')
plt.title('Validation Classification Losses Heatmap')
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/val_class_losses_heatmap_fold-{fold}.png')
plt.close()

print("Training completed. Metrics plots saved.")

In [None]:
path = '/kaggle/working/rsna24-results/keypoint_visualization_study_8785691_epoch_2_fold-0.png'
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Load the image
img = mpimg.imread(path)

# Display the image
plt.figure(figsize=(20, 15))  # Optional: set the figure size
plt.imshow(img)
plt.axis('off')  # Turn off axis numbers and ticks
plt.show()