In [None]:
import os 
os.chdir("/Users/oliverdaniels-koch/projects/elk-experiments/")

In [None]:
from tqdm import tqdm
import random
from itertools import cycle
from PIL import Image
import numpy as np
import torch 
import torch.nn as nn
from datasets import load_dataset

In [None]:
def set_seed(seed):
    """Sets seed"""
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [None]:
set_seed(1)

In [None]:
dataset_dir = "/Users/oliverdaniels-koch/data/waterbird_complete95_forest2water2"
split_dict = {"train": "train_metadata.csv", "val": "val_metadata.csv", "test": "test_metadata.csv"}
dataset = load_dataset(dataset_dir, data_files=split_dict)

In [None]:
dataset = dataset.remove_columns(["img_id", "place_filename"])

In [None]:
dataset = dataset.rename_column("y", "labels")
dataset = dataset.rename_column("place", "aux_labels")

In [None]:
remove instances on train set where y != place 
dataset["train"] = dataset["train"].filter(lambda x: x["labels"] == x["aux_labels"])

In [None]:
all([label == aux_label for label, aux_label in zip(dataset["train"]["labels"], dataset["train"]["aux_labels"])])

In [None]:
from transformers import AutoImageProcessor, ResNetForImageClassification
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=2, ignore_mismatched_sizes=True) 

In [None]:
n_heads = 2
model.classifier = nn.Sequential(
    nn.Flatten(start_dim=1, end_dim=-1), 
    nn.Linear(in_features=2048, out_features=n_heads * 2, bias=True)
)

In [None]:
image_filename = dataset["train"]["img_filename"][0]
image = Image.open(os.path.join(dataset_dir, image_filename)).convert("RGB")
image

In [None]:
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
logits

In [None]:
def transform_images(examples):
    examples["pixel_values"] = torch.concat([
        processor(Image.open(os.path.join(dataset_dir, img_filename)), return_tensors="pt").pixel_values 
        for img_filename in examples["img_filename"]
    ])
    del examples["img_filename"]
    return examples


In [None]:
dataset = dataset.with_transform(transform_images)

In [None]:
# divdis loss
# from https://github.com/yoonholee/DivDis/blob/main/divdis.py
# TODO: understand this code
from einops import rearrange

def to_probs(logits, heads):
    """
    Converts logits to probabilities.
    Input must have shape [batch_size, heads * classes].
    Output will have shape [batch_size, heads, classes].
    """

    B, N = logits.shape
    logits_chunked = torch.chunk(logits, heads, dim=-1)
    probs = torch.stack(logits_chunked, dim=1).softmax(-1)
    B, H, D = probs.shape
    assert H == heads
    return probs

class DivDisLoss(nn.Module):
    """Computes pairwise repulsion losses for DivDis.

    Args:
        logits (torch.Tensor): Input logits with shape [BATCH_SIZE, HEADS * DIM].
        heads (int): Number of heads.
        mode (str): DIVE loss mode. One of {pair_mi, total_correlation, pair_l1}.
    """

    def __init__(self, heads):
        super().__init__()
        self.heads = heads

    def forward(self, logits):
        heads = self.heads
        probs = to_probs(logits, heads)

        marginal_p = probs.mean(dim=0)  # H, D
        marginal_p = torch.einsum(
            "hd,ge->hgde", marginal_p, marginal_p
        )  # H, H, D, D
        marginal_p = rearrange(marginal_p, "h g d e -> (h g) (d e)")  # H^2, D^2

        joint_p = torch.einsum("bhd,bge->bhgde", probs, probs).mean(
            dim=0
        )  # H, H, D, D
        joint_p = rearrange(joint_p, "h g d e -> (h g) (d e)")  # H^2, D^2

        # Compute pairwise mutual information = KL(P_XY | P_X x P_Y)
        # Equivalent to: F.kl_div(marginal_p.log(), joint_p, reduction="none")
        kl_computed = joint_p * (joint_p.log() - marginal_p.log())
        kl_computed = kl_computed.sum(dim=-1)
        kl_grid = rearrange(kl_computed, "(h g) -> h g", h=heads)
        repulsion_grid = -kl_grid

        repulsion_grid = torch.triu(repulsion_grid, diagonal=1)
        repulsions = repulsion_grid[repulsion_grid.nonzero(as_tuple=True)]
        repulsion_loss = -repulsions.mean()

        return repulsion_loss

In [None]:
from torch.distributions import Categorical, kl_divergence

class RegLoss(nn.Module):
    def __init__(self, n_heads, n_classes):
        super().__init__()
        self.n_heads = n_heads
        self.n_classes = n_classes

    def forward(self, source_logits, target_logits):
        source_logits_chunked = torch.chunk(source_logits, self.n_heads, dim=-1) # [[B, C] * H]
        target_logits_chunked = torch.chunk(target_logits, self.n_heads, dim=-1) # [[B, C] * H]
        source_logits = torch.stack(source_logits_chunked, dim=1) # [B, H, C]
        target_logits = torch.stack(target_logits_chunked, dim=1) # [B, H, C]
        avg_preds_source = source_logits.softmax(dim=-1).mean(dim=[0,1]) # [C]
        target_preds = target_logits.softmax(dim=-1) # [B, H, C]
        avg_preds_target = target_preds.mean(dim=1) # [B, C]
        dist_source = Categorical(probs=avg_preds_source)
        dist_target = Categorical(probs=avg_preds_target)
        reg_loss = kl_divergence(dist_source, dist_target).mean()
        return reg_loss

In [None]:
# cross entropy loss with multiple heads 
class CrossEntropyLossHeads(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.loss = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        logits_chunked = torch.chunk(logits, self.n_heads, dim=-1)
        losses = [self.loss(logits, labels) for logits in logits_chunked]
        return sum(losses) / self.n_heads

In [None]:
# hparams 
batch_size = 16
num_epochs = 10
learning_rate = 1e-3
weight_decay = 1e-4
momentum = 0.9
diversity_weight = 100
reg_weight = 10

device = "mps"

In [None]:
loss = CrossEntropyLossHeads(n_heads)
dividis_loss = DivDisLoss(n_heads)
reg_loss = RegLoss(n_heads, 2)
# hmm, do I need separate optimizers for the heads?
# hmm, for some reason SGD gives degenerate results
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
train_loader = torch.utils.data.DataLoader(dataset["train"], batch_size=batch_size, shuffle=True)
target_loader = torch.utils.data.DataLoader(dataset["val"], batch_size=batch_size, shuffle=False)
val_lodaer = torch.utils.data.DataLoader(dataset["test"], batch_size=batch_size, shuffle=False)

In [None]:
# TODO: log to tensorboard

In [None]:
# train loop with accuracy logging
model.to(device)
for epoch in range(num_epochs):
    model.train()
    progress_bar = tqdm(zip(train_loader, cycle(target_loader)), total=len(train_loader), desc=f"train Epoch {epoch}")
    for train_batch, target_batch in progress_bar:
        optimizer.zero_grad()
        outputs = model(pixel_values=train_batch["pixel_values"].to(device))
        target_outputs = model(pixel_values=target_batch["pixel_values"].to(device))
        loss_value = loss(outputs.logits, train_batch["labels"].to(device))
        divdis_loss_value = dividis_loss(target_outputs.logits)
        reg_loss_value = reg_loss(outputs.logits, target_outputs.logits)
        total_loss_value = loss_value + divdis_loss_value * diversity_weight + reg_loss_value * reg_weight
        total_loss_value.backward()
        optimizer.step()
        progress_bar.set_postfix({"training_loss": f"{loss_value.item():.3f}",
                                  "divdis_loss": f"{divdis_loss_value.item():.5f}",
                                  "reg_loss": f"{reg_loss_value.item():.3f}",
                                  "total_loss": f"{total_loss_value.item():.3f}"})
    model.eval()
    correct = np.zeros(n_heads)
    correct_aux = np.zeros(n_heads)
    total = 0
    # TODO: evaluate on aux labels too
    with torch.no_grad():
        progress_bar = tqdm(val_lodaer, desc=f"val Epoch {epoch}")
        for batch in progress_bar:
            outputs = model(pixel_values=batch["pixel_values"].to(device)) # (batch, heads * classes)
            chunked_logits = torch.chunk(outputs.logits, n_heads, dim=-1) # [[batch, classes] * heads]
            predictions = torch.stack([logits.argmax(dim=-1) for logits in chunked_logits], dim=1) # (batch, heads)
            correct += (predictions == batch["labels"].unsqueeze(-1).to(device)).sum(dim=0).cpu().numpy()
            correct_aux += (predictions == batch["aux_labels"].unsqueeze(-1).to(device)).sum(dim=0).cpu().numpy()
            total += batch["labels"].shape[0]
    for i in range(n_heads):
        print(f"head {i}: Epoch {epoch}: Accuracy: {correct[i] / total} (aux: {correct_aux[i] / total})")

In [None]:
with torch.no_grad():
    progress_bar = tqdm(val_lodaer, desc=f"val Epoch {epoch}")
    for batch in progress_bar:
        outputs = model(pixel_values=batch["pixel_values"].to(device)) # (batch, heads * classes)
        chunked_logits = torch.chunk(outputs.logits, n_heads, dim=-1) # [[batch, classes] * heads]
        predictions = torch.stack([logits.argmax(dim=-1) for logits in chunked_logits], dim=1) # (batch, heads)
        correct += (predictions == batch["labels"].unsqueeze(-1).to(device)).sum(dim=0).cpu().numpy()
        correct_aux += (predictions == batch["aux_labels"].unsqueeze(-1).to(device)).sum(dim=0).cpu().numpy()
        total += batch["labels"].shape[0]
for i in range(n_heads):
    print(f"head {i}: Epoch {epoch}: Accuracy: {correct[i] / total} (aux: {correct_aux[i] / total})")