## Train Linear Head

### 1. Imports

In [None]:
import os
import glob
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.manifold import TSNE
import seaborn as sns
import numpy as np
import random
import matplotlib.colors as mcolors
import torch.nn.functional as F

### 2. Configuration

In [None]:
# --- Configurations ---
# Set the device to CUDA if available, otherwise use the CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Root directory of the EuroSAT dataset
DATA_ROOT = "data/eurosat/eurosat/2750"
# Path to the pretrained backbone checkpoint
BACKBONE_CKPT = "models/res18/backbone/backbone_res18.pt"
# Directory to save checkpoints
CKPT_DIR = "runs/linear_probe_eurosat_res18/checkpoints_linear_probe_res18"
# Directory to save TensorBoard logs
LOG_DIR = "runs/linear_probe_eurosat_res18"
# Number of training epochs
NUM_EPOCHS = 100
# Batch size for training and validation
BATCH_SIZE = 256
# Learning rate for the optimizer
LR = 0.01
# Momentum for the SGD optimizer
MOMENTUM = 0.9
# Weight decay for regularization
WEIGHT_DECAY = 1e-4
# Number of classes in the EuroSAT dataset
NUM_CLASSES = 10

# Create directories if they don't exist
os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

# Class names for the EuroSAT dataset
class_names = [
    "AnnualCrop", "Forest", "HerbaceousVegetation", "Highway",
    "Industrial", "Pasture", "PermanentCrop", "Residential",
    "River", "SeaLake"
]

### 3. Data Preparation

In [None]:
# Mean and standard deviation for normalization
mean = [0.48241806, 0.48080587, 0.47794071]
std = [0.19021621, 0.16879530, 0.14623168]
# Preprocessing pipeline for the images
preproc = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# Custom dataset class for EuroSAT
class EuroSATJPG(Dataset):
    def __init__(self, root_dir, transform=None):
        super().__init__()
        self.transform = transform
        self.paths = glob.glob(os.path.join(root_dir, "*", "*.jpg"))
        self.paths.sort()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        label = os.path.basename(os.path.dirname(path))
        y = class_names.index(label)
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, y

# Create the dataset and split into training and validation sets
ds = EuroSATJPG(DATA_ROOT, transform=preproc)
n = len(ds)
n_train = int(0.8 * n)
train_ds, val_ds = torch.utils.data.random_split(ds, [n_train, n - n_train])
# Create data loaders for training and validation
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# Function to save the validation dataset
def save_dataset(dataset, filename):
    data = []
    labels = []
    for img, label in dataset:
        data.append(img)
        labels.append(label)
    data_tensor = torch.stack(data)
    labels_tensor = torch.tensor(labels)
    torch.save((data_tensor, labels_tensor), filename)

# Save the validation dataset
save_dataset(val_ds, "val_dataset.pth")

### 4. Model Definition

In [None]:
# Feature dimension for ResNet-18
feat_dim = 512
# Load the ResNet-18 model with pretrained weights
encoder = models.resnet18(pretrained=True)
# Remove the final classification layer
encoder.fc = nn.Identity()
# Load the pretrained backbone checkpoint
ckpt = torch.load(BACKBONE_CKPT, map_location=DEVICE)
# Load the state dict into the encoder
state = {k.replace("backbone.", ""): v for k, v in ckpt["model_state_dict"].items() if k.startswith("backbone.")}
encoder.load_state_dict(state, strict=True)
# Move the encoder to the device and set to evaluation mode
encoder.to(DEVICE).eval()
# Freeze the encoder weights
for p in encoder.parameters():
    p.requires_grad = False

# Define the linear head for classification
head = nn.Linear(feat_dim, NUM_CLASSES).to(DEVICE)

### 5. Training

In [None]:
# Loss function, optimizer, and learning rate scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(head.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
# TensorBoard writer
writer = SummaryWriter(log_dir=LOG_DIR)

# Projector config for embeddings
metadata_path = os.path.join(LOG_DIR, "metadata.tsv")
with open(metadata_path, "w") as f:
    f.write("Label\n")
projector_config = f"""
embeddings {{
  tensor_name: "LinearProbe/Embeddings"
  metadata_path: "{os.path.basename(metadata_path)}"
}}
"""
with open(os.path.join(LOG_DIR, "projector_config.pbtxt"), "w") as f:
    f.write(projector_config)

# Training and validation loop
best_val_acc = 0.0
for epoch in range(1, NUM_EPOCHS + 1):
    # Training
    head.train()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    for x, y in tqdm(train_loader, desc=f"[Train] Epoch {epoch}/{NUM_EPOCHS}"):
        x, y = x.to(DEVICE), y.to(DEVICE)
        with torch.no_grad():
            feats = encoder(x)
        logits = head(feats)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * y.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
        all_preds.append(pred.cpu())
        all_labels.append(y.cpu())

    # Log training metrics
    train_loss = running_loss / total
    train_acc = correct / total
    scheduler.step()
    writer.add_scalar("LinearProbe/Train_Loss", train_loss, epoch)
    writer.add_scalar("LinearProbe/Train_Acc", train_acc, epoch)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    writer.add_scalar("LinearProbe/Train_Precision", precision_score(all_labels, all_preds, average='weighted', zero_division=1), epoch)
    writer.add_scalar("LinearProbe/Train_Recall", recall_score(all_labels, all_preds, average='weighted', zero_division=1), epoch)
    writer.add_scalar("LinearProbe/Train_F1", f1_score(all_labels, all_preds, average='weighted', zero_division=1), epoch)

    # Validation
    head.eval()
    correct, total = 0, 0
    all_feats, all_preds_list, all_labels_list, all_imgs = [], [], [], []
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            feats = encoder(x)
            logits = head(feats)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

            all_feats.append(feats.cpu())
            all_preds_list.append(pred.cpu())
            all_labels_list.append(y.cpu())
            all_imgs.append(x.cpu())

    # Log validation metrics
    val_acc = correct / total
    writer.add_scalar("LinearProbe/Val_Acc", val_acc, epoch)

    val_preds = torch.cat(all_preds_list)
    val_labels = torch.cat(all_labels_list)
    writer.add_scalar("LinearProbe/Val_Precision", precision_score(val_labels, val_preds, average='weighted', zero_division=1), epoch)
    writer.add_scalar("LinearProbe/Val_Recall", recall_score(val_labels, val_preds, average='weighted', zero_division=1), epoch)
    writer.add_scalar("LinearProbe/Val_F1", f1_score(val_labels, val_preds, average='weighted', zero_division=1), epoch)

    # Log embeddings
    feats_cat = torch.cat(all_feats)
    labels_cat = torch.cat(all_labels_list)

    max_pts = 1000
    if feats_cat.size(0) > max_pts:
        idx = torch.randperm(feats_cat.size(0))[:max_pts]
        feats_sample = feats_cat[idx]
        labels_sample = labels_cat[idx].tolist()
    else:
        feats_sample = feats_cat
        labels_sample = labels_cat.tolist()

    writer.add_embedding(
        mat=feats_sample,
        metadata=labels_sample,
        global_step=epoch,
        tag="LinearProbe/Embeddings"
    )

    # Log t-SNE plot
    tsne = TSNE(n_components=2)
    feats_2d = tsne.fit_transform(feats_sample.numpy())
    fig_tsne, ax_tsne = plt.subplots()
    ax_tsne.scatter(feats_2d[:, 0], feats_2d[:, 1], c=labels_sample, s=5)
    ax_tsne.set_title(f"t-SNE Val (Epoch {epoch})")
    writer.add_figure("LinearProbe/tSNE_Val", fig_tsne, epoch)
    plt.close(fig_tsne)

    # Log confusion matrix
    cm = confusion_matrix(labels_cat.tolist(), val_preds.tolist())
    fig_cm, ax_cm = plt.subplots(figsize=(6, 6))
    sns.heatmap(cm, annot=True, fmt="d",
                xticklabels=class_names, yticklabels=class_names,
                ax=ax_cm)
    ax_cm.set_xlabel("Predicted")
    ax_cm.set_ylabel("True")
    writer.add_figure("LinearProbe/ConfusionMatrix_Val", fig_cm, epoch)
    plt.close(fig_cm)

    print(f"Epoch {epoch:03d} | Train: {train_loss:.4f}/{train_acc:.3f} | Val: {val_acc:.3f}")

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        path = os.path.join(CKPT_DIR, f"best_head_epoch{epoch:03d}.pth")
        torch.save({"epoch": epoch, "head_state_dict": head.state_dict(), "val_acc": val_acc}, path)

# Save the final model
final_path = os.path.join(CKPT_DIR, "head_final.pth")
torch.save(head.state_dict(), final_path)
print(f"Final head saved at {final_path}")
writer.close()

### 6. Evaluation

In [None]:
import os
import torch
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms, models
from sklearn.metrics import confusion_matrix
from sklearn.manifold import TSNE
import numpy as np
import random
import matplotlib.colors as mcolors

# --- Configuration ---
# Choose GPU if available, otherwise fall back to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Class labels for the EuroSAT dataset
class_names = [
    "AnnualCrop", "Forest", "HerbaceousVegetation", "Highway", "Industrial",
    "Pasture", "PermanentCrop", "Residential", "River", "SeaLake"
]

# --- Utility: Inverse normalization for visualization ---
# These mean/std values match the ones used during training
mean = torch.tensor([0.48241806, 0.48080587, 0.47794071])
std  = torch.tensor([0.19021621, 0.16879530, 0.14623168])

def unnormalize(img_tensor):
    """
    Reverses normalization on a tensor and converts it to a PIL Image.
    """
    img = img_tensor.clone()
    for c in range(3):
        img[c] = img[c] * std[c] + mean[c]              # undo each channel normalization
    img = img.mul(255).byte().permute(1, 2, 0).cpu()     # to H×W×C uint8
    return Image.fromarray(img.numpy())

# --- Preprocessing transform for raw images ---
transform = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean.tolist(), std=std.tolist()),
])

# --- Load SSL-trained ResNet-18 encoder ---
encoder = models.resnet18(pretrained=False)
encoder.fc = torch.nn.Identity()  # remove classification head
encoder.to(device).eval()

# Load the checkpoint and extract only the backbone parameters
ckpt_ssl = torch.load(
    "models/res18/backbone/backbone_res18.pt",
    map_location=device
)
state = {
    k.replace("backbone.", ""): v
    for k, v in ckpt_ssl["model_state_dict"].items()
    if k.startswith("backbone.")
}
encoder.load_state_dict(state, strict=True)

# --- Load linear probing head ---
feat_dim = 512
head = torch.nn.Linear(feat_dim, len(class_names)).to(device).eval()

ckpt_lp = torch.load(
    "models/res18/linear_probe/head_final.pth",
    map_location=device
)
# Support two checkpoint formats
if "head_state_dict" in ckpt_lp:
    head.load_state_dict(ckpt_lp["head_state_dict"])
else:
    head.weight.data = ckpt_lp["weight"]
    head.bias.data   = ckpt_lp["bias"]

# --- Load validation dataset ---
# The saved file may contain (images_tensor, labels_tensor) or a list of (img,label) pairs
val_dataset = torch.load(os.path.expanduser("~/val_dataset.pth"), map_location=device)
if isinstance(val_dataset, (list, tuple)) and len(val_dataset) == 2 and isinstance(val_dataset[0], torch.Tensor):
    images_tensor, labels_tensor = val_dataset
    use_tensor_input = True
    total = images_tensor.size(0)
else:
    use_tensor_input = False
    total = len(val_dataset)

# --- Create a random collage of samples with predictions ---
num_samples = 20
indices = random.sample(range(total), num_samples)
fig, axes = plt.subplots(4, 5, figsize=(15, 12))

for ax, idx in zip(axes.flatten(), indices):
    # Select image and true label
    if use_tensor_input:
        x = images_tensor[idx].unsqueeze(0).to(device)
        true_label = int(labels_tensor[idx].item())
    else:
        sample = val_dataset[idx]
        # Support dict or tuple formats
        img = sample.get("image", sample.get("img")) if isinstance(sample, dict) else sample[0]
        true_label = sample.get("label", sample.get("target")) if isinstance(sample, dict) else int(sample[1])
        x = transform(img).unsqueeze(0).to(device)

    # Forward pass through encoder + head
    with torch.no_grad():
        feats  = encoder(x)
        logits = head(feats)
        probs  = F.softmax(logits, dim=1)
    pred = probs.argmax(dim=1).item()

    # Unnormalize and plot
    img_pil = unnormalize(x[0])
    ax.imshow(img_pil)
    ax.set_title(f"T: {class_names[true_label]}\nP: {class_names[pred]} ({probs[0,pred]:.2f})")
    ax.axis('off')

plt.tight_layout()
plt.savefig("collage_res18.png", dpi=150, bbox_inches='tight')
plt.close()
print("Collage saved as collage_res18.png")

# --- Inference over the entire validation set and feature extraction ---
y_true, y_pred, features_list = [], [], []

for idx in range(total):
    if use_tensor_input:
        x = images_tensor[idx].unsqueeze(0).to(device)
        label = int(labels_tensor[idx].item())
    else:
        sample = val_dataset[idx]
        if isinstance(sample, dict):
            img = sample.get("image", sample.get("img"))
            label = int(sample.get("label", sample.get("target", -1)))
        else:
            img, label = sample[0], int(sample[1])
        x = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        feat   = encoder(x)
        logits = head(feat)
        probs  = F.softmax(logits, dim=1)
    pred = probs.argmax(dim=1).item()

    y_true.append(label)
    y_pred.append(pred)
    features_list.append(feat.cpu().squeeze().numpy())

# --- Plot Confusion Matrix for ResNet-18 ---
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation='nearest', cmap='viridis')
plt.title("Confusion Matrix ResNet-18", fontsize=16)
plt.xlabel("Predicted Label", fontsize=14)
plt.ylabel("True Label", fontsize=14)
plt.xticks(np.arange(len(class_names)), class_names, rotation=45, ha='right')
plt.yticks(np.arange(len(class_names)), class_names)
plt.colorbar(label="Number of Samples")
plt.tight_layout()
plt.savefig("confusion_matrix_res18.png", dpi=150, bbox_inches='tight')
plt.show()

# --- t-SNE Visualization of Feature Embeddings for ResNet-18 ---
features = np.stack(features_list)
tsne = TSNE(n_components=2, random_state=42)
z = tsne.fit_transform(features)

plt.figure(figsize=(8, 8))
# Use a discrete colormap with boundary normalization
cmap = plt.get_cmap('tab10', len(class_names))
norm = mcolors.BoundaryNorm(boundaries=np.arange(len(class_names)+1)-0.5, ncolors=cmap.N)

scatter = plt.scatter(z[:, 0], z[:, 1], c=y_true, cmap=cmap, norm=norm, s=10)
plt.title("t-SNE Feature Embeddings (ResNet-18)", fontsize=16)
plt.xlabel("t-SNE Component 1", fontsize=14)
plt.ylabel("t-SNE Component 2", fontsize=14)

# Add a colorbar with class names
cbar = plt.colorbar(scatter, ticks=np.arange(len(class_names)))
cbar.ax.set_yticklabels(class_names)
cbar.set_label("Classes", fontsize=14)

plt.tight_layout()
plt.savefig("tsne_res18.png", dpi=150, bbox_inches='tight')
plt.show()
