In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
import torchvision.transforms.v2 as transforms
from PIL import Image
import pandas as pd

In [2]:
from custom_dataloaders import *

from models import *

from custom_logging import *

from mean_teacher import *

In [3]:
class Config:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_classes = 3
    initial_lr = 1e-4
    lr_backbone = 3e-5
    batch_size = 32
    num_epochs = 5
    freeze_until_epoch = 0
    checkpoint_path = "./checkpoints"
    log_interval = 10
    early_stopping_patience = 3
    use_amp = True
    consistency_weight = 0.5
    ema_decay = 0.99
    warmup_steps = 500
    
    # Create directories
    os.makedirs(checkpoint_path, exist_ok=True)

config = Config()

In [4]:
logger = Logger(log_dir="./logs", experiment_name=f"resnet152_mean_teacher_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
logger.log_config(config)

print(f"Initializing training on {config.device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Logging to: ./logs\resnet152_mean_teacher_20250803_164505
Initializing training on cuda
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3090
CUDA memory: 25.8 GB


## data transforms

In [5]:
weak_transform = transforms.Compose([
    # transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), # move this out
    
    transforms.RandomRotation(degrees=(-5, 5), interpolation=transforms.InterpolationMode.BILINEAR, expand=True, fill=0),
    transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # Resize maintaining aspect ratio, then pad to square
    transforms.RandomCrop(224), 

    transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.02),
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.005), transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),

    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

strong_transform = transforms.Compose([
    # transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True),
    
    transforms.RandomRotation(degrees=(-15, 15), interpolation=transforms.InterpolationMode.BILINEAR, expand=True, fill=0),
    transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # Resize maintaining aspect ratio, then pad to square
    transforms.RandomCrop(224), 

    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.01), transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
    
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    
    transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # Resize maintaining aspect ratio, then pad to square
    transforms.CenterCrop(224), 
])


In [None]:
train_dataset = StandardImageDataset(
    csv_file="train.csv", 
    root_dir=".\\data\\train",
    weak_transform=weak_transform, 
    strong_transform=strong_transform,
    device= torch.device('cuda'), # set to CPU if a batch of 64 takes longer than 5.5s on the model
    # currently it costs around 1s of GPU time to do transforms on the GPU (probably a bit less, if scheudle optimizations are happening)
)

val_dataset = ImageDataset(
    csv_file="val.csv", 
    root_dir=".\\data\\val", 
    transform=val_transform
)


# Wrap with CUDA prefetcher
train_loader = DataLoader(train_dataset, config.batch_size, True)
val_loader = DataLoader(val_dataset, config.batch_size, True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# TODO: these batch_sizes aren't getting loaded from checkpoint. 


Creating optimized datasets...
Train batches: 7823
Val batches: 978


In [7]:

# Initialize models
print("\nInitializing models...")
model = ResNet152Classifier(num_classes=config.num_classes).to(config.device)
teacher_model = copy.deepcopy(model).to(config.device)
for p in teacher_model.parameters():
    p.requires_grad = False

# Optimizer setup
backbone_params = []
fc_params = []

for name, param in model.backbone.named_parameters():
    if name.startswith("fc."):
        fc_params.append(param)
    elif name.startswith("conv1."):
        param.requires_grad = False
    else:
        backbone_params.append(param)

optimizer = torch.optim.AdamW([
    {'params': backbone_params, 'lr': config.lr_backbone, 'weight_decay': 1e-4},
    {'params': fc_params, 'lr': config.initial_lr, 'weight_decay': 1e-4},
])

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# Training setup
criterion = nn.CrossEntropyLoss().to(config.device)
scaler = GradScaler() if config.use_amp else None
monitor = PerformanceMonitor()




Initializing models...


In [None]:
def load_checkpoint(checkpoint_path, config, device='cuda'):
    """
    Load a checkpoint and restore training state
    
    Args:
        checkpoint_path: Path to the .pth file
        config: Configuration object (will be updated with checkpoint config)
        device: Device to load the model on
    
    Returns:
        Dictionary with loaded components
    """
    print(f"Loading checkpoint from: {checkpoint_path}")
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Update config with checkpoint settings (optional)
    if 'config' in checkpoint:
        checkpoint_config = checkpoint['config']
        print("Checkpoint configuration:")
        for key, value in checkpoint_config.items():
            print(f"  {key}: {value}")
            # Optionally update current config
            # setattr(config, key, value)
    
    # Initialize models with same architecture
    model = ResNet152Classifier(num_classes=config.num_classes).to(device)
    teacher_model = ResNet152Classifier(num_classes=config.num_classes).to(device)
    
    # Load model weights
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Loaded student model weights")
    
    if 'teacher_state_dict' in checkpoint:
        teacher_model.load_state_dict(checkpoint['teacher_state_dict'])
        print("Loaded teacher model weights")
    
    # Make teacher model non-trainable
    for p in teacher_model.parameters():
        p.requires_grad = False
    
    # Initialize optimizer (same structure as training)
    backbone_params = []
    fc_params = []
    
    for name, param in model.backbone.named_parameters():
        if name.startswith("fc."):
            fc_params.append(param)
        elif name.startswith("conv1."):
            param.requires_grad = False
        else:
            backbone_params.append(param)
    
    optimizer = torch.optim.AdamW([
        {'params': backbone_params, 'lr': config.lr_backbone, 'weight_decay': 1e-4},
        {'params': fc_params, 'lr': config.initial_lr, 'weight_decay': 1e-4},
    ])
    
    # Load optimizer state
    if 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("Loaded optimizer state")
    
    # Initialize scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    
    # Load scheduler state
    if 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        print("Loaded scheduler state")
    
    # Get training progress info
    start_epoch = checkpoint.get('epoch', -1) + 1  # Resume from next epoch
    best_accuracy = checkpoint.get('val_accuracy', 0)
    
    print(f"\nCheckpoint info:")
    print(f"  Saved at epoch: {checkpoint.get('epoch', 'unknown')}")
    print(f"  Best validation accuracy: {best_accuracy:.4f}")
    print(f"  Will resume from epoch: {start_epoch}")
    
    return {
        'model': model,
        'teacher_model': teacher_model,
        'optimizer': optimizer,
        'scheduler': scheduler,
        'start_epoch': start_epoch,
        'best_accuracy': best_accuracy,
        'checkpoint': checkpoint
    }




In [None]:
# IF WE ARE RESUMING:

loaded = load_checkpoint("checkpoints/best_model_acc0.7807.pth", config)

model = loaded['model']
teacher_model = loaded['teacher_model']
optimizer = loaded['optimizer']
scheduler = loaded['scheduler']
start_epoch = loaded['start_epoch']
best_accuracy = loaded['best_accuracy']

monitor.best_accuracy = best_accuracy


Loading checkpoint from: checkpoints/best_model_acc0.7807.pth
Checkpoint configuration:
  freeze_until_epoch: 1
  num_epochs: 2
Loaded student model weights
Loaded teacher model weights
Loaded optimizer state
Loaded scheduler state

Checkpoint info:
  Saved at epoch: 1
  Best validation accuracy: 0.7807
  Will resume from epoch: 2


In [18]:
# # run the one off code shit here
config.num_epochs = 10
# config.batch_size=
start_epoch=5

## main training loop

In [8]:

def validate(model, val_loader, config, monitor, epoch):
    """Validate model and compute metrics"""
    model.eval()
    total, correct = 0, 0
    class_correct = [0] * config.num_classes
    class_total = [0] * config.num_classes
    all_preds, all_labels = [], []
    val_loss = 0
    
    # Handle both regular DataLoader and DataPrefetcher
    if isinstance(val_loader, DataPrefetcher):
        total_batches = len(val_loader.loader)
    else:
        total_batches = len(val_loader)
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f'Validation Epoch {epoch}', total=total_batches, leave=False)
        for batch_data in val_pbar:
            
            x, y = batch_data
            x = x.to(config.device)
            y = y.to(config.device)
            
            outputs = model(x)
            loss = F.cross_entropy(outputs, y)
            val_loss += loss.item()
            
            preds = outputs.argmax(dim=1)
            preds_cpu = preds.cpu().numpy()
            y_cpu = y.cpu().numpy()
            all_preds.extend(preds_cpu)
            all_labels.extend(y_cpu)
            
            total += y.size(0)
            correct += (preds == y).sum().item()
            
            for i, label in enumerate(y_cpu):
                class_total[label] += 1
                class_correct[label] += (preds_cpu[i] == label).item()
            
            val_pbar.set_postfix({'loss': f'{loss.item():.4f}', 
                                  'acc': f'{correct/total:.4f}'})

    accuracy = correct / total 
    avg_val_loss = val_loss / total_batches
    class_names = ['bad', 'neutral', 'good']
    
    # Calculate per-class metrics
    class_recall = [class_correct[i] / class_total[i] if class_total[i] else 0 for i in range(config.num_classes)]
    
    # Calculate confusion matrix and classification report
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    
    # Log metrics
    metrics = {
        "val/accuracy": accuracy,
        "val/loss": avg_val_loss,
        "val/recall_bad": class_recall[0],
        "val/recall_neutral": class_recall[1],
        "val/recall_good": class_recall[2],
    }
    
    # Add precision and F1 scores if available
    for class_name in class_names:
        if class_name in report:
            metrics[f"val/precision_{class_name}"] = report[class_name]['precision']
            metrics[f"val/f1_{class_name}"] = report[class_name]['f1-score']
    
    # Update monitor
    monitor.accuracy_history.append(accuracy)
    
    if accuracy > monitor.best_accuracy:
        monitor.best_accuracy = accuracy
        monitor.epochs_without_improvement = 0
    else:
        monitor.epochs_without_improvement += 1
    
    model.train()
    return accuracy, metrics, cm


In [12]:
# Training loop
print("\nStarting training...")
print(f"Tensorboard logs: tensorboard --logdir {logger.get_log_dir()}")
global_step = start_epoch * len(train_loader)
best_accuracy = 0 
data_loading_times = []
gpu_compute_times = []


Starting training...
Tensorboard logs: tensorboard --logdir ./logs\resnet152_mean_teacher_20250803_164505


In [20]:
print("training epochs: ", list(range(start_epoch, config.num_epochs)))


training epochs:  [5, 6, 7, 8, 9]


In [21]:

for epoch in range(start_epoch, config.num_epochs):
    # Unfreeze backbone if needed
    if epoch == config.freeze_until_epoch:
        print(f"Unfreezing backbone at epoch {epoch}")
        for param in backbone_params:
            param.requires_grad = True
    
    model.train()
    epoch_start = time.time()
    epoch_loss = 0
    epoch_cls_loss = 0
    epoch_consistency_loss = 0

    # Training epoch
    epoch_pbar = tqdm(train_loader, desc=f'Epoch {epoch}', total=len(train_loader))
    
    for batch_idx, batch_data in enumerate(epoch_pbar):
        batch_start = time.time()
        
        x_weak, x_strong, y = batch_data
        data_load_time = time.time() - batch_start
        
        compute_start = time.time()
        
        # Mixed precision training
        if config.use_amp:
            with autocast(device_type='cuda'):
                student_outputs = model(x_strong)
                cls_loss = criterion(student_outputs, y.to(config.device))
                
                with torch.no_grad():
                    teacher_outputs = teacher_model(x_weak)
                
                consistency_weight = config.consistency_weight * min(1.0, global_step / config.warmup_steps)
                consistency = compute_consistency_loss(student_outputs, teacher_outputs)
                loss = cls_loss + consistency_weight * consistency
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            student_outputs = model(x_strong)
            cls_loss = criterion(student_outputs, y)
            
            with torch.no_grad():
                teacher_outputs = teacher_model(x_weak)
            
            consistency_weight = config.consistency_weight * min(1.0, global_step / config.warmup_steps)
            consistency = compute_consistency_loss(student_outputs, teacher_outputs)
            loss = cls_loss + consistency_weight * consistency
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        # Update teacher model
        update_ema_variables(model, teacher_model, alpha=config.ema_decay, global_step=global_step)
        
        # Timing
        gpu_compute_time = time.time() - compute_start
        total_time = time.time() - batch_start
        
        # Track times
        if len(data_loading_times) < 1000:
            data_loading_times.append(data_load_time)
            gpu_compute_times.append(gpu_compute_time)
        
        # Update epoch stats
        epoch_loss += loss.item()
        epoch_cls_loss += cls_loss.item()
        epoch_consistency_loss += consistency.item()
        
        # Calculate throughput
        images_per_second = config.batch_size / total_time
        
        # Update progress bar
        epoch_pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'img/s': f'{images_per_second:.1f}',
            'data_ms': f'{data_load_time*1000:.1f}',
            'gpu_ms': f'{gpu_compute_time*1000:.1f}'
        })
        
        # Logging
        if global_step % config.log_interval == 0:
            avg_data_time = np.mean(data_loading_times[-100:]) if data_loading_times else 0
            avg_gpu_time = np.mean(gpu_compute_times[-100:]) if gpu_compute_times else 0
            current_lr = optimizer.param_groups[0]['lr']
            
            # Log to TensorBoard
            train_metrics = {
                "train/loss": loss.item(),
                "train/cls_loss": cls_loss.item(),
                "train/consistency": consistency.item(),
                "train/consistency_weight": consistency_weight,
                "train/learning_rate": current_lr,
                "train/images_per_second": images_per_second,
                "timing/data_load_ms": avg_data_time * 1000,
                "timing/gpu_compute_ms": avg_gpu_time * 1000,
                "timing/data_load_percentage": (avg_data_time / (avg_data_time + avg_gpu_time)) * 100 if (avg_data_time + avg_gpu_time) > 0 else 0,
                "system/gpu_memory_mb": get_gpu_memory_usage()
            }
            
            logger.log_metrics(train_metrics, global_step)
            
            # Log to CSV
            logger.log_train_step(global_step, epoch, {
                'loss': loss.item(),
                'cls_loss': cls_loss.item(),
                'consistency_loss': consistency.item(),
                'learning_rate': current_lr,
                'consistency_weight': consistency_weight
            })
            
            # System metrics
            queue_size = train_loader.dataset.get_stats().get('queue_size', 0) if hasattr(train_loader.dataset, 'get_stats') else 0
            logger.log_system_metrics(global_step, {
                'images_per_second': images_per_second,
                'data_load_ms': avg_data_time * 1000,
                'gpu_compute_ms': avg_gpu_time * 1000,
                'queue_size': queue_size,
                'gpu_memory_mb': get_gpu_memory_usage()
            })
        
        global_step += 1
            
    # End of epoch
    epoch_time = time.time() - epoch_start
    avg_epoch_loss = epoch_loss / len(train_loader) if len(train_loader) > 0 else 0
    avg_epoch_cls_loss = epoch_cls_loss / len(train_loader) if len(train_loader) > 0 else 0
    avg_epoch_consistency_loss = epoch_consistency_loss / len(train_loader) if len(train_loader) > 0 else 0
    
    print(f"\nEpoch {epoch} Summary:")
    print(f"  Time: {epoch_time/60:.1f} minutes")
    print(f"  Avg Loss: {avg_epoch_loss:.4f}")
    print(f"  Avg Classification Loss: {avg_epoch_cls_loss:.4f}")
    print(f"  Avg Consistency Loss: {avg_epoch_consistency_loss:.4f}")
    print(f"  Throughput: {len(train_dataset) / epoch_time:.1f} images/second")
    
    # Validation
    print("Running validation...")
    val_acc, val_metrics, cm = validate(teacher_model, val_loader, config, monitor, epoch)
    
    # Log validation metrics
    logger.log_metrics(val_metrics, epoch)
    logger.log_validation(epoch, val_metrics)
    logger.log_confusion_matrix(cm, ['bad', 'neutral', 'good'], epoch)
    
    print(f"  Validation Accuracy: {val_acc:.4f}")
    print(f"  Best Accuracy: {monitor.best_accuracy:.4f}")
    
    # Save checkpoint if best
    if val_acc > best_accuracy:
        best_accuracy = val_acc
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'teacher_state_dict': teacher_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_accuracy': val_acc,
            'config': vars(config)
        }
        checkpoint_path = os.path.join(config.checkpoint_path, f"best_model_acc{val_acc:.4f}.pth")
        torch.save(checkpoint, checkpoint_path)
        print(f"Saved best model: {checkpoint_path}")
    
    # Learning rate scheduling
    scheduler.step()
    
    # Early stopping
    if monitor.epochs_without_improvement >= config.early_stopping_patience:
        print(f"Early stopping at epoch {epoch}")
        break

Epoch 5: 100%|██████████| 7823/7823 [1:03:20<00:00,  2.06it/s, loss=0.6607, img/s=473.4, data_ms=0.0, gpu_ms=67.6] 



Epoch 5 Summary:
  Time: 63.3 minutes
  Avg Loss: 0.3304
  Avg Classification Loss: 0.2513
  Avg Consistency Loss: 0.1581
  Throughput: 65.9 images/second
Running validation...


                                                                                              

  Validation Accuracy: 0.8222
  Best Accuracy: 0.8246


Epoch 6: 100%|██████████| 7823/7823 [1:02:40<00:00,  2.08it/s, loss=0.6658, img/s=533.0, data_ms=0.0, gpu_ms=60.0] 



Epoch 6 Summary:
  Time: 62.7 minutes
  Avg Loss: 0.3023
  Avg Classification Loss: 0.2190
  Avg Consistency Loss: 0.1665
  Throughput: 66.6 images/second
Running validation...


                                                                                              

  Validation Accuracy: 0.8273
  Best Accuracy: 0.8273
Saved best model: ./checkpoints\best_model_acc0.8273.pth


Epoch 7: 100%|██████████| 7823/7823 [1:02:37<00:00,  2.08it/s, loss=0.2742, img/s=509.4, data_ms=0.0, gpu_ms=62.8] 



Epoch 7 Summary:
  Time: 62.6 minutes
  Avg Loss: 0.2771
  Avg Classification Loss: 0.1935
  Avg Consistency Loss: 0.1672
  Throughput: 66.6 images/second
Running validation...


                                                                                              

  Validation Accuracy: 0.8328
  Best Accuracy: 0.8328
Saved best model: ./checkpoints\best_model_acc0.8328.pth


Epoch 8: 100%|██████████| 7823/7823 [1:02:38<00:00,  2.08it/s, loss=0.1759, img/s=532.2, data_ms=0.0, gpu_ms=60.1] 



Epoch 8 Summary:
  Time: 62.6 minutes
  Avg Loss: 0.2619
  Avg Classification Loss: 0.1776
  Avg Consistency Loss: 0.1685
  Throughput: 66.6 images/second
Running validation...


                                                                                              

  Validation Accuracy: 0.8272
  Best Accuracy: 0.8328


Epoch 9: 100%|██████████| 7823/7823 [1:02:38<00:00,  2.08it/s, loss=0.1418, img/s=503.0, data_ms=0.0, gpu_ms=63.6] 



Epoch 9 Summary:
  Time: 62.6 minutes
  Avg Loss: 0.2537
  Avg Classification Loss: 0.1687
  Avg Consistency Loss: 0.1700
  Throughput: 66.6 images/second
Running validation...


                                                                                              

  Validation Accuracy: 0.8227
  Best Accuracy: 0.8328




In [22]:

# Cleanup
print("\nCleaning up...")
if hasattr(train_loader.dataset, 'cleanup'):
    train_loader.dataset.cleanup()
if hasattr(train_dataset, 'cleanup'):
    train_dataset.cleanup()

# Final statistics
print("\n=== Training Complete ===")
print(f"Best validation accuracy: {best_accuracy:.4f}")
print(f"Logs saved to: {logger.get_log_dir()}")
print(f"View in TensorBoard: tensorboard --logdir {logger.get_log_dir()}")

# Close logger
logger.close()

checkpoint_path


Cleaning up...

=== Training Complete ===
Best validation accuracy: 0.8328
Logs saved to: ./logs\resnet152_mean_teacher_20250803_164505
View in TensorBoard: tensorboard --logdir ./logs\resnet152_mean_teacher_20250803_164505


'./checkpoints\\best_model_acc0.8328.pth'

In [23]:
test_dataset = ImageDataset(
    csv_file="test.csv", 
    root_dir=".\\data\\test", 
    transform=val_transform
)
test_loader = DataLoader(test_dataset, config.batch_size, False)
print(f"Train batches: {len(test_loader)}")

Train batches: 978


In [None]:
test_acc, test_metrics, cm = validate(teacher_model, test_loader, config, monitor, 11)

                                                                                               

In [None]:
inf_model = load_for_inference(checkpoint_path=".\\checkpoints\\best_model_acc0.8328.pth", device='cuda')

Loaded teacher model for inference
