In [None]:
# =============================================================================
# 1. IMPORT LIBRARIES
# =============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet34, ResNet34_Weights
from datasets import load_dataset
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

print(f"PyTorch Version: {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")

In [None]:
!nvidia-smi

In [None]:
import os

# Set the environment variable to use GPU 1
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Now import torch
import torch

# Your code will now only see GPU 1. 
# torch.cuda.current_device() will return 0, as it's the first *visible* device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Device name: {torch.cuda.get_device_name(0)}")

In [None]:
# 2. SETUP AND CONFIGURATION
# =============================================================================
# Hyperparameters
NUM_EPOCHS = 15
BATCH_SIZE = 64 # Adjust based on your VRAM 
LEARNING_RATE = 1e-3
HF_DATASET_NAME = "jonathan-roberts1/MLRSNet" # Name of the dataset on Hugging Face Hub

In [None]:
# =============================================================================
# 3. DATA LOADING AND PREPARATION FROM HUGGING FACE
# =============================================================================
# This single line downloads and caches the dataset
print("Loading dataset from Hugging Face Hub...")
dataset = load_dataset(HF_DATASET_NAME)
# this will download the dataset into the `~/.cache/huggingface/datasets` directory
print("Dataset loaded successfully.")

# Get class information from the dataset features
class_names = dataset['train'].features['labels'].feature.names
NUM_CLASSES = len(class_names)
print(f"Number of classes: {NUM_CLASSES}")

# Define image transformations
# We use Compose to chain transformations together
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Create a transformation function that processes both images and labels
def apply_transforms(batch):
    # Apply image transforms to each image in the batch
    batch['pixel_values'] = [image_transforms(image.convert("RGB")) for image in batch['image']]
    
    # Convert labels from a list of indices to a multi-hot encoded tensor
    multi_hot_labels = []
    for label_indices in batch['labels']:
        new_label = torch.zeros(NUM_CLASSES)
        new_label[label_indices] = 1.0
        multi_hot_labels.append(new_label)
    batch['labels'] = torch.stack(multi_hot_labels)
    
    # Remove the original image column to save memory
    del batch['image']
    return batch

# Apply the transformations to the dataset on-the-fly
dataset.set_transform(apply_transforms)

# Create DataLoaders
train_loader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(dataset['test'], batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
print("DataLoaders created.")

In [None]:
# =============================================================================
# 4. MODEL DEFINITION (Unchanged)
# =============================================================================
model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
model = model.to(device)

In [None]:
# =============================================================================
# 5. LOSS, OPTIMIZER, AND METRICS (Unchanged)
# =============================================================================
criterion = nn.BCEWithLogitsLoss() # Suitable for multi-label classification- cannot use CrossEntropyLoss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

def calculate_metrics(preds, targets, threshold=0.5):
    preds = torch.sigmoid(preds)
    binary_preds = (preds >= threshold).cpu().numpy()
    targets = targets.cpu().numpy()
    accuracy = accuracy_score(targets, binary_preds)
    f1 = f1_score(targets, binary_preds, average='samples', zero_division=0)
    precision = precision_score(targets, binary_preds, average='samples', zero_division=0)
    recall = recall_score(targets, binary_preds, average='samples', zero_division=0)
    return {'accuracy': accuracy, 'f1': f1, 'precision': precision, 'recall': recall}

In [None]:
# =============================================================================
# 6. TRAINING AND VALIDATION LOOP
# =============================================================================
for epoch in range(NUM_EPOCHS):
    # --- Training Phase ---
    model.train()
    running_loss = 0.0
    
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Training]")
    for batch in train_pbar:
        # **Important Change:** Access data from the batch dictionary
        inputs, labels = batch['pixel_values'].to(device), batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        train_pbar.set_postfix({'loss': loss.item()})
        
    epoch_train_loss = running_loss / len(dataset['train'])

    # --- Validation Phase ---
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Validation]")
        for batch in val_pbar:
            # **Important Change:** Access data from the batch dictionary
            inputs, labels = batch['pixel_values'].to(device), batch['labels'].to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            
            all_preds.append(outputs)
            all_targets.append(labels)
            
    epoch_val_loss = val_loss / len(dataset['test'])
    
    all_preds_tensor = torch.cat(all_preds, dim=0)
    all_targets_tensor = torch.cat(all_targets, dim=0)
    val_metrics = calculate_metrics(all_preds_tensor, all_targets_tensor)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} -> Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f}")
    print(f"Validation Metrics -> F1: {val_metrics['f1']:.4f} | Precision: {val_metrics['precision']:.4f} | Recall: {val_metrics['recall']:.4f}")

print("Finished Training!")
torch.save(model.state_dict(), 'resnet34_mlrsnet_hf.pth')
print("Model saved to resnet34_mlrsnet_hf.pth")