
# Dependencies

In [None]:
import copy
import datetime
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import zipfile

from collections import OrderedDict
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

# Data Pipeline

## Import from Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

IMAGES_ZIP_PATH = '/content/drive/MyDrive/CS2952Q-FP/MLRSNet/images.zip'
LABELS_ZIP_PATH = '/content/drive/MyDrive/CS2952Q-FP/MLRSNet/labels.zip'
LOCAL_IMAGES_DIR = '/content/images'
LOCAL_LABELS_DIR = '/content/labels'

def make_local_dir(dirname, zip_path):
    os.makedirs(dirname, exist_ok=True)

    if not os.listdir(dirname):
        print(f'Unzipping {zip_path}...')
        with zipfile.ZipFile(zip_path, 'r') as zipf:
            zipf.extractall(dirname)
        print(f'Completed unzipping {zip_path}.')
    else:
        print(f'{zip_path} was already unzipped.')

    for subfolder in os.listdir(dirname):
        subfolder_path = os.path.join(dirname, subfolder)
        if os.path.isdir(subfolder_path):
            for f in os.listdir(subfolder_path):
                src_path = os.path.join(subfolder_path, f)
                dst_path = os.path.join(dirname, f)
                shutil.move(src_path, dst_path)
            os.rmdir(subfolder_path)

make_local_dir(LOCAL_IMAGES_DIR, IMAGES_ZIP_PATH)
make_local_dir(LOCAL_LABELS_DIR, LABELS_ZIP_PATH)

## Dataset Preparation

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def encode_labels() -> tuple[pd.DataFrame, int]:
    # Read all label CSV files and combine them into a single DataFrame
    label_files = glob.glob(os.path.join(LOCAL_LABELS_DIR, '*.csv'))

    dfs = []
    for label_file in label_files:
        df = pd.read_csv(label_file)
        dfs.append(df)
    labels_df = pd.concat(dfs, ignore_index=True)
    print("Columns in labels_df:", labels_df.columns.tolist())

    # Remove duplicate entries
    labels_df = labels_df.drop_duplicates(subset='image')

    # Correct typos in class names
    labels_df = labels_df.rename(columns={
        'habor': 'harbor',
        'swimmimg pool': 'swimming pool'
    })

    # Get the list of all class names from the header of the CSV files
    class_names = labels_df.columns.tolist()
    class_names.remove('image')  # Remove 'image' column
    num_classes = len(class_names)
    class_to_index = {cls_name: idx for idx, cls_name in enumerate(class_names)}

    # Map image filenames to their labels
    labels_df['multi_hot'] = labels_df[class_names].values.tolist()

    # Drop the individual class label columns to reduce redundancy
    labels_df = labels_df.drop(columns=class_names)

    # Pre-index all image paths
    all_image_paths = glob.glob(os.path.join(LOCAL_IMAGES_DIR,'*.jpg'))
    image_name_to_path = {}
    for img_path in all_image_paths:
        image_name = os.path.basename(img_path)
        image_name_to_path[image_name] = img_path

    # Build a mapping from image paths to multi-hot labels
    data = {'image_path': [], 'multi_labels': []}

    for idx, row in labels_df.iterrows():
        image_name = row['image']
        # Look up the image path directly
        image_path = image_name_to_path.get(image_name)
        if image_path is None:
            print(f'Image {image_name} not found in dataset.')
            continue
        data['image_path'].append(image_path)
        # Use the 'multi_hot' column directly
        multi_hot = np.array(row['multi_hot'], dtype=np.float32)
        data['multi_labels'].append(multi_hot)

    return pd.DataFrame(data), num_classes

def train_val_test_split(df: pd.DataFrame, train_size, val_size) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    vt_size = 1 - train_size
    test_size = 1 - val_size / vt_size

    train, val_test = train_test_split(df, test_size=vt_size, random_state=42)
    val, test = train_test_split(val_test, test_size=test_size, random_state=42)
    print(f'Training samples: {len(train)}, Validation samples: {len(val)}, Test samples: {len(test)}')
    return train, val, test

labels_df, num_classes = encode_labels()
train_df, val_df, test_df = train_val_test_split(labels_df, .7, .2)

### Dataset Classes

#### MLRSNet Dataset

In [None]:
class MLRSNetDataset(Dataset):
    def __init__(self, df, transform=None):
        self.image_paths = df['image_path'].tolist()
        self.multi_labels = np.stack(df['multi_labels'].values)
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            image = Image.new('RGB', (224, 224))
        multi_label = torch.from_numpy(self.multi_labels[idx])

        if self.transform:
            image = self.transform(image)

        return image, multi_label


#### Dataset Statistics (*needed for SimCLR transforms*)

In [None]:
def compute_mean_std(loader, agg_batch_stats=False):
    mean = torch.zeros(3)
    std = torch.zeros(3)
    m = 0

    for images, _ in loader:
        images = images.to('cpu')
        n = images.size(0)
        images = images.view(n, images.size(1), -1)  # (B, C, H*W)

        if agg_batch_stats:
            mean += images.mean(2).sum(0).detach().clone()
            std += images.std(2).sum(0).detach().clone()

        else:
            mean_m = mean.detach().clone()
            mean_n = images.mean(2).sum(0)
            mean = m*mean_m/(m+n) + n*mean_n/(m+n)

            var_m = std.detach().clone() ** 2
            var_n = images.var(2).sum(0)
            var = m*var_m/(m+n) + n*var_n/(m+n) + m*n*(mean_m-mean_n)**2/(m+n)**2
            std = torch.sqrt(var)

        m += n

    if agg_batch_stats:
        mu /= m
        std /= m

    return mean.numpy(), std.numpy()


compute_stats_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

stats_dataset = MLRSNetDataset(train_df, transform=compute_stats_transform)
stats_loader = DataLoader(
    stats_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)

train_mean, train_std = compute_mean_std(stats_loader)
del stats_dataset, stats_loader
print(f'Dataset Mean: {train_mean}')
print(f'Dataset Std: {train_std}')

simclr_pretrain_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224, ratio=(1,1), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # brightness, contrast, saturation, hue
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=train_mean.tolist(), std=train_std.tolist())
])

#### SimCLR Pre-Training Dataset

In [None]:
class SimCLRDataset(Dataset):
    def __init__(self, df, transform=None):
        self.image_paths = df['image_path'].tolist()
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            view1 = self.transform(image)
            view2 = self.transform(image)
        else:
            view1 = view2 = transforms.ToTensor()(image)

        return view1, view2

# Contrastive Pre-Training

## Custom Modules

### SimCLR

In [None]:
class SimCLR(nn.Module):
    def __init__(self, encoder, projection_dim=128):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.encoder.fc = nn.Identity()  # Remove the original fully connected layer

        # Projection head
        self.projection_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)  # [batch_size, 2048, 1, 1]
        h = torch.flatten(h, start_dim=1)  # [batch_size, 2048]
        z = self.projection_head(h)  # [batch_size, projection_dim]
        return h, z


class NTXentLoss(nn.Module):
    def __init__(self, batch_size, temperature=0.5, device='cuda'):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.mask = self._get_correlated_mask().type(torch.bool)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def _get_correlated_mask(self):
        N = 2 * self.batch_size
        mask = torch.ones((N, N), dtype=torch.float32)
        mask = mask.fill_diagonal_(0)
        for i in range(self.batch_size):
            mask[i, self.batch_size + i] = 0
            mask[self.batch_size + i, i] = 0
        return mask.to(self.device)

    def forward(self, z_i, z_j):
        N = 2 * self.batch_size
        z = torch.cat((z_i, z_j), dim=0)  # [2N, projection_dim]

        # Compute similarity matrix
        sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)  # [2N, 2N]

        sim = sim / self.temperature

        # Exclude self-comparisons and positive pairs
        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0)  # [2N]
        negative_samples = sim[self.mask].view(N, -1)  # [2N, 2N-2]

        labels = torch.zeros(N).to(self.device).long()  # Positive samples are at index 0
        logits = torch.cat((positive_samples.unsqueeze(1), negative_samples), dim=1)  # [2N, 2N-1]

        loss = self.criterion(logits, labels)
        loss /= N
        return loss

## Pre-Training Routines

### Baseline Encoder (no pre-training)

In [None]:
base_encoder = models.resnet50(weights='IMAGENET1K_V2')
base_encoder.to(device)
base_encoder_name = 'resnet50'

base_encoder.fc = nn.Identity()

### SimCLR

In [None]:
PRETRAIN_EPOCHS_SIMCLR = 10
PRETRAIN_LR_SIMCLR = 3e-4
PRETRAIN_BATCH_SIZE_SIMCLR = 128
PRETRAIN_TEMP_SIMCLR = 0.5

model_simclr = SimCLR(copy.deepcopy(base_encoder)).to(device)

dataset_simclr_pretrain = SimCLRDataset(train_df, transform=simclr_pretrain_transform)
loader_simclr_pretrain = DataLoader(
    dataset_simclr_pretrain,
    batch_size=PRETRAIN_BATCH_SIZE_SIMCLR,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    drop_last=True
)

print('SimCLR dataset and loader initialized.')

opt_simclr = optim.Adam(model_simclr.parameters(), lr=PRETRAIN_LR_SIMCLR)
criterion_simclr = NTXentLoss(batch_size=PRETRAIN_BATCH_SIZE_SIMCLR, temperature=PRETRAIN_TEMP_SIMCLR, device=device)

model_simclr.train()
for epoch in range(1, PRETRAIN_EPOCHS_SIMCLR + 1):
    total_loss = 0.0
    for step, (view1, view2) in enumerate(loader_simclr_pretrain):
        view1 = view1.to(device, non_blocking=True)
        view2 = view2.to(device, non_blocking=True)

        opt_simclr.zero_grad()

        h1, z1 = model_simclr(view1)
        h2, z2 = model_simclr(view2)

        loss = criterion_simclr(z1, z2)
        loss.backward()
        opt_simclr.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader_simclr_pretrain)
    print(f'SimCLR Epoch [{epoch}/{PRETRAIN_EPOCHS_SIMCLR}], Loss: {avg_loss:.4f}')

    # save checkpoints
    checkpoint_path = f'/content/simclr/{base_encoder_name}/pretrain/epoch-{epoch}_loss-{avg_loss:.4f}.pth'
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
    torch.save(model_simclr.state_dict(), checkpoint_path)
    print(f'Checkpoint saved at {checkpoint_path}')

# Linear Evaluation

## Multi-Label Linear Probe

In [None]:
class MultiLabelLinearProbe(nn.Module):
    def __init__(self, encoder_output_dim, num_classes):
        super(MultiLabelLinearProbe, self).__init__()
        self.linear = nn.Linear(encoder_output_dim, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

## Evaluation Routine

In [None]:
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
BATCH_SIZE = 128


def linear_eval_trainval(encoder, probe, model_name, encoder_name, num_epochs, criterion, optimizer, train_set, train_load, val_set, val_load) -> dict:
    metrics = {
        'epoch': list(range(1, num_epochs+1)),
        'train_loss': list(),
        'val_loss': list(),
        'val_mAP': list()
    }

    for param in encoder.parameters():
        param.requires_grad = False
    print('Encoder weights frozen.')

    # train linear probe
    for epoch in range(1, num_epochs + 1):
        probe.train()
        running_loss_train = 0.0

        for images, multi_labels in train_load:
            images = images.to(device, non_blocking=True)
            multi_labels = multi_labels.to(device, non_blocking=True)

            optimizer.zero_grad()

            with torch.no_grad():
                embeddings = encoder(images)

            outputs = probe(embeddings)
            loss = criterion(outputs, multi_labels)
            loss.backward()
            optimizer.step()

            running_loss_train += loss.item() * images.size(0)

        epoch_loss_train = running_loss_train / len(train_set)
        print(f'Classification Training Epoch [{epoch}/{num_epochs}], Loss: {epoch_loss_train:.4f}')
        metrics['train_loss'].append(epoch_loss_train)

        # validation
        probe.eval()
        running_loss_val = 0.0
        all_targets = []
        all_outputs = []

        with torch.no_grad():
            for images, multi_labels in val_load:
                images = images.to(device, non_blocking=True)
                multi_labels = multi_labels.to(device, non_blocking=True)

                embeddings = encoder(images)
                outputs = probe(embeddings)

                loss = criterion(outputs, multi_labels)
                running_loss_val += loss.item() * images.size(0)

                all_targets.append(multi_labels.cpu().numpy())
                all_outputs.append(outputs.cpu().numpy())

        epoch_loss_val = running_loss_val / len(val_set)
        print(f'Validation Loss: {epoch_loss_val:.4f}')
        metrics['val_loss'].append(epoch_loss_val)

        # compute mAP
        all_targets = np.vstack(all_targets)
        all_outputs = np.vstack(all_outputs)
        all_outputs = 1 / (1 + np.exp(-all_outputs))  # sigmoid activation

        average_precisions = []
        for i in range(num_classes):
            try:
                ap = average_precision_score(all_targets[:, i], all_outputs[:, i])
            except ValueError:
                ap = 0.0  # edge case: no positive labels
            average_precisions.append(ap)

        mAP_val = np.mean(average_precisions)
        print(f'mAP (Multi-Label): {mAP_val:.4f}')
        metrics['val_mAP'].append(mAP_val)

        # save checkpoints
        checkpoint_path = f'/content/{model_name}/{encoder_name}/eval/epoch-{epoch}_loss-{epoch_loss_val:.4f}_mAP-{mAP_val:.4f}.pth'
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        torch.save(probe.state_dict(), checkpoint_path)
        print(f'Checkpoint saved at {checkpoint_path}')

    return metrics


def get_min_loss_pretrained(model_name, encoder_name):
    min_loss = float('inf')
    min_loss_path = None
    for filename in os.listdir(f'/content/{model_name}/{encoder_name}/pretrain/'):
        if filename.endswith('.pth'):
            loss_value = float(filename.split('_')[1].split('-')[1].split('.')[0])
            if loss_value < min_loss:
                min_loss = loss_value
                min_loss_path = os.path.join(f'/content/{model_name}/{encoder_name}/pretrain/', filename)
    return min_loss_path


def linear_eval_test(encoder, probe, test_set, test_load) -> float:
    probe.eval()
    running_loss_test = 0.0
    all_targets = []
    all_outputs = []

    with torch.no_grad():
        for images, multi_labels in test_load:
            images = images.to(device, non_blocking=True)
            multi_labels = multi_labels.to(device, non_blocking=True)

            embeddings = encoder(images)
            outputs = probe(embeddings)

            all_targets.append(multi_labels.cpu().numpy())
            all_outputs.append(outputs.cpu().numpy())

        all_targets = np.vstack(all_targets)
        all_outputs = np.vstack(all_outputs)
        all_outputs = 1 / (1 + np.exp(-all_outputs))  # sigmoid activation

        average_precisions = []
        for i in range(num_classes):
            try:
                ap = average_precision_score(all_targets[:, i], all_outputs[:, i])
            except ValueError:
                ap = 0.0  # edge case: no positive labels
            average_precisions.append(ap)

        mAP_val = np.mean(average_precisions)
        print(f'mAP (Multi-Label): {mAP_val:.4f}')
        return mAP_val


def get_max_mAP_eval(model_name, encoder_name):
    max_mAP = float('-inf')
    max_mAP_path = None
    for filename in os.listdir(f'/content/{model_name}/{encoder_name}/eval/'):
        if filename.endswith('.pth'):
            mAP_value = float(filename.split('_')[2].split('-')[1].split('.')[0])
            if mAP_value > max_mAP:
                max_mAP = mAP_value
                max_mAP_path = os.path.join(f'/content/{model_name}/{encoder_name}/eval/', filename)
    return max_mAP_path


def linear_eval(model_name, encoder_name, train_set, train_load, val_set, val_load, test_set, test_load):
    probe = MultiLabelLinearProbe(
        encoder_output_dim=2048,
        num_classes=num_classes
    ).to(device)

    if model_name == 'baseline':
        encoder = base_encoder.to(device)
    elif model_name == 'simclr':
        best_loss_path = get_min_loss_pretrained(model_name, encoder_name)
        model_simclr.load_state_dict(torch.load(best_loss_path, map_location=device))
        encoder = model_simclr.encoder.to(device)
    else:
        raise ValueError('Invalid model name.')

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(probe.parameters(), lr=LEARNING_RATE)
    print('Supervised loss function and optimizer initialized.')

    metrics = linear_eval_trainval(
        encoder=encoder,
        probe=probe,
        model_name=model_name,
        encoder_name=encoder_name,
        num_epochs=NUM_EPOCHS,
        criterion=criterion,
        optimizer=optimizer,
        train_set=train_set,
        train_load=train_load,
        val_set=val_set,
        val_load=val_load
    )

    max_mAP_path = get_max_mAP_eval(model_name, encoder_name)
    probe.load_state_dict(torch.load(max_mAP_path, map_location=device))

    mAP = linear_eval_test(
        encoder=encoder,
        probe=probe,
        test_set=test_set,
        test_load=test_load
    )

    return metrics, mAP


def visualize(full_metrics, test_mAP):
    # side-by-side comparison of training losses across models
    fig = plt.figure(figsize=(10, 6))
    for model_name, metrics in full_metrics.items():
        plt.plot(metrics['epoch'], metrics['train_loss'])
    fig.suptitle('Training Loss Comparison')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(full_metrics.keys())
    os.makedirs('/content/output/', exist_ok=True)
    plt.savefig('/content/output/training_loss_comparison.png')
    plt.show()

    # side-by-side comparison of validation losses across models
    fig = plt.figure(figsize=(10, 6))
    for model_name, metrics in full_metrics.items():
        plt.plot(metrics['epoch'], metrics['val_loss'])
    fig.suptitle('Validation Loss Comparison')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(full_metrics.keys())
    plt.savefig('/content/output/validation_loss_comparison.png')
    plt.show()

    # side-by-side comparison of validation mAP across models
    fig = plt.figure(figsize=(10, 6))
    for model_name, metrics in full_metrics.items():
        plt.plot(metrics['epoch'], metrics['val_mAP'])
    fig.suptitle('Validation mAP Comparison')
    plt.xlabel('Epoch')
    plt.ylabel('mAP')
    plt.legend(full_metrics.keys())
    plt.savefig('/content/output/validation_mAP_comparison.png')

    # side-by-side comparison of test mAP across models
    fig = plt.figure(figsize=(10, 6))
    for model_name, mAP in test_mAP.items():
        plt.bar(model_name, mAP)
    fig.suptitle('Test mAP Comparison')
    plt.xlabel('Model')
    plt.ylabel('mAP')
    plt.savefig('/content/output/test_mAP_comparison.png')
    plt.show()


def main():
    eval_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=train_mean.tolist(), std=train_std.tolist())
    ])
    train_set = MLRSNetDataset(df=train_df, transform=eval_transform)
    train_load = DataLoader(
        train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8,
        pin_memory=True
    )
    val_set = MLRSNetDataset(df=val_df, transform=eval_transform)
    val_load = DataLoader(
        val_set,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )
    test_set = MLRSNetDataset(df=test_df, transform=eval_transform)
    test_load = DataLoader(
        test_set,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )

    full_metrics = {}
    test_mAP = {}
    for model_name in ['baseline', 'mae']:
        metrics, mAP = linear_eval(
            model_name, base_encoder_name,
            train_set, train_load,
            val_set, val_load,
            test_set, test_load
        )
        full_metrics[model_name] = metrics
        test_mAP[model_name] = mAP


main()

# Save Results to Drive

In [None]:
def zip_folders(output_path):
    with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        # TODO: ADD MAE
        for folder_name in ['output', 'simclr', 'baseline']:
            folder_path = os.path.join('/content', folder_name)
            for root, _, files in os.walk(folder_path):
                for file in files:
                    file_path = os.path.join(root, file)
                    zipf.write(file_path, os.path.relpath(file_path, '/content'))

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
zip_folders(f'/content/drive/MyDrive/CS2952Q-FP/results/results_{timestamp}.zip')