In [None]:
# =============================================================================
# 1. IMPORTS
# =============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import models
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
import warnings

warnings.filterwarnings("ignore")

print(f"PyTorch Version: {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")

In [None]:
!nvidia-smi

In [None]:
import os

# Set the environment variable to use GPU 1
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # set up according to whatever gpu is available

# Now import torch
import torch

# Your code will now only see GPU 1. 
# torch.cuda.current_device() will return 0, as it's the first *visible* device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Device name: {torch.cuda.get_device_name(0)}")

In [None]:
# =============================================================================
# 2. CONFIGURATION
# =============================================================================
# --- Experiment Settings ---
MODELS_TO_TEST = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
# Note: Deeper models like resnet101/152 will require significant VRAM and time.

# --- Hyperparameters ---
NUM_EPOCHS = 10 # seems like a fair compromise
BATCH_SIZE = 256 # Just because I can
LEARNING_RATE = 1e-3
HF_DATASET_NAME = "jonathan-roberts1/MLRSNet"

# --- Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# =============================================================================
# 3. DATA LOADING & ANALYSIS
# =============================================================================
# --- Load Dataset ---
print("Loading dataset from Hugging Face Hub...")
full_dataset = load_dataset(HF_DATASET_NAME)
class_names = full_dataset['train'].features['label'].feature.names
NUM_CLASSES = len(class_names)

# --- 1. CLASS FREQUENCY ANALYSIS ---
print("\nAnalyzing class distribution...")
class_counts = np.zeros(NUM_CLASSES)
for example in tqdm(full_dataset['train'], desc="Counting classes"):
    for label_index in example['label']:
        class_counts[label_index] += 1

# Create a pandas DataFrame for nice printing
freq_df = pd.DataFrame({
    'Class Name': class_names,
    'Count': class_counts
}).sort_values(by='Count', ascending=False).reset_index(drop=True)

print("Class Distribution in MLRSNet:")
print(freq_df)

# --- Split and Transform Data ---
split_dataset = full_dataset['train'].train_test_split(test_size=0.15, seed=42) # 85% train, 15% test- seed = 42 for reproducibility
dataset = split_dataset

image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
]) # Standard normalization for pre-trained models

def apply_transforms(batch):
    batch['pixel_values'] = [image_transforms(image.convert("RGB")) for image in batch['image']]
    multi_hot_labels = []
    for label_indices in batch['label']:
        new_label = torch.zeros(NUM_CLASSES)
        new_label[label_indices] = 1.0
        multi_hot_labels.append(new_label)
    batch['label'] = torch.stack(multi_hot_labels)
    del batch['image']
    return batch

dataset.set_transform(apply_transforms)

# --- Create DataLoaders ---
train_loader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(dataset['test'], batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
print("\nDataLoaders created.")

In [None]:
# =============================================================================
# 4. MODEL TRAINING & EVALUATION
# =============================================================================
# --- Helper Functions ---
def get_model(model_name, num_classes):
    """Loads a pre-trained ResNet model and replaces the classifier."""
    if model_name == 'resnet18':
        model = models.resnet18(weights='IMAGENET1K_V1')
    elif model_name == 'resnet34':
        model = models.resnet34(weights='IMAGENET1K_V1')
    elif model_name == 'resnet50':
        model = models.resnet50(weights='IMAGENET1K_V1')
    elif model_name == 'resnet101':
        model = models.resnet101(weights='IMAGENET1K_V1')
    elif model_name == 'resnet152':
        model = models.resnet152(weights='IMAGENET1K_V1')
    else:
        raise ValueError(f"Model {model_name} not supported.")
        
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model.to(device)

def calculate_metrics(preds, targets, threshold=0.5):
    preds = torch.sigmoid(preds)
    binary_preds = (preds >= threshold).cpu().numpy()
    targets = targets.cpu().numpy()
    f1 = f1_score(targets, binary_preds, average='samples', zero_division=0)
    precision = precision_score(targets, binary_preds, average='samples', zero_division=0)
    recall = recall_score(targets, binary_preds, average='samples', zero_division=0)
    accuracy = accuracy_score(targets, binary_preds)
    return {'accuracy': accuracy, 'f1': f1, 'precision': precision, 'recall': recall}

# --- Main Experiment Loop ---
results = {}

for model_name in MODELS_TO_TEST:
    print(f"\n{'='*20} Training {model_name.upper()} {'='*20}")
    
    # Initialize model, loss, and optimizer
    model = get_model(model_name, NUM_CLASSES)
    criterion = nn.BCEWithLogitsLoss() # Suitable for multi-label classification - cannot use CrossEntropyLoss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) # Adam optimizer generally works well (can try with other optimizers too)

    history = {'train_loss': [], 'val_loss': [], 'val_f1': [], 'val_precision': [], 'val_recall': [], 'val_accuracy': []}

    for epoch in range(NUM_EPOCHS):
        # Training phase
        model.train()
        running_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]"):
            inputs, labels = batch['pixel_values'].to(device), batch['label'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        
        epoch_train_loss = running_loss / len(dataset['train'])
        history['train_loss'].append(epoch_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        all_preds, all_targets = [], []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]"):
                inputs, labels = batch['pixel_values'].to(device), batch['labels'].to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                all_preds.append(outputs)
                all_targets.append(labels)
        
        epoch_val_loss = val_loss / len(dataset['test'])
        history['val_loss'].append(epoch_val_loss)
        
        # Calculate metrics for the epoch
        all_preds_tensor = torch.cat(all_preds, dim=0)
        all_targets_tensor = torch.cat(all_targets, dim=0)
        val_metrics = calculate_metrics(all_preds_tensor, all_targets_tensor)
        
        history['val_f1'].append(val_metrics['f1'])
        history['val_precision'].append(val_metrics['precision'])
        history['val_recall'].append(val_metrics['recall'])
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} -> Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val F1: {val_metrics['f1']:.4f}")

    results[model_name] = history
    print(f"Finished training for {model_name}.")

In [None]:
# =============================================================================
# 5. RESULTS VISUALIZATION
# =============================================================================
print("\nVisualizing results...")

# --- 2. PLOT LOSS CURVES ---
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax1 = plt.subplots(1, 1, figsize=(12, 7))

for model_name, history in results.items():
    ax1.plot(range(1, NUM_EPOCHS + 1), history['val_loss'], 'o-', label=f'{model_name} Val Loss')

ax1.set_title('Validation Loss Comparison Across ResNet Architectures', fontsize=16)
ax1.set_xlabel('Epochs', fontsize=12)
ax1.set_ylabel('BCEWithLogitsLoss', fontsize=12)
ax1.legend(fontsize=12)
ax1.grid(True)
plt.tight_layout()
plt.show()

# --- 3. PLOT PERFORMANCE BAR CHART ---
final_metrics = {
    'Model': [],
    'Accuracy': [],
    'F1-Score': [],
    'Precision': [],
    'Recall': []
}

for model_name, history in results.items():
    final_metrics['Model'].append(model_name)
    final_metrics['Accuracy'].append(history['val_accuracy'][-1])
    final_metrics['F1-Score'].append(history['val_f1'][-1]) # Get the last epoch's score
    final_metrics['Precision'].append(history['val_precision'][-1])
    final_metrics['Recall'].append(history['val_recall'][-1])

metrics_df = pd.DataFrame(final_metrics)

fig, ax2 = plt.subplots(1, 1, figsize=(12, 7))
metrics_df.plot(x='Model', y=['Accuracy', 'F1-Score', 'Precision', 'Recall'], kind='bar', ax=ax2, zorder=3)

ax2.set_title('Final Performance Metrics Comparison', fontsize=16)
ax2.set_ylabel('Score', fontsize=12)
ax2.set_xlabel('ResNet Architecture', fontsize=12)
ax2.tick_params(axis='x', rotation=0)
ax2.grid(axis='y', zorder=0)

# Add value labels on top of bars
for p in ax2.patches:
    ax2.annotate(f"{p.get_height():.3f}", (p.get_x() + p.get_width() / 2., p.get_height()),
                 ha='center', va='center', xytext=(0, 9), textcoords='offset points')

plt.tight_layout()
plt.show()