# EuroSAT Training Pipeline (PyTorch)
This notebook trains a ResNet50 on EuroSAT data and generates a CSV for test predictions.

In [None]:
# Install rasterio
!pip install rasterio

In [None]:
# Imports
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, random_split, Dataset
import numpy as np
import pandas as pd
from PIL import Image
import rasterio
from tqdm import tqdm

In [None]:
# Mount Google Drive (optional)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Set paths
DATA_DIR = '/content/drive/MyDrive/eurosat'  # Place your EuroSAT data here
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
TEST_DIR = os.path.join(DATA_DIR, 'test')

In [None]:
# Dataset class for training (no cache, direct band slicing)
class TrainDataset(Dataset):
    def __init__(self, train_dir, transform=None):
        self.transform = transform
        self.samples = []
        self.labels = []
        for class_idx, class_name in enumerate(sorted(os.listdir(train_dir))):
            class_path = os.path.join(train_dir, class_name)
            if os.path.isdir(class_path):
                for f in os.listdir(class_path):
                    if f.endswith('.tif'):
                        img_path = os.path.join(class_path, f)
                        self.samples.append(img_path)
                        self.labels.append(class_idx)
    def __getitem__(self, index):
        img_path = self.samples[index]
        with rasterio.open(img_path) as src:
            img = src.read()  # shape: (bands, H, W)
            img = np.transpose(img, (1, 2, 0))  # (H, W, bands)
            image_rgb = img[:, :, [3,2,1]]
            image_rgb = np.clip(image_rgb, 0, 255).astype(np.uint8)
        pil_img = Image.fromarray(image_rgb)
        if self.transform:
            pil_img = self.transform(pil_img)
        label = self.labels[index]
        return pil_img, label
    def __len__(self):
        return len(self.samples)

In [None]:
# Dataset class for test .npy files (nice id extraction)
class TestNPYDataset(Dataset):
    def __init__(self, test_dir, transform=None):
        self.transform = transform
        self.samples = []
        self.ids = []
        import re
        def extract_num(filename):
            match = re.search(r'test_(\d+)\.npy', filename)
            return int(match.group(1)) if match else float('inf')
        files = [f for f in os.listdir(test_dir) if f.endswith(".npy")]
        for f in sorted(files, key=extract_num):
            img_path = os.path.join(test_dir, f)
            self.samples.append(img_path)
            match = re.search(r'test_(\d+)\.npy', f)
            self.ids.append(match.group(1) if match else f)
    def __getitem__(self, index):
        img_path = self.samples[index]
        img_id = self.ids[index]
        img = np.load(img_path)
        image_rgb = img[:, :, [3,2,1]]
        pil_img = Image.fromarray(image_rgb.astype(np.uint8))
        if self.transform:
            pil_img = self.transform(pil_img)
        return pil_img, img_id
    def __len__(self):
        return len(self.samples)

In [None]:
# Hyperparameters and transforms
class_names = [
    'AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial',
    'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake'
]
input_size = 224
batch_size = 32
epochs = 10
lr = 1e-3
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Prepare dataset and loaders
dataset = TrainDataset(TRAIN_DIR, transform=train_transform)
val_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=8)

In [None]:
# Model setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from collections import Counter
class_counts = Counter([label for label in dataset.labels])
weights = torch.tensor([1.0 / class_counts[i] if class_counts[i] > 0 else 0.0 for i in range(len(class_names))], dtype=torch.float)
criterion = nn.CrossEntropyLoss(weight=weights.to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(class_names))
model = model.to(device)

In [None]:
# Training loop with tqdm progress bar and checkpoints
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    avg_loss = running_loss / len(train_loader.dataset)
    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct / total
    print(f'Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f} - Val Acc: {val_acc:.4f}')
    # Per-class accuracy
    class_correct = [0] * len(class_names)
    class_total = [0] * len(class_names)
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            for i in range(len(labels)):
                label = labels[i].item()
                pred = preds[i].item()
                if label == pred:
                    class_correct[label] += 1
                class_total[label] += 1
    for i, cname in enumerate(class_names):
        acc = class_correct[i] / class_total[i] if class_total[i] > 0 else 0
        print(f"Val Acc {cname}: {acc:.4f}")
    # Save checkpoint
    torch.save(model.state_dict(), f'eurosat_resnet50_epoch_{epoch+1}.pth')
    print(f'Checkpoint saved: eurosat_resnet50_epoch_{epoch+1}.pth')

In [None]:
# Save final model weights
torch.save(model.state_dict(), 'eurosat_resnet50.pth')
print('Model trained and saved as eurosat_resnet50.pth')

In [None]:
# Test inference and CSV export
test_dataset = TestNPYDataset(TEST_DIR, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
model.load_state_dict(torch.load('eurosat_resnet50.pth', map_location=device))
model.eval()
results = []
with torch.no_grad():
    for img, img_id in tqdm(test_loader, desc='Testing'):
        img = img.to(device)
        outputs = model(img)
        _, pred = torch.max(outputs, 1)
        label = class_names[pred.item()]
        results.append({'test_id': img_id[0], 'label': label})
df = pd.DataFrame(results)
df.to_csv('test_predictions.csv', index=False)
print('test_predictions.csv saved.')