## Imports

In [69]:
import argparse
from filelock import FileLock
import platform
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, classification_report, roc_auc_score
from torch.utils.data import DataLoader

import torchvision
from torchvision.io import decode_image
from torchvision.transforms import v2 as T2
import torchvision.transforms.v2.functional as F2
from torchvision.transforms.v2.functional import InterpolationMode
from torchvision.models import list_models, get_model, get_model_weights

from sklearn.model_selection import train_test_split, GroupShuffleSplit
from collections import defaultdict
from PIL import Image
from torch.amp import GradScaler, autocast
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from math import pi
from datetime import datetime
import time
import hashlib
import json

import torch.backends.cudnn as cudnn
import warnings

In [70]:
# Enable cuDNN benchmark for optimized performance
cudnn.benchmark = True

In [71]:
warnings.filterwarnings("ignore", category=UserWarning, module="PIL.PngImagePlugin")

## Global variables

In [72]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

use_amp = True
scaler = torch.amp.GradScaler(enabled=use_amp)

script_dir = os.getcwd()
machine_name = platform.node()
user = os.getenv("USER") 
if user == "jon":
    model_names = ['densenet121', 'densenet161', 'densenet169', 'densenet201', 'alexnet']

    if "dataset" not in os.listdir():
        script_dir = "/mnt/b/Xray"

elif user == "jonal":
    model_names = ['resnet50', 'resnet101', 'resnet152', 'resnext101_32x8d', 'mobilenet_v3_large', 'googlenet', 'inception_v3']

else:
    model_names = ['mobilenet_v3_large', 'googlenet', 'inception_v3', 'alexnet','convnext_base', 'convnext_large','vit_b_16', 'swin_b','vgg16', 'vgg19'] 
           
#'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7',

print(f"script_dir: {script_dir}")

script_dir: /home/jon/projects/Xrays


In [73]:
image_size = 224
batch_size = 128
num_workers = max(1, os.cpu_count() // 3)
prefetch_factor = 3
enable_cache = True
rebuild_cache = False
num_train_images = None
num_test_images = None
checkpoint_interval = 10
num_epochs = 1

In [74]:
runs_per_model = 5
lock_timeout = 86400 # 24 hours 604800 # 1 week

In [75]:
timestamp = datetime.now().strftime("%Y%m%d")

In [76]:
train_dir = f"{script_dir}/dataset/data/train_{image_size}"
test_dir = f"{script_dir}/dataset/data/test_{image_size}"
labels_file = f"{script_dir}/dataset/Data_Entry_2017_v2020.csv"

models_dir = f"{script_dir}/models"
os.makedirs(models_dir, exist_ok=True)
results_dir = f"{script_dir}/results"
os.makedirs(results_dir, exist_ok=True)
locks_dir = "locks"  
os.makedirs(locks_dir, exist_ok=True)


detailed_results_path = f"{results_dir}/detailed_model_results_{machine_name}_{image_size}_{timestamp}.csv"
summary_results_path = f"{results_dir}/summary_model_results_{machine_name}_{image_size}_{timestamp}.csv"

## Classes

In [77]:
class ChestXray14Dataset(torch.utils.data.Dataset):
    """
    Dataset class for loading cached tensors and multi-label vectors.
    """
    def __init__(self, data: list, label_mapping: dict, pathologies: list, transform: callable = None):
        """
        Args:
            data (list): List of (patient_id, image_path) pairs.
            label_mapping (dict): Dictionary mapping image names to label vectors.
            pathologies (list): List of pathologies for model alignment.
            transform (callable, optional): Transformation function to apply to the images.
        """
        self.data = data
        self.label_mapping = label_mapping
        self.pathologies = pathologies
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        patient_id, tensor_path, image_name = self.data[idx]

        image = decode_image(tensor_path)#, mode="RGB")
        if self.transform:
            image = self.transform(image)

        labels = self.label_mapping[image_name]
        label_vector = torch.tensor([labels[self.pathologies.index(p)] for p in self.pathologies], dtype=torch.float)
        #label_vector = torch.tensor(labels, dtype=torch.float)
        return {"img": image, "lab": label_vector}

## Functions

In [78]:
def batch_iterable(iterable: iter, batch_size: int) -> iter:
    """Yield successive batches from an iterable."""
    iterator = iter(iterable)
    while True:
        batch = list(islice(iterator, batch_size))
        if not batch:
            break
        yield batch

In [79]:
def preprocess_and_save(dataset: list, transform: callable, cache_dir: str, num_workers: int = 1, batch_size: int = 32, enable_cache: bool = True, rebuild_cache: bool = False) -> list:
    """
    Preprocess and save dataset images in batches, with optional caching and multiprocessing.

    Args:
        dataset (list): List of (patient_id, image_path) pairs.
        transform (callable): Transformations to apply to the images.
        cache_dir (str): Directory to store cached preprocessed images.
        num_workers (int): Number of parallel workers for preprocessing.
        batch_size (int): Number of items to process in each batch.
        enable_cache (bool): If True, use caching; otherwise, process all files without caching.
        rebuild_cache (bool): If True, overwrite existing cache files.

    Returns:
        list: A list of (patient_id, cached_image_path or transformed_image) pairs.
    """
    if not enable_cache:
        print("Caching is disabled. Processing images in memory.")
        return [(patient_id, transform(Image.open(image_path).convert("RGB"))) for patient_id, image_path in dataset]
        
    if enable_cache:
        print("\nBuilding cache...")
        os.makedirs(cache_dir, exist_ok=True)
        if rebuild_cache:
            print(f"Rebuilding cache. Clearing directory: {cache_dir}")
            for file in os.listdir(cache_dir):
                file_path = os.path.join(cache_dir, file)
                os.remove(file_path)

    def process_batch(batch):
        results = []
        for patient_id, image_path in batch:
            cache_path = os.path.join(cache_dir, f"{os.path.basename(image_path)}.pt") if enable_cache else None
            if not enable_cache or rebuild_cache or (enable_cache and not os.path.exists(cache_path)):
                try:
                    image = Image.open(image_path).convert("RGB")
                    image = transform(image)
                    if enable_cache:
                        torch.save(image, cache_path)
                except Exception as e:
                    print(f"Error processing {image_path}: {e}")
            results.append((patient_id, cache_path if enable_cache else image))
        return results

    def worker(input_queue, output_queue):
        while True:
            batch = input_queue.get()
            if batch is None:  # End of queue signal
                break
            output_queue.put(process_batch(batch))

    # Create queues
    input_queue = mp.Queue()
    output_queue = mp.Queue()
    workers = []

    # Start worker processes
    for i in range(num_workers):
        print(f"Starting worker process {i+1}/{num_workers}", end="\r")
        process = mp.Process(target=worker, args=(input_queue, output_queue))
        process.start()
        workers.append(process)
    print()

    # Divide dataset into batches and add to queue
    total_batches = (len(dataset) + batch_size - 1) // batch_size
    for i, batch in enumerate(batch_iterable(dataset, batch_size)):
        print(f"Adding batches to queue: {i+1}/{total_batches}", end="\r")
        input_queue.put(batch)
    print()

    # Signal workers to terminate
    for i in range(num_workers):
        input_queue.put(None)

    # Collect results
    preprocessed_dataset = []
    start_time = time.time()
    for i in range(total_batches):
        batch_start = time.time()
        preprocessed_dataset.extend(output_queue.get())
        batch_end = time.time()
        
        # Calculate elapsed time and remaining time
        elapsed_time = batch_end - start_time
        batches_processed = i + 1
        avg_batch_time = elapsed_time / batches_processed
        remaining_time = avg_batch_time * (total_batches - batches_processed)
        eta = time.strftime('%H:%M:%S', time.gmtime(remaining_time))
        
        print(f"Collecting results: {batches_processed}/{total_batches}, ETA: {eta}", end="\r")
    print()

    # Wait for workers to finish
    for process in workers:
        process.join()

    print(f"Preprocessing complete. Total processed items: {len(preprocessed_dataset)}")
    return preprocessed_dataset

In [80]:
def load_dataset(directory: str, max_total_images: int = None, random_selection: bool = False, seed: int = None) -> list:
    if random_selection and seed is not None:
        random.seed(seed)

    patient_images = defaultdict(list)

    for filename in sorted(os.listdir(directory)):
        if filename.endswith(".png"):
            patient_id = filename.split("_")[0]
            patient_images[patient_id].append(os.path.join(directory, filename))

    selected_images = []
    for patient_id, images in patient_images.items():
        for image in images:
        #selected_image = random.choice(images) if random_selection else images[0]

            image_name = os.path.basename(image)
            selected_images.append((patient_id, image, image_name))
            if max_total_images is not None and len(selected_images) >= max_total_images:
                break

    return selected_images

In [81]:
def load_labels(csv_path: str, conditions: list) -> dict:
    df = pd.read_csv(csv_path)
    labels = {}
    for _, row in df.iterrows():
        image_path = row['Image Index']
        findings = row['Finding Labels'].split('|')
        label_vector = [1 if condition in findings else 0 for condition in conditions]
        labels[image_path] = label_vector
    return labels

In [82]:
def get_data_transforms(image_size: int=512, mean: list=[0.485, 0.456, 0.406], std: list=[0.229, 0.224 , 0.225]) -> dict:
    
    return {
        "train": T2.Compose([
            T2.RandomHorizontalFlip(),
            T2.RandomRotation(7),
            T2.RandomResizedCrop(
                size=(224,224), 
                scale=(0.08, 1.0),
                ratio=(3/4, 4/3),
                antialias=True,
            ),
            T2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            T2.ToDtype(torch.float32, scale=True),
            T2.Normalize(mean, std),
        ]),
        "val": T2.Compose([
            T2.Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS),
            T2.CenterCrop(224),
            T2.ToDtype(torch.float32, scale=True),
            T2.Normalize(mean, std),
        ]),
    }


In [83]:
def prepare_model(model_name: str, num_classes: int, weights: str = "DEFAULT") -> nn.Module:
    """
    Prepare a classification model with custom output classes.

    Args:
        model_name (str): Name of the model (must be a valid torchvision model name).
        num_classes (int): Number of output classes.
        weights (str): Pretrained weights to use. Default is "DEFAULT".
        
    Returns:
        model (torch.nn.Module): The prepared model with the custom classification head.
    """
    # Get the model
    model = get_model(model_name, weights=weights)

    # Replace the classification head based on the model architecture
    if hasattr(model, "fc"):  # For models like ResNet, RegNet, etc.
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif hasattr(model, "classifier"):  # For models like DenseNet, VGG, etc.
        if isinstance(model.classifier, nn.Linear):
            model.classifier = nn.Linear(model.classifier.in_features, num_classes)
        elif isinstance(model.classifier, nn.Sequential):  # For models like EfficientNet
            model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
    elif hasattr(model, "heads"):  # For Vision Transformers (ViT)
        model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
    else:
        raise ValueError(f"Model {model_name} does not have a recognized classification head.")

    return model


In [84]:
def train_model(model: nn.Module,
                train_loader: DataLoader,
                val_loader: DataLoader,
                num_epochs: int,
                lr: float,
                weight_decay: float,
                retrain: bool=True,
                grad_clip: float=None,
                models_dir: str = "models",
                checkpoint_interval: int = 5) -> nn.Module:


    criterion = nn.BCEWithLogitsLoss()
    if retrain:
        for param in model.parameters():
            param.requires_grad = True
        
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-08, amsgrad=False)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
    
    model.to(device)

    train_losses, val_losses = [], []
    epoch_times = [] 

    for epoch in range(num_epochs):
        start_time = time.time() 
        train_loss, val_loss = 0.0, 0.0
        os.makedirs(models_dir, exist_ok=True)  

        # Training Phase
        model.train()
        for batch in train_loader:
            images, labels = batch['img'].to(device, non_blocking=True), batch['lab'].to(device, non_blocking=True)

            optimizer.zero_grad()
            with autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp): 
                outputs = model(images)
                loss = criterion(outputs, labels.float())

            scaler.scale(loss).backward()
            if grad_clip is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()

        # Validation Phase
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                images, labels = batch['img'].to(device, non_blocking=True), batch['lab'].to(device, non_blocking=True)
                with autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp):
                    outputs = model(images)
                    loss = criterion(outputs, labels.float())
                val_loss += loss.item()

        train_losses.append(train_loss / len(train_loader))
        val_losses.append(val_loss / len(val_loader))

        scheduler.step()  # Adjust learning rate

        # Calculate epoch duration and remaining time
        epoch_duration = time.time() - start_time
        epoch_times.append(epoch_duration)
        avg_epoch_time = sum(epoch_times) / len(epoch_times)
        remaining_time = avg_epoch_time * (num_epochs - (epoch + 1))

        # Format remaining time as HH:MM:SS
        remaining_time_str = time.strftime('%H:%M:%S', time.gmtime(remaining_time))

        # Print epoch summary with timing and remaining time
        print(f"    Epoch {epoch+1:03d}/{num_epochs:03d}, "
              f"Train Loss: {train_losses[-1]:.6f}, "
              f"Val Loss: {val_losses[-1]:.6f}, "
              f"Time: {epoch_duration:.2f} sec, "
              f"ETA: {remaining_time_str}", end=" ")
        
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = f"{models_dir}/checkpoint_epoch_{epoch + 1}.pth"
            torch.save(model.state_dict(), checkpoint_path)
            print(", Checkpoint saved")
        else:
            print()

    return model


In [85]:
def evaluate_model(model: nn.Module, val_loader: DataLoader, target_names: list) -> dict:
    model.eval()

    predictions, actuals = [], []
    with torch.no_grad():
        for batch in val_loader:
            images, labels = batch['img'].to(device), batch['lab'].to(device)
            outputs = torch.sigmoid(model(images))  # Sigmoid for probabilities
            predictions.extend(outputs.cpu().numpy())
            actuals.extend(labels.cpu().numpy())

    predictions = np.array(predictions)
    actuals = np.array(actuals)

    # Calculate AUC for each label
    auc_scores = []
    for i in range(len(target_names)):
        if np.sum(actuals[:, i]) == 0 or np.sum(actuals[:, i]) == len(actuals):
            print(f"Skipping AUC calculation for {target_names[i]} (only one class present in labels).")
            auc_scores.append(None)
        else:
            auc = roc_auc_score(actuals[:, i], predictions[:, i])
            auc_scores.append(auc)

    valid_auc_scores = [auc for auc in auc_scores if auc is not None]
    avg_auc = None
    if valid_auc_scores:
        avg_auc = np.mean(valid_auc_scores)
    
    return {'predictions': predictions, 'actuals': actuals, 'auc_scores': auc_scores, 'avg_auc': avg_auc}

In [86]:
def plot_combined_radar_chart(results_df: pd.DataFrame) -> None:
    
    pathologies = results_df["Pathology"].unique()
    num_pathologies = len(pathologies)

    # Create angle for each pathology
    angles = np.linspace(0, 2 * np.pi, num_pathologies, endpoint=False).tolist()
    angles += angles[:1]  # Complete the loop

    # Prepare figure
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True))
    ax.set_theta_offset(pi / 2)
    ax.set_theta_direction(-1)

    # Draw one axe per pathology and add labels
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(pathologies, fontsize=10)

    # Draw y-labels
    ax.set_rscale("linear")
    ax.set_rlabel_position(0)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(["0.2", "0.4", "0.6", "0.8", "1.0"], color="grey", size=10)
    ax.set_ylim(0, 1)

    # Colors for each model
    colors = plt.cm.tab20.colors

    # Plot test AUCs for each model
    for i, model_name in enumerate(results_df["Model"].unique()):
        model_results = results_df[results_df["Model"] == model_name]
        avg_auc_per_pathology = model_results.groupby("Pathology")["Test AUC"].mean()

        test_aucs = avg_auc_per_pathology.tolist()
        test_aucs += test_aucs[:1]  # Complete the loop

        ax.plot(angles, test_aucs, label=model_name, linestyle='-', color=colors[i % len(colors)])
        ax.fill(angles, test_aucs, color=colors[i % len(colors)], alpha=0.1)

    # Add legend and title
    plt.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1), fontsize=9)
    plt.title("Combined Radar Chart for Test AUCs of All Models", size=15, y=1.1)
    plt.show()

In [87]:
def plot_radar_chart(model_name: str, results_df: pd.DataFrame) -> None:
    pathologies = results_df["Pathology"].unique()
    num_pathologies = len(pathologies)

    # Prepare data for the specified model
    model_results = results_df[results_df["Model"] == model_name]
    avg_auc_per_pathology = model_results.groupby("Pathology")[["Validation AUC", "Test AUC"]].mean()

    # Create angle for each pathology
    angles = np.linspace(0, 2 * np.pi, num_pathologies, endpoint=False).tolist()
    angles += angles[:1]  # Complete the loop

    # Prepare data for radar chart
    validation_aucs = avg_auc_per_pathology["Validation AUC"].tolist()
    test_aucs = avg_auc_per_pathology["Test AUC"].tolist()
    validation_aucs += validation_aucs[:1]  # Complete the loop
    test_aucs += test_aucs[:1]

    # Start the radar plot
    fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
    ax.set_theta_offset(pi / 2)
    ax.set_theta_direction(-1)

    # Draw one axe per pathology and add labels
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(pathologies, fontsize=10)

    # Draw y-labels
    ax.set_rscale("linear")
    ax.set_rlabel_position(0)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(["0.2", "0.4", "0.6", "0.8", "1.0"], color="grey", size=10)
    ax.set_ylim(0, 1)

    # Plot data
    ax.plot(angles, validation_aucs, label="Validation AUC", linestyle='--', color="blue")
    ax.fill(angles, validation_aucs, color="blue", alpha=0.1)

    ax.plot(angles, test_aucs, label="Test AUC", linestyle='-', color="orange")
    ax.fill(angles, test_aucs, color="orange", alpha=0.1)

    # Add legend and title
    plt.legend(loc="upper right", bbox_to_anchor=(0.1, 0.1))
    plt.title(f"Radar Chart for Model: {model_name}", size=15, y=1.1)
    plt.show()

## Data

In [88]:
common_pathologies = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion", 
                        "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", 
                        "Nodule", "Pleural_Thickening", "Pneumonia", "Pneumothorax"]

label_mapping = load_labels(labels_file, common_pathologies)

In [89]:
label_mapping

{'00000001_000.png': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 '00000001_001.png': [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 '00000001_002.png': [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 '00000002_000.png': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 '00000003_001.png': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 '00000003_002.png': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 '00000003_003.png': [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
 '00000003_004.png': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 '00000003_005.png': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 '00000003_006.png': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 '00000003_007.png': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 '00000003_000.png': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 '00000004_000.png': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
 '00000005_000.png': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 '00000005_001.png': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 '00000005_002.png': [0, 

In [90]:
data_transforms = get_data_transforms() 

In [91]:
train_val_dataset = load_dataset(train_dir, random_selection=True, seed=42, max_total_images=num_train_images)

# Separate IDs and paths
ids = [item[0] for item in train_val_dataset]
paths = [item[1] for item in train_val_dataset]
names = [item[2] for item in train_val_dataset]

splitter = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_idx, val_idx in splitter.split(paths, groups=ids):
    train_data = [(ids[i], paths[i], names[i]) for i in train_idx]
    val_data = [(ids[i], paths[i], names[i]) for i in val_idx]


test_data = load_dataset(test_dir, random_selection=False, max_total_images=num_test_images)

In [92]:
len(train_data), len(val_data), len(test_data)

(69625, 16899, 25596)

In [93]:
# check for data leakage
ids1 = {item[0] for item in train_data}
ids2 = {item[0] for item in val_data}
common_ids = ids1.intersection(ids2)

if common_ids:
    print("Data leakage detected! Common IDs:", common_ids)
else:
    print("No data leakage detected.")

No data leakage detected.


In [94]:
train_data[0]

('00000001',
 '/home/jon/projects/Xrays/dataset/data/train_224/00000001_000.png',
 '00000001_000.png')

In [95]:
train_loader = DataLoader(
    ChestXray14Dataset(train_data, label_mapping, common_pathologies, transform=data_transforms["train"]),
    batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=True
)

val_loader = DataLoader(
    ChestXray14Dataset(val_data, label_mapping, common_pathologies, transform=data_transforms["val"]),
    batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=True
)

test_loader = DataLoader(
    ChestXray14Dataset(test_data, label_mapping, common_pathologies, transform=data_transforms["val"]),
    batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=True
)

## Models

In [96]:
lr = 0.00013334120505282098
weight_decay = 1.8857522141696178e-06
grad_clip =  0.47836246526814713
num_epochs = 20

## Train and Evaluation

In [None]:

if os.path.exists(detailed_results_path):
    results_df = pd.read_csv(detailed_results_path)
    completed_runs = set((row["Model"], row["Run"]) for _, row in results_df.iterrows())
else:
    results_df = pd.DataFrame(columns=["Model", "Run", "Pathology", "Validation AUC", "Test AUC", "Validation Predictions", "Validation Actuals", "Test Predictions", "Test Actuals"])
    completed_runs = set()

# Iterate over models
for model_name in model_names:
    print(f"Running experiments for model: {model_name}")
    
    for run in range(1, runs_per_model + 1):
        run_identifier = (model_name, run)
        run_lock_path = os.path.join(locks_dir, f"{model_name}_run_{run}.lock")
        run_lock = FileLock(run_lock_path)

        # Check for existing lock
        if os.path.exists(run_lock_path):
            lock_age = time.time() - os.path.getmtime(run_lock_path)
            if lock_age < lock_timeout:
                print(f"  Skipping locked run {run} for model {model_name} (lock age: {lock_age:.1f} seconds)")
                continue
            else:
                print(f"  Found stale lock for run {run} of model {model_name}. Removing...")
                os.remove(run_lock_path)
            
        # Create a lock atomically
        try:
            with open(run_lock_path, "x") as lock_file:
                lock_file.write(str(time.time()))  # Write timestamp to the lock file
        except FileExistsError:
            print(f"  Skipping locked run {run} for model {model_name} (race condition)")
            continue

        try:
            print(f"  Starting run {run} for model {model_name}")

            # Prepare and train the model
            print(f"    Preparing model...")
            model = prepare_model(model_name=model_name, num_classes=len(common_pathologies), weights="DEFAULT")
            print(f"    Training model...")
            model = train_model(model,
                                train_loader,
                                val_loader,
                                num_epochs,
                                lr,
                                weight_decay,
                                retrain=True,
                                grad_clip=grad_clip,
                                checkpoint_interval=checkpoint_interval)

            # Evaluate and test the model
            print(f"    Evaluating model...")
            results_eval = evaluate_model(model, val_loader, target_names=common_pathologies)
            print(f"    Testing model...")
            results_test = evaluate_model(model, test_loader, target_names=common_pathologies)

            # Collect results
            print(f"    Collecting results...")
            new_results = []
            for i, pathology in enumerate(common_pathologies):
                val_auc = results_eval['auc_scores'][i] if results_eval['auc_scores'][i] is not None else np.nan
                val_pred = results_eval['predictions'][:, i]
                val_actual = results_eval['actuals'][:, i]
                test_auc = results_test['auc_scores'][i] if results_test['auc_scores'][i] is not None else np.nan
                test_pred = results_test['predictions'][:, i]
                test_actual = results_test['actuals'][:, i]


                new_results.append({
                    "Model": model_name,
                    "Run": run,
                    "Pathology": pathology,
                    "Validation AUC": val_auc,
                    "Test AUC": test_auc,
                    "Validation Predictions": val_pred,
                    "Validation Actuals": val_actual,
                    "Test Predictions": test_pred,
                    "Test Actuals": test_actual
                })

            print(f"    Results collected for model {model_name}, run {run}")
            # Append results safely with a global file lock
            file_lock_path = detailed_results_path + ".lock"
            with FileLock(file_lock_path):
                new_results_df = pd.DataFrame(new_results)
                if os.path.exists(detailed_results_path):
                    new_results_df.to_csv(detailed_results_path, mode='a', header=False, index=False)
                else:
                    new_results_df.to_csv(detailed_results_path, index=False)

            print(f"  Results saved for model {model_name}, run {run}")

        except Exception as e:
            print(f"Error during run {run} for model {model_name}: {e}")

        finally:
            torch.cuda.empty_cache()

# Generate summary statistics after all runs
with FileLock(detailed_results_path + ".lock"):
    results_df = pd.read_csv(detailed_results_path)

results_df["Validation AUC"] = pd.to_numeric(results_df["Validation AUC"], errors="coerce")
results_df["Test AUC"] = pd.to_numeric(results_df["Test AUC"], errors="coerce")

summary = results_df.groupby(["Model", "Pathology"]).agg({
    "Validation AUC": ["mean", "std"],
    "Test AUC": ["mean", "std"]
}).reset_index()

summary.to_csv(summary_results_path, index=False)

print("Summary of Results:")
from IPython.display import display
display(summary)

Running experiments for model: densenet121
  Skipping locked run 1 for model densenet121 (lock age: 35890.4 seconds)
  Skipping locked run 2 for model densenet121 (lock age: 35852.1 seconds)
  Skipping locked run 3 for model densenet121 (lock age: 16167.0 seconds)
  Starting run 4 for model densenet121
    Preparing model...


    Training model...


In [None]:
model_filename = f"{models_dir}/model_{timestamp}.pth"
torch.save(model.state_dict(), model_filename)
print("Model saved successfully.")

## Visualize

In [None]:
results_df = pd.read_csv(detailed_results_path)

plot_combined_radar_chart(results_df)

In [None]:
# Generate radar chart for each model
for model_name in results_df["Model"].unique():
    plot_radar_chart(model_name, results_df)