# Sequence model
This notebook builds onto our initial [Training-the-model](https://www.kaggle.com/code/na731ff/training-the-model) notebook. Here, we implement a sequential model using a transformer to wrap our base model, which will still be a pretrained TIMM model, which will have 1 in_channel instead of 30. 

## This will allow to be more flexible by allowing a different number of images per label. 
However, we still only use the fixed 30 images per label coming from [Convert-Images-to-png](https://www.kaggle.com/code/na731ff/convert-images-to-png).

In [1]:
import os
import sys
import pandas as pd
import gc
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from tqdm import tqdm
import math
import random
import psutil
from collections import OrderedDict
import time
from datetime import timedelta

import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import timm
from transformers import get_cosine_schedule_with_warmup
import albumentations as A
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torchvision.transforms import transforms
from albumentations.pytorch import ToTensorV2
from torchvision.transforms import RandomAffine, GaussianBlur
from torch.utils.checkpoint import checkpoint_sequential
from torchvision.transforms.functional import adjust_sharpness, autocontrast
from torch.nn.functional import interpolate


In [2]:
BASE_URL = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/'
IMAGE_URL = '/kaggle/input/convert-images-to-smaller-png/Converted_smaller_images/'
OUTPUT_DIR = 'rsna24-results'
SEED = 7620
DEBUG = False # if set to true, run fewer computations
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
N_WORKERS = os.cpu_count()
IMAGE_SIZE = [224, 224]
IN_CHANNELS = 30 # number of images stacked over each other in an np.array which constitutes the Dataset 
N_LABELS = 25
N_CLASSES = 3 * N_LABELS

EPOCHS = 100 if not DEBUG else 2

GRAD_ACCUMULATION = 4
TARGET_BATCH_SIZE = 16
BATCH_SIZE = TARGET_BATCH_SIZE // GRAD_ACCUMULATION
MAX_GRAD_NORM = 1.0
EARLY_STOPPING_EPOCH = 10

LEARNING_RATE = 2e-4 * TARGET_BATCH_SIZE / 32 #could implement a lr scheduler
WEIGHT_DECAY = 1e-2
AUGMENTATION = True

P_DROPOUT = 0.2 # Dropout intensity in the transformer

USE_AUTOMATIC_MIXED_PRECISION = True # can change True if using T4 or newer than Ampere
AUGMENTATION_PROBABILITY = 0.75
EPS = 10e-6

LEVELS = ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']

In [3]:
train_df = pd.read_csv(f'{BASE_URL}train.csv')
descriptions_df = pd.read_csv(f'{BASE_URL}train_series_descriptions.csv')
coordinates_df = pd.read_csv(f'{BASE_URL}train_label_coordinates.csv')

In [4]:
coordinates_df = pd.merge(coordinates_df, descriptions_df, on = ['study_id', 'series_id'])
coordinates_df.head()

Unnamed: 0,study_id,series_id,instance_number,condition,level,x,y,series_description
0,4003253,702807833,8,Spinal Canal Stenosis,L1/L2,322.831858,227.964602,Sagittal T2/STIR
1,4003253,702807833,8,Spinal Canal Stenosis,L2/L3,320.571429,295.714286,Sagittal T2/STIR
2,4003253,702807833,8,Spinal Canal Stenosis,L3/L4,323.030303,371.818182,Sagittal T2/STIR
3,4003253,702807833,8,Spinal Canal Stenosis,L4/L5,335.292035,427.327434,Sagittal T2/STIR
4,4003253,702807833,8,Spinal Canal Stenosis,L5/S1,353.415929,483.964602,Sagittal T2/STIR


In [5]:
os.makedirs(OUTPUT_DIR, exist_ok = True)

In [6]:
def set_random_seed(seed: int = 8620, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = deterministic  # type: ignore

set_random_seed(SEED)

In [7]:
condition_encoding = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}

def map_condition(condition):
    return condition_encoding.get(condition, -100) # we need to make sure that these NA filled with -100 are not used

train_df.iloc[:, 1:] = train_df.iloc[:, 1:].map(map_condition)
train_df.head()

Unnamed: 0,study_id,spinal_canal_stenosis_l1_l2,spinal_canal_stenosis_l2_l3,spinal_canal_stenosis_l3_l4,spinal_canal_stenosis_l4_l5,spinal_canal_stenosis_l5_s1,left_neural_foraminal_narrowing_l1_l2,left_neural_foraminal_narrowing_l2_l3,left_neural_foraminal_narrowing_l3_l4,left_neural_foraminal_narrowing_l4_l5,...,left_subarticular_stenosis_l1_l2,left_subarticular_stenosis_l2_l3,left_subarticular_stenosis_l3_l4,left_subarticular_stenosis_l4_l5,left_subarticular_stenosis_l5_s1,right_subarticular_stenosis_l1_l2,right_subarticular_stenosis_l2_l3,right_subarticular_stenosis_l3_l4,right_subarticular_stenosis_l4_l5,right_subarticular_stenosis_l5_s1
0,4003253,0,0,0,0,0,0,0,0,1,...,0,0,0,1,0,0,0,0,0,0
1,4646740,0,0,1,2,0,0,0,0,1,...,0,0,0,2,0,0,1,1,1,0
2,7143189,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,8785691,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
4,10728036,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0


In [8]:
CONDITIONS = [
    'Spinal Canal Stenosis', 
    'Left Neural Foraminal Narrowing', 
    'Right Neural Foraminal Narrowing',
    'Left Subarticular Stenosis',
    'Right Subarticular Stenosis'
]

LEVELS = [
    'L1/L2',
    'L2/L3',
    'L3/L4',
    'L4/L5',
    'L5/S1',
]

## Defining the Dataset
This implementation supposedly leaves a lot of room for improvement.

In [9]:
class RSNA24Dataset(Dataset):
    def __init__(
            self, 
            df = train_df, 
            descriptions_df = descriptions_df, 
            phase='train', 
            transform=None
            ):
        
        self.df = df
        self.descriptions_df = descriptions_df
        self.transform = transform
        self.phase = phase
        self.PILToTensor = transforms.Compose([transforms.PILToTensor()])
        self.images = {}
        self.load_images()

    def load_images(self):
        study_ids = self.df['study_id'].unique()
        for study_id in study_ids:
            if study_id not in self.images:
                self.images[study_id] = {}
                
                # loading Sagittal T2
                self._load_sagittal_t2(study_id)
                
                # loading Sagittal T1
                self._load_sagittal_t1(study_id)
                
                # loading Axial T2
                self._load_axial_t2(study_id)

    def _load_image_type(self, study_id, image_type, description):
        description_ = description.replace(' ', '_').replace('/', '-')
        series_id_df = self.descriptions_df.query('@study_id == study_id and @description == series_description')

        if not series_id_df.empty:
            series_id = series_id_df['series_id'].iloc[0]
            image_paths = glob(f'{IMAGE_URL}{study_id}/{description_}/*.png')
            if len(image_paths) == 0:
                print(f'{image_type} Study id: {study_id} has no images')
                return [torch.zeros(1, IMAGE_SIZE[0], IMAGE_SIZE[1]) for _ in range(10)]
            try:
                images = [self.PILToTensor(Image.open(path).convert('L')) for path in image_paths[:10]]
                if len(images) < 10:
                    images.extend([torch.zeros(1, IMAGE_SIZE[0], IMAGE_SIZE[1]) for _ in range(10 - len(images))])
                return images
            except Exception as e:
                print(f'Study id: {study_id} {image_type} error while loading image: {str(e)}')
                return [torch.zeros(1, IMAGE_SIZE[0], IMAGE_SIZE[1]) for _ in range(10)]
        return [torch.zeros(1, IMAGE_SIZE[0], IMAGE_SIZE[1]) for _ in range(10)]

    def _load_sagittal_t2(self, study_id):
        self.images[study_id]['Sagittal T2'] = self._load_image_type(study_id, 'Sagittal T2', 'Sagittal T2/STIR')

    def _load_sagittal_t1(self, study_id):
        self.images[study_id]['Sagittal T1'] = self._load_image_type(study_id, 'Sagittal T1', 'Sagittal T1')

    def _load_axial_t2(self, study_id):
        self.images[study_id]['Axial T2'] = self._load_image_type(study_id, 'Axial T2', 'Axial T2')
            
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        images = []
        row = self.df.iloc[idx]
        study_id = int(row['study_id'])
        labels = row[1:].values.astype(np.int64)
        
        for image_type in ['Sagittal T2', 'Sagittal T1', 'Axial T2']:
            images.extend(self.images[study_id][image_type])

        if sum([torch.sum(t) for t in images]) == 0:
            raise ValueError(f'No valid images found for study_id: {study_id}')

        if self.transform is not None:
            # Generate a random seed for this sequence
            seed = np.random.randint(2147483647)

            transformed_images = []
            for img in images:
                # Set the seed for this image (ensures same transform for all images in sequence)
                np.random.seed(seed)
                img = img.squeeze(0).numpy()  # Remove channel dimension and convert to numpy
                transformed = self.transform(image=img)['image']
                transformed_images.append(transformed)
            
            images = torch.stack(transformed_images)
        else:
            images = torch.stack(images)

        return images, torch.tensor(labels)

## Data augmentation
Taken from Kaggle winner, but for a different problem. See, for instance, [this notebook](https://www.kaggle.com/code/haqishen/1st-place-soluiton-code-small-ver) (past Kaggle competition winner)

In [10]:
transforms_train = A.Compose([
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=AUGMENTATION_PROBABILITY),
    A.OneOf([
        A.MotionBlur(blur_limit=5),
        A.MedianBlur(blur_limit=5),
        A.GaussianBlur(blur_limit=5),
        A.GaussNoise(var_limit=(5.0, 30.0)),
    ], p=AUGMENTATION_PROBABILITY),
    A.OneOf([
        A.OpticalDistortion(distort_limit=1.0),
        A.GridDistortion(num_steps=5, distort_limit=1.),
        A.ElasticTransform(alpha=3),
    ], p=AUGMENTATION_PROBABILITY),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=AUGMENTATION_PROBABILITY),
    A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
    A.CoarseDropout(max_holes=4, max_height=8, max_width=8, min_holes=1, min_height=4, min_width=4, p=AUGMENTATION_PROBABILITY),    
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2()
])

transforms_validation = A.Compose([
    A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2()
])

if DEBUG or not AUGMENTATION:
    transforms_train = transforms_validation

  validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self)


## Let's try here the Dataloader

In [11]:
def test_rsna24_dataset(num_studies=10):
    print("Testing RSNA24Dataset...")
    
    # Use the first num_studies rows from train_df
    test_df = train_df.head(num_studies)
    
    # Create a dataset instance
    dataset = RSNA24Dataset(df=test_df, descriptions_df=descriptions_df, phase='train', transform=transforms_train)
    
    # Test __len__ method
    assert len(dataset) == num_studies, f"Expected {num_studies} items, but got {len(dataset)}"
    print(f"Dataset length: {len(dataset)} (as expected)")
    
    # Test __getitem__ method for each study
    for idx in range(num_studies):
        try:
            images, labels = dataset[idx]
            
            # Check images
            assert images.shape == (30, 1, 224, 224), f"Expected shape (30, 1, 224, 224), but got {images.shape}"
            assert images.dtype == torch.float32, f"Expected dtype torch.float32, but got {images.dtype}"
            assert images.min() >= -1 and images.max() <= 1, f"Expected values in range [-1, 1], but got [{images.min()}, {images.max()}]"
            
            # Check labels
            assert labels.shape == (25,), f"Expected shape (25,), but got {labels.shape}"
            assert labels.dtype == torch.int64, f"Expected dtype torch.int64, but got {labels.dtype}"
            assert all(label in [0, 1, 2] for label in labels), f"Expected labels to be 0, 1, or 2, but got {labels.tolist()}"
            
            print(f"Study {idx+1}/{num_studies} passed all checks")
        
        except Exception as e:
            print(f"Error processing study {idx+1}/{num_studies}: {str(e)}")
    
    print("RSNA24Dataset test completed.")

# Run the test
if DEBUG:
    test_rsna24_dataset(10)

## Define the Model
Didn't have much time so just sticking to the stuff on the notebook we are drawing inspiration from

In [12]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.2, max_len=10):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


In [13]:
class ResidualTransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dropout):
        super(ResidualTransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            ResidualTransformerEncoderLayer(d_model, nhead, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, src):
        for layer in self.layers:
            src = layer(src)
        return self.norm(src)

class ResidualTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super(ResidualTransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src):
        # Multi-head attention with residual connection
        attn_output, _ = self.self_attn(src, src, src)
        src = src + self.dropout(attn_output)
        src = self.norm1(src)
        
        # Feed-forward network with residual connection
        ff_output = self.feed_forward(src)
        src = src + self.dropout(ff_output)
        src = self.norm2(src)
        
        return src

In [14]:
# 'tf_efficientnetv2_b3'
class OptimizedMultiScaleEfficientNet(nn.Module):
    def __init__(self, model_name='densenet121', pretrained=True, num_classes=256):
        super(OptimizedMultiScaleEfficientNet, self).__init__()
        
        # Load the EfficientNet model
        self.efficientnet = timm.create_model(model_name, pretrained=pretrained, features_only=True, in_chans=1)
        
        # Get the number of channels in each feature map
        self.channels = self.efficientnet.feature_info.channels()
        
        # Select three scales: early, middle, and late
        self.scales = [-3, -2, -1]
        
        # Global average pooling
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        # Projection layers for each selected scale
        self.projections = nn.ModuleList([
            nn.Linear(self.channels[i], num_classes // 4) for i in self.scales[:2]
        ] + [nn.Linear(self.channels[self.scales[-1]], num_classes // 2)])
        
        # Final fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(num_classes, num_classes),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(num_classes, num_classes)
        )
        
        # Freeze all layers initially
        self.freeze_all_layers()
        self.num_unfrozen_layers = 0
        self.unfreeze_layers(num_layers_to_unfreeze = 2)

    def freeze_all_layers(self):
        for param in self.efficientnet.parameters():
            param.requires_grad = False

    def unfreeze_layers(self, num_layers_to_unfreeze=1):
        layers_to_unfreeze = [
            'features_denseblock4',
            'features_transition3',
            'features_denseblock3',
            'features_transition2',
            'features_denseblock2',
            'features_transition1',
            'features_denseblock1',
            'features_norm0',
            'features_conv0'
        ]
        
        curr_unfrozen = self.num_unfrozen_layers
        
        if curr_unfrozen >= len(layers_to_unfreeze) - 3:
            print('All layers are trainable')
        else:
            for i, layer_name in enumerate(layers_to_unfreeze[curr_unfrozen: min(curr_unfrozen + num_layers_to_unfreeze, len(layers_to_unfreeze))]):
                self.num_unfrozen_layers += 1
                if hasattr(self.efficientnet, layer_name):
                    for param in getattr(self.efficientnet, layer_name).parameters():
                        param.requires_grad = True
                    print(f"Unfrozen layer: {layer_name}")

    def forward(self, x):
        # Get features from EfficientNet
        features = self.efficientnet(x)
        
        # Process selected scales
        multi_scale_features = []
        for i, proj in zip(self.scales, self.projections):
            feat = features[i]
            feat = self.gap(feat).squeeze(-1).squeeze(-1)
            feat = proj(feat)
            multi_scale_features.append(feat)
        
        # Concatenate features from all scales
        combined_features = torch.cat(multi_scale_features, dim=1)
        
        # Final fusion
        output = self.fusion(combined_features)
        
        return output

class ImageSequenceModel(nn.Module):
    def __init__(self, num_classes, nhead, num_encoder_layers, reduced_dim=256, num_images=30, dropout=0.2):
        super(ImageSequenceModel, self).__init__()
        
        # Use OptimizedMultiScaleEfficientNet as the base model
        self.base_model = OptimizedMultiScaleEfficientNet(num_classes=reduced_dim)
        
        self.reduced_dim = reduced_dim
        
        # Positional encoding with fixed number of images
        self.positional_encoding = PositionalEncoding(reduced_dim, dropout, num_images)
        
        # Multi-head attention layer
        self.multihead_attn = nn.MultiheadAttention(reduced_dim, nhead, dropout=dropout, batch_first=True)
        
        # Residual Transformer encoder
        self.transformer_encoder = ResidualTransformerEncoder(reduced_dim, nhead, num_encoder_layers, dropout)
        
        # Final classification layer
        self.classifier = nn.Linear(reduced_dim, num_classes)

    def forward(self, x):
        batch_size, num_images, channels, height, width = x.size()
        
        # Reshape to process all images at once
        x = x.view(batch_size * num_images, channels, height, width)
        
        # Extract multi-scale features using the base model
        features = self.base_model(x)  # Shape: (batch_size * num_images, reduced_dim)
        
        # Reshape to form a sequence
        embeddings = features.view(batch_size, num_images, -1)  # Shape: (batch_size, num_images, reduced_dim)
        
        # Apply positional encoding
        embeddings = self.positional_encoding(embeddings)
        
        # Apply multi-head attention
        attn_output, _ = self.multihead_attn(embeddings, embeddings, embeddings)
        
        # Residual connection after multi-head attention
        embeddings = embeddings + attn_output
        
        # Transformer encoding with residual connections
        transformer_output = self.transformer_encoder(embeddings)  # Shape: (batch_size, num_images, reduced_dim)
        
        # Mean pooling across the sequence dimension (average across images)
        sequence_representation = transformer_output.mean(dim=1)  # Shape: (batch_size, reduced_dim)
        
        # Classification
        output = self.classifier(sequence_representation)
        
        return output

In [15]:
def test_model_memory_usage(batch_size=1, num_images=30, image_size=(230, 230), num_classes=75, nhead=8, num_encoder_layers=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = ImageSequenceModel(
        num_classes=num_classes,
        nhead=nhead,
        num_encoder_layers=num_encoder_layers,
        reduced_dim=384,
        num_images=num_images,
        dropout=0.1
    ).to(device)

    dummy_input = torch.randn(batch_size, num_images, 1, *image_size).to(device)
    dummy_target = torch.randint(0, num_classes, (batch_size,)).to(device)
    print(f"Input shape: {dummy_input.shape}")

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())

    torch.cuda.empty_cache()
    gc.collect()

    initial_memory = torch.cuda.memory_allocated(device)

    try:
        # Forward pass
        output = model(dummy_input)
        loss = criterion(output, dummy_target)

        print(f"Forward pass successful.")
        print(f"Output shape: {output.shape}")
        
        forward_memory = torch.cuda.memory_allocated(device)
        print(f"Memory after forward pass: {forward_memory / 1e6:.2f} MB")
        print(f"Memory used in forward pass: {(forward_memory - initial_memory) / 1e6:.2f} MB")

        # Backward pass
        loss.backward()

        backward_memory = torch.cuda.memory_allocated(device)
        print(f"Memory after backward pass: {backward_memory / 1e6:.2f} MB")
        print(f"Memory used in backward pass: {(backward_memory - forward_memory) / 1e6:.2f} MB")

        # Optimizer step
        optimizer.step()
        optimizer.zero_grad()

        final_memory = torch.cuda.memory_allocated(device)
        print(f"Final memory usage: {final_memory / 1e6:.2f} MB")
        print(f"Total memory difference: {(final_memory - initial_memory) / 1e6:.2f} MB")

        peak_memory = torch.cuda.max_memory_allocated(device)
        print(f"Peak memory usage: {peak_memory / 1e6:.2f} MB")

    except RuntimeError as e:
        print(f"Error during computation: {e}")
        import traceback
        traceback.print_exc()
    
    torch.cuda.empty_cache()
    gc.collect()

# Run the test
if DEBUG:
    test_model_memory_usage(batch_size=4, num_images=30, image_size=(224, 224), num_classes=75, nhead=16, num_encoder_layers=8)

In [16]:
def test_model_performance(batch_size=2, grad_acc=8, num_images=30, image_size=(230, 230), num_classes=75, nhead=12, num_encoder_layers=6):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = ImageSequenceModel(
        num_classes=num_classes,
        nhead=nhead,
        num_encoder_layers=num_encoder_layers,
        reduced_dim=512,
        num_images=num_images,
        dropout=0.1
    ).to(device)

    dummy_input = torch.randn(batch_size, num_images, 1, *image_size).to(device)
    dummy_target = torch.randint(0, num_classes, (batch_size,)).to(device)

    criterion = nn.CrossEntropyLoss()
    autocast = torch.cuda.amp.autocast(enabled=USE_AUTOMATIC_MIXED_PRECISION, dtype=torch.half)
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AUTOMATIC_MIXED_PRECISION, init_scale=4096)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=WEIGHT_DECAY)
    torch.cuda.empty_cache()
    gc.collect()

    try:
        for i in range(grad_acc):
            with autocast:
                output = model(dummy_input)
                loss = criterion(output, dummy_target)
                loss = loss / grad_acc

            scaler.scale(loss).backward()

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        peak_memory = torch.cuda.max_memory_allocated(device)
        print(f"Peak memory usage: {peak_memory / 1e6:.2f} MB")
        print(f"Effective batch size: {batch_size * grad_acc}")

    except RuntimeError as e:
        print(f"Error during computation: {e}")
        import traceback
        traceback.print_exc()
    
    torch.cuda.empty_cache()
    gc.collect()

# Run the test
if DEBUG:
    test_model_performance(batch_size=4, grad_acc=4, num_images=30, image_size=(224, 224), num_classes=75, nhead=16, num_encoder_layers=8)

## Check if the model works
Commenting these out to save memory. Somehow it does not reset when I delete. Commented out during actual training

## Training the model

In [17]:
## helper functions to keep track of memory usage 
def get_ram_usage():
    """Returns the current CPU RAM usage in MB."""
    process = psutil.Process()
    mem_info = process.memory_info()
    return int(mem_info.rss / (1024 ** 2))  # Convert bytes to MB

def get_gpu_memory_usage(device=0):
    """Returns the current GPU VRAM usage in MB for the specified device."""
    return int(torch.cuda.memory_allocated(device) / (1024 ** 2))  # Convert bytes to MB

In [18]:
def get_optimizer_fine_grained(model, base_lr=1e-3, lr_factor=0.1):
    # Group parameters by component, with finer granularity in the base model
    base_model_early_params = list(model.base_model.efficientnet.features_conv0.parameters()) + \
                              list(model.base_model.efficientnet.features_norm0.parameters()) + \
                              list(model.base_model.efficientnet.features_pool0.parameters()) + \
                              list(model.base_model.efficientnet.features_denseblock1.parameters())
    
    base_model_middle_params = list(model.base_model.efficientnet.features_transition1.parameters()) + \
                               list(model.base_model.efficientnet.features_denseblock2.parameters()) + \
                               list(model.base_model.efficientnet.features_transition2.parameters())
    
    base_model_late_params = list(model.base_model.efficientnet.features_denseblock3.parameters()) + \
                             list(model.base_model.efficientnet.features_transition3.parameters()) + \
                             list(model.base_model.efficientnet.features_denseblock4.parameters()) + \
                             list(model.base_model.efficientnet.features_norm5.parameters())
    
    base_model_proj_params = list(model.base_model.projections.parameters()) + list(model.base_model.fusion.parameters())
    transformer_params = list(model.multihead_attn.parameters()) + list(model.transformer_encoder.parameters())
    classifier_params = list(model.classifier.parameters())

    # Create parameter groups with different learning rates
    param_groups = [
        {'params': base_model_early_params, 'lr': base_lr * (lr_factor ** 2)},  # Lowest learning rate
        {'params': base_model_middle_params, 'lr': base_lr * (lr_factor ** 1.5)},
        {'params': base_model_late_params, 'lr': base_lr * lr_factor},
        {'params': base_model_proj_params, 'lr': base_lr * (lr_factor ** 0.5)},
        {'params': transformer_params, 'lr': base_lr * (lr_factor ** 0.25)},
        {'params': classifier_params, 'lr': base_lr}  # Highest learning rate
    ]

    # Create optimizer with parameter groups
    optimizer = AdamW(param_groups, lr=base_lr, weight_decay=0.01)
    
    return optimizer

In [19]:
autocast = torch.cuda.amp.autocast(enabled=USE_AUTOMATIC_MIXED_PRECISION, dtype=torch.half)
scaler = torch.cuda.amp.GradScaler(enabled=USE_AUTOMATIC_MIXED_PRECISION, init_scale=4096)

from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

training_index, validation_index = train_test_split(range(len(train_df)), test_size=0.15, random_state=123)

if DEBUG:
    training_index = training_index[:32]
    validation_index = validation_index[:32]

print('#' * 30)
print(f'Starting Training')
print('#' * 30)
print('Training length: ', len(training_index), 'Validation length: ', len(validation_index))
training_rows = train_df.iloc[training_index]
validation_rows = train_df.iloc[validation_index]

start_time = time.time()
max_training_time = 9 * 3600  # 9 hours in seconds

training_dataset = RSNA24Dataset(training_rows, phase='train', transform=transforms_train)
training_dataloader = DataLoader(
    training_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
    num_workers=N_WORKERS
)

validation_dataset = RSNA24Dataset(validation_rows, phase='val', transform=transforms_validation)
validation_dataloader = DataLoader(
    validation_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
    drop_last=False,
    num_workers=N_WORKERS
)

model = ImageSequenceModel(
            num_classes=75,
            nhead=16,
            num_encoder_layers=8,
            reduced_dim=384,
            num_images=30,
            dropout=0.2
        )

model.to(device)
optimizer = get_optimizer_fine_grained(model, base_lr=5e-4, lr_factor=0.01)

# Define warmup steps and total steps
num_warmup_steps = EPOCHS // 10 * len(training_dataloader) // GRAD_ACCUMULATION
num_training_steps = num_total_steps = EPOCHS * len(training_dataloader) // GRAD_ACCUMULATION

# Create a custom learning rate scheduler with warmup
class WarmupCosineSchedule(CosineAnnealingWarmRestarts):
    def __init__(self, optimizer, num_warmup_steps, num_training_steps, **kwargs):
        self.num_warmup_steps = num_warmup_steps
        super().__init__(optimizer, T_0=num_training_steps, **kwargs)

    def get_lr(self):
        if self.last_epoch < self.num_warmup_steps:
            return [base_lr * (self.last_epoch + 1) / self.num_warmup_steps for base_lr in self.base_lrs]
        return super().get_lr()

# Create the scheduler with warmup
scheduler = WarmupCosineSchedule(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
    T_mult=1,  # Multiplier for increasing T_0 after a restart
    eta_min=5e-8  # Minimum learning rate
)

weights = torch.tensor([1.0, 2.0, 4.0])
criterion = nn.CrossEntropyLoss(weight=weights.to(device))
criterion2 = nn.CrossEntropyLoss(weight=weights)

best_loss = float('inf')
best_wall = float('inf')
es_step = 0

max_cpu_ram_usage = 0
max_gpu_ram_usage = 0

for epoch in range(1, EPOCHS + 1):
    print(f'Starting epoch {epoch}')
    gc.collect()
    torch.cuda.empty_cache()
    
    if epoch % 10 == 0 and epoch > 0:
        model.base_model.unfreeze_layers()
        # Recreate optimizer to update learning rates for newly unfrozen layers
        optimizer = get_optimizer_fine_grained(model, base_lr=5e-4, lr_factor=0.1)

    model.train()
    total_loss = 0
    with tqdm(training_dataloader, leave=True) as loaded_items:
        optimizer.zero_grad()
        for idx, (x, labels) in enumerate(loaded_items):
            x = x.to(device)
            labels = labels.to(device)
            
            with torch.autograd.set_detect_anomaly(True):
                with autocast:
                    loss = 0
                    output = model(x)
                    for idx_label in range(N_LABELS):
                        prediction = output[:, idx_label * 3 : idx_label * 3 + 3]
                        actual_label = labels[:, idx_label]
                        loss += criterion(prediction, actual_label) / N_LABELS
                        
                    if torch.isnan(loss):
                        print("NaN loss encountered, skipping update")
                        continue
                    
                    total_loss += loss.item()
                    if GRAD_ACCUMULATION > 1:
                        loss /= GRAD_ACCUMULATION

                cpu_ram_before = get_ram_usage()
                gpu_ram_before = get_gpu_memory_usage()

                scaler.scale(loss).backward()

                torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM or 1e9)

                if (idx + 1) % GRAD_ACCUMULATION == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    scheduler.step()

                loaded_items.set_postfix(
                    OrderedDict(
                        loss=f'{loss.item() * GRAD_ACCUMULATION:.6f}',
                        lr=f'{optimizer.param_groups[0]["lr"]:.3e}',
                        max_ram_usage=f'CPU:{cpu_ram_before} GPU:{gpu_ram_before}'
                    )
                )

                max_cpu_ram_usage = max(max_cpu_ram_usage, cpu_ram_before)
                max_gpu_ram_usage = max(max_gpu_ram_usage, gpu_ram_before)

    train_loss = total_loss / len(training_dataloader)
    print(f'train_loss:{train_loss:.6f}')
    print(f'max CPU RAM usage: {max_cpu_ram_usage}, max GPU RAM usage: {max_gpu_ram_usage}')
    
    #### model evaluation phase ####
        
    total_loss = 0
    output_predictions = []
    row_names = []

    model.eval()
    with tqdm(validation_dataloader, leave=True) as loaded_items:
        with torch.no_grad():
            for idx, (x, labels) in enumerate(loaded_items):
                x = x.to(device)
                labels = labels.to(device)

                with autocast:
                    loss = 0
                    loss_ema = 0
                    output = model(x)
                    for idx_label in range(N_LABELS):
                        prediction = output[:, idx_label * 3 : idx_label * 3 + 3]
                        actual_label = labels[:, idx_label]
                        loss += criterion(prediction, actual_label) / N_LABELS

                        output_prediction = prediction.float()
                        output_predictions.append(output_prediction.cpu())
                        row_names.append(actual_label.cpu())

                    total_loss += loss.item()

    validation_loss = total_loss / len(validation_dataloader)

    output_predictions = torch.cat(output_predictions, dim = 0)
    row_names = torch.cat(row_names)
    validation_wall = criterion2(output_predictions, row_names)

    print(f'Validation loss: {validation_loss:.6f}, validation wall:{validation_wall:.6f}')

    if validation_loss < best_loss or validation_wall < best_wall:
        es_step = 0
        if device != 'cuda:0':
            model.to('cuda:0')

        if validation_loss < best_loss:
            print(f'epoch:{epoch}, best loss updated from {best_loss:.6f} to {validation_loss:.6f}')
            best_loss = validation_loss

        if validation_wall < best_wall:
            print(f'epoch:{epoch}, best wall_metric updated from {best_wall:.6f} to {validation_wall:.6f}')
            best_wall = validation_wall
            model_path = f'{OUTPUT_DIR}/best_wall_model.pt'
            torch.save(model.state_dict(), model_path)

        if device != 'cuda:0': # why do we need to do this again?
            model.to(device)

    else:
        es_step += 1
        if es_step >= EARLY_STOPPING_EPOCH:
            print('Early stopping')
            break
            
    elapsed_time = time.time() - start_time
    if elapsed_time > max_training_time:
        print(f'Training time limit of {max_training_time/3600:.1f} hours reached. Stopping training.')
        break

  autocast = torch.cuda.amp.autocast(enabled=USE_AUTOMATIC_MIXED_PRECISION, dtype=torch.half)
  scaler = torch.cuda.amp.GradScaler(enabled=USE_AUTOMATIC_MIXED_PRECISION, init_scale=4096)


##############################
Starting Training
##############################
Training length:  1678 Validation length:  297


model.safetensors:   0%|          | 0.00/32.3M [00:00<?, ?B/s]

Unfrozen layer: features_denseblock4
Unfrozen layer: features_transition3
Starting epoch 1


100%|██████████| 419/419 [08:24<00:00,  1.21s/it, loss=0.883516, lr=5.014e-09, max_ram_usage=CPU:4562 GPU:843]


train_loss:0.874785
max CPU RAM usage: 4570, max GPU RAM usage: 843


100%|██████████| 75/75 [00:14<00:00,  5.35it/s]


Validation loss: 0.745720, validation wall:0.859779
epoch:1, best loss updated from inf to 0.745720
epoch:1, best wall_metric updated from inf to 0.859779
Starting epoch 2


100%|██████████| 419/419 [08:40<00:00,  1.24s/it, loss=0.748714, lr=9.981e-09, max_ram_usage=CPU:4584 GPU:844]


train_loss:0.761642
max CPU RAM usage: 4584, max GPU RAM usage: 844


100%|██████████| 75/75 [00:12<00:00,  5.81it/s]


Validation loss: 0.755455, validation wall:0.889153
Starting epoch 3


100%|██████████| 419/419 [08:23<00:00,  1.20s/it, loss=0.968956, lr=1.495e-08, max_ram_usage=CPU:4585 GPU:844]


train_loss:0.763171
max CPU RAM usage: 4585, max GPU RAM usage: 844


100%|██████████| 75/75 [00:12<00:00,  5.82it/s]


Validation loss: 0.747127, validation wall:0.864441
Starting epoch 4


100%|██████████| 419/419 [08:26<00:00,  1.21s/it, loss=1.004040, lr=1.991e-08, max_ram_usage=CPU:4585 GPU:844]


train_loss:0.765166
max CPU RAM usage: 4585, max GPU RAM usage: 844


100%|██████████| 75/75 [00:12<00:00,  5.81it/s]


Validation loss: 0.743839, validation wall:0.865927
epoch:4, best loss updated from 0.745720 to 0.743839
Starting epoch 5


100%|██████████| 419/419 [08:27<00:00,  1.21s/it, loss=0.715825, lr=2.488e-08, max_ram_usage=CPU:4585 GPU:844]


train_loss:0.761185
max CPU RAM usage: 4585, max GPU RAM usage: 844


100%|██████████| 75/75 [00:13<00:00,  5.76it/s]


Validation loss: 0.750783, validation wall:0.873277
Starting epoch 6


100%|██████████| 419/419 [08:41<00:00,  1.25s/it, loss=0.528168, lr=2.985e-08, max_ram_usage=CPU:4585 GPU:844]


train_loss:0.764291
max CPU RAM usage: 4585, max GPU RAM usage: 844


100%|██████████| 75/75 [00:12<00:00,  5.81it/s]


Validation loss: 0.753264, validation wall:0.885000
Starting epoch 7


100%|██████████| 419/419 [08:17<00:00,  1.19s/it, loss=0.752792, lr=3.481e-08, max_ram_usage=CPU:4585 GPU:844]


train_loss:0.760335
max CPU RAM usage: 4585, max GPU RAM usage: 844


100%|██████████| 75/75 [00:12<00:00,  5.77it/s]


Validation loss: 0.763301, validation wall:0.875693
Starting epoch 8


100%|██████████| 419/419 [08:25<00:00,  1.21s/it, loss=0.621087, lr=3.978e-08, max_ram_usage=CPU:4585 GPU:844]


train_loss:0.764135
max CPU RAM usage: 4585, max GPU RAM usage: 844


100%|██████████| 75/75 [00:12<00:00,  5.82it/s]


Validation loss: 0.755734, validation wall:0.871146
Starting epoch 9


100%|██████████| 419/419 [08:44<00:00,  1.25s/it, loss=1.011812, lr=4.475e-08, max_ram_usage=CPU:4585 GPU:844]


train_loss:0.759353
max CPU RAM usage: 4585, max GPU RAM usage: 844


100%|██████████| 75/75 [00:12<00:00,  5.77it/s]


Validation loss: 0.752968, validation wall:0.866627
Starting epoch 10
Unfrozen layer: features_denseblock3


100%|██████████| 419/419 [09:16<00:00,  1.33s/it, loss=0.956790, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2662]


train_loss:0.767902
max CPU RAM usage: 4592, max GPU RAM usage: 2665


100%|██████████| 75/75 [00:12<00:00,  5.77it/s]


Validation loss: 0.771808, validation wall:0.917785
Starting epoch 11


100%|██████████| 419/419 [09:17<00:00,  1.33s/it, loss=1.103177, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2657]


train_loss:0.760465
max CPU RAM usage: 4592, max GPU RAM usage: 2665


100%|██████████| 75/75 [00:13<00:00,  5.76it/s]


Validation loss: 0.746938, validation wall:0.861647
Starting epoch 12


100%|██████████| 419/419 [09:13<00:00,  1.32s/it, loss=0.492192, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2659]


train_loss:0.754141
max CPU RAM usage: 4592, max GPU RAM usage: 2665


100%|██████████| 75/75 [00:12<00:00,  5.78it/s]


Validation loss: 0.739221, validation wall:0.870628
epoch:12, best loss updated from 0.743839 to 0.739221
Starting epoch 13


100%|██████████| 419/419 [09:11<00:00,  1.32s/it, loss=0.634630, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2658]


train_loss:0.756724
max CPU RAM usage: 4592, max GPU RAM usage: 2665


100%|██████████| 75/75 [00:12<00:00,  5.77it/s]


Validation loss: 0.742862, validation wall:0.872121
Starting epoch 14


100%|██████████| 419/419 [09:10<00:00,  1.31s/it, loss=0.782513, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2660]


train_loss:0.753497
max CPU RAM usage: 4592, max GPU RAM usage: 2665


100%|██████████| 75/75 [00:12<00:00,  5.80it/s]


Validation loss: 0.738924, validation wall:0.863090
epoch:14, best loss updated from 0.739221 to 0.738924
Starting epoch 15


100%|██████████| 419/419 [09:14<00:00,  1.32s/it, loss=0.514151, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2665]


train_loss:0.749649
max CPU RAM usage: 4592, max GPU RAM usage: 2665


100%|██████████| 75/75 [00:12<00:00,  5.78it/s]


Validation loss: 0.741401, validation wall:0.859196
epoch:15, best wall_metric updated from 0.859779 to 0.859196
Starting epoch 16


100%|██████████| 419/419 [09:20<00:00,  1.34s/it, loss=0.651989, lr=5.000e-06, max_ram_usage=CPU:4594 GPU:2656]


train_loss:0.750850
max CPU RAM usage: 4594, max GPU RAM usage: 2665


100%|██████████| 75/75 [00:12<00:00,  5.78it/s]


Validation loss: 0.755044, validation wall:0.903780
Starting epoch 17


100%|██████████| 419/419 [09:20<00:00,  1.34s/it, loss=1.298919, lr=5.000e-06, max_ram_usage=CPU:4594 GPU:2660]


train_loss:0.762770
max CPU RAM usage: 4594, max GPU RAM usage: 2666


100%|██████████| 75/75 [00:12<00:00,  5.78it/s]


Validation loss: 0.743100, validation wall:0.872114
Starting epoch 18


100%|██████████| 419/419 [09:16<00:00,  1.33s/it, loss=0.852115, lr=5.000e-06, max_ram_usage=CPU:4594 GPU:2657]


train_loss:0.749955
max CPU RAM usage: 4594, max GPU RAM usage: 2666


100%|██████████| 75/75 [00:13<00:00,  5.77it/s]


Validation loss: 0.745258, validation wall:0.882182
Starting epoch 19


100%|██████████| 419/419 [09:28<00:00,  1.36s/it, loss=0.537852, lr=5.000e-06, max_ram_usage=CPU:4594 GPU:2657]


train_loss:0.748433
max CPU RAM usage: 4594, max GPU RAM usage: 2666


100%|██████████| 75/75 [00:12<00:00,  5.77it/s]


Validation loss: 0.735408, validation wall:0.855303
epoch:19, best loss updated from 0.738924 to 0.735408
epoch:19, best wall_metric updated from 0.859196 to 0.855303
Starting epoch 20
Unfrozen layer: features_transition2


100%|██████████| 419/419 [09:20<00:00,  1.34s/it, loss=0.579649, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2895]


train_loss:0.752727
max CPU RAM usage: 4594, max GPU RAM usage: 2895


100%|██████████| 75/75 [00:12<00:00,  5.81it/s]


Validation loss: 0.732983, validation wall:0.854072
epoch:20, best loss updated from 0.735408 to 0.732983
epoch:20, best wall_metric updated from 0.855303 to 0.854072
Starting epoch 21


100%|██████████| 419/419 [09:22<00:00,  1.34s/it, loss=0.694799, lr=5.000e-06, max_ram_usage=CPU:4595 GPU:2892]


train_loss:0.748324
max CPU RAM usage: 4595, max GPU RAM usage: 2895


100%|██████████| 75/75 [00:12<00:00,  5.79it/s]


Validation loss: 0.732453, validation wall:0.851338
epoch:21, best loss updated from 0.732983 to 0.732453
epoch:21, best wall_metric updated from 0.854072 to 0.851338
Starting epoch 22


100%|██████████| 419/419 [09:30<00:00,  1.36s/it, loss=1.143981, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2895]


train_loss:0.745383
max CPU RAM usage: 4640, max GPU RAM usage: 2895


100%|██████████| 75/75 [00:12<00:00,  5.79it/s]


Validation loss: 0.732243, validation wall:0.854292
epoch:22, best loss updated from 0.732453 to 0.732243
Starting epoch 23


100%|██████████| 419/419 [09:23<00:00,  1.34s/it, loss=0.778565, lr=5.000e-06, max_ram_usage=CPU:4591 GPU:2894]


train_loss:0.746266
max CPU RAM usage: 4640, max GPU RAM usage: 2896


100%|██████████| 75/75 [00:13<00:00,  5.76it/s]


Validation loss: 0.741625, validation wall:0.878360
Starting epoch 24


100%|██████████| 419/419 [09:22<00:00,  1.34s/it, loss=0.669409, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2888]


train_loss:0.748699
max CPU RAM usage: 4640, max GPU RAM usage: 2896


100%|██████████| 75/75 [00:13<00:00,  5.76it/s]


Validation loss: 0.737470, validation wall:0.854423
Starting epoch 25


100%|██████████| 419/419 [09:24<00:00,  1.35s/it, loss=0.742395, lr=5.000e-06, max_ram_usage=CPU:4591 GPU:2892]


train_loss:0.748191
max CPU RAM usage: 4640, max GPU RAM usage: 2896


100%|██████████| 75/75 [00:12<00:00,  5.78it/s]


Validation loss: 0.742914, validation wall:0.852151
Starting epoch 26


100%|██████████| 419/419 [09:28<00:00,  1.36s/it, loss=0.906344, lr=5.000e-06, max_ram_usage=CPU:4591 GPU:2892]


train_loss:0.745003
max CPU RAM usage: 4640, max GPU RAM usage: 2896


100%|██████████| 75/75 [00:12<00:00,  5.78it/s]


Validation loss: 0.738639, validation wall:0.865132
Starting epoch 27


100%|██████████| 419/419 [09:32<00:00,  1.37s/it, loss=0.723347, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2891]


train_loss:0.744401
max CPU RAM usage: 4640, max GPU RAM usage: 2896


100%|██████████| 75/75 [00:12<00:00,  5.80it/s]


Validation loss: 0.735670, validation wall:0.852354
Starting epoch 28


100%|██████████| 419/419 [09:24<00:00,  1.35s/it, loss=0.550682, lr=5.000e-06, max_ram_usage=CPU:4592 GPU:2892]


train_loss:0.745636
max CPU RAM usage: 4640, max GPU RAM usage: 2896


100%|██████████| 75/75 [00:13<00:00,  5.76it/s]


Validation loss: 0.730743, validation wall:0.849712
epoch:28, best loss updated from 0.732243 to 0.730743
epoch:28, best wall_metric updated from 0.851338 to 0.849712
Starting epoch 29


100%|██████████| 419/419 [09:28<00:00,  1.36s/it, loss=0.542613, lr=5.000e-06, max_ram_usage=CPU:4595 GPU:2887]


train_loss:0.744655
max CPU RAM usage: 4640, max GPU RAM usage: 2896


100%|██████████| 75/75 [00:12<00:00,  5.77it/s]


Validation loss: 0.743927, validation wall:0.873017
Starting epoch 30
Unfrozen layer: features_denseblock2


100%|██████████| 419/419 [10:35<00:00,  1.52s/it, loss=0.601260, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4763]


train_loss:0.749352
max CPU RAM usage: 4640, max GPU RAM usage: 4763


100%|██████████| 75/75 [00:12<00:00,  5.79it/s]


Validation loss: 0.735465, validation wall:0.869267
Starting epoch 31


100%|██████████| 419/419 [10:32<00:00,  1.51s/it, loss=0.772769, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4765]


train_loss:0.744308
max CPU RAM usage: 4640, max GPU RAM usage: 4765


100%|██████████| 75/75 [00:12<00:00,  5.78it/s]


Validation loss: 0.737538, validation wall:0.859309
Starting epoch 32


100%|██████████| 419/419 [10:32<00:00,  1.51s/it, loss=0.605375, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4766]


train_loss:0.743937
max CPU RAM usage: 4640, max GPU RAM usage: 4766


100%|██████████| 75/75 [00:13<00:00,  5.77it/s]


Validation loss: 0.738773, validation wall:0.856385
Starting epoch 33


100%|██████████| 419/419 [10:36<00:00,  1.52s/it, loss=1.211774, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4767]


train_loss:0.742407
max CPU RAM usage: 4640, max GPU RAM usage: 4767


100%|██████████| 75/75 [00:13<00:00,  5.75it/s]


Validation loss: 0.734369, validation wall:0.864106
Starting epoch 34


100%|██████████| 419/419 [10:33<00:00,  1.51s/it, loss=0.738635, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4768]


train_loss:0.743080
max CPU RAM usage: 4640, max GPU RAM usage: 4768


100%|██████████| 75/75 [00:12<00:00,  5.78it/s]


Validation loss: 0.735711, validation wall:0.866466
Starting epoch 35


  3%|▎         | 14/419 [00:29<09:45,  1.45s/it, loss=0.841403, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4676]

NaN loss encountered, skipping update


100%|██████████| 419/419 [10:38<00:00,  1.52s/it, loss=0.711236, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4763]


train_loss:0.746790
max CPU RAM usage: 4640, max GPU RAM usage: 4772


100%|██████████| 75/75 [00:12<00:00,  5.77it/s]


Validation loss: 0.735455, validation wall:0.866101
Starting epoch 36


100%|██████████| 419/419 [10:34<00:00,  1.52s/it, loss=0.830175, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4763]


train_loss:0.744409
max CPU RAM usage: 4640, max GPU RAM usage: 4772


100%|██████████| 75/75 [00:12<00:00,  5.80it/s]


Validation loss: 0.735311, validation wall:0.863491
Starting epoch 37


100%|██████████| 419/419 [10:33<00:00,  1.51s/it, loss=1.217796, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4763]


train_loss:0.744406
max CPU RAM usage: 4640, max GPU RAM usage: 4772


100%|██████████| 75/75 [00:12<00:00,  5.81it/s]


Validation loss: 0.738601, validation wall:0.874054
Starting epoch 38


100%|██████████| 419/419 [10:30<00:00,  1.50s/it, loss=0.661406, lr=5.000e-06, max_ram_usage=CPU:4598 GPU:4766]


train_loss:0.744076
max CPU RAM usage: 4640, max GPU RAM usage: 4772


100%|██████████| 75/75 [00:12<00:00,  5.79it/s]


Validation loss: 0.734342, validation wall:0.863219
Early stopping
