In [1]:
import os 
os.chdir("/Users/oliverdaniels-koch/projects/elk-experiments/")

In [None]:
import torch 
import torch.nn as nn

# Explore Waterbirds

In [2]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pandas as pd

dataset_dir = "datasets/waterbird_complete95_forest2water2"
# load metadata.csv
metadata = pd.read_csv(dataset_dir + "/metadata.csv")
# split into "train", "val", "test" by using "split" (0, 1, 2)
train = metadata[metadata.split == 0]
val = metadata[metadata.split == 1]
test = metadata[metadata.split == 2]
# remove "split" column
train = train.drop(columns=["split"])
val = val.drop(columns=["split"])
test = test.drop(columns=["split"])
# write new csv
train.to_csv(dataset_dir + "/train_metadata.csv", index=False)
val.to_csv(dataset_dir + "/val_metadata.csv", index=False)
test.to_csv(dataset_dir + "/test_metadata.csv", index=False)
# create split dictionary
split_dict = {"train": "train_metadata.csv", "val": "val_metadata.csv", "test": "test_metadata.csv"}

In [4]:
dataset = load_dataset("datasets/waterbird_complete95_forest2water2", data_files=split_dict)

Generating train split: 4795 examples [00:00, 185660.63 examples/s]
Generating val split: 1199 examples [00:00, 119495.56 examples/s]
Generating test split: 5794 examples [00:00, 574347.64 examples/s]


In [5]:
dataset 

DatasetDict({
    train: Dataset({
        features: ['img_id', 'img_filename', 'y', 'place', 'place_filename'],
        num_rows: 4795
    })
    val: Dataset({
        features: ['img_id', 'img_filename', 'y', 'place', 'place_filename'],
        num_rows: 1199
    })
    test: Dataset({
        features: ['img_id', 'img_filename', 'y', 'place', 'place_filename'],
        num_rows: 5794
    })
})

In [6]:
# add image
from PIL import Image
def read_image(img_filename):
    image = Image.open(img_filename)
    image = image.convert("RGB")
    return {"image": image}
dataset = dataset.map(lambda x: read_image(dataset_dir + "/" + x["img_filename"]))

Map: 100%|██████████| 4795/4795 [06:30<00:00, 12.27 examples/s] 
Map: 100%|██████████| 1199/1199 [01:05<00:00, 18.44 examples/s]
Map: 100%|██████████| 5794/5794 [07:31<00:00,  5.59 examples/s] 

In [None]:
# remove img_filename and img_id 
dataset = dataset.remove_columns(["img_filename", "img_id"])

# Load Resnet 50 from HF

In [None]:
# load resnset50 from huggingface
from transformers import AutoImageProcessor, ResNetForImageClassification
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=0) # num labels is 0 because we are using it as a feature extractor

In [None]:
image = dataset["train"][0]["image"]
image

In [None]:
inputs = processor(image, return_tensors="pt")

In [None]:
inputs.keys()

In [None]:
with torch.no_grad():
    out = model(**inputs, output_hidden_states=True)


In [None]:
out["logits"].shape

In [None]:
out["hidden_states"][-1].shape

# Define DivDis Loss Function

Review: Mutual Information

Amount of information received about one random variable from observing the other random variable

determines how different joint distribution is from product of marginal distributions

Expected value of the pointwise mutual information 

KL Divergence between joint distribution and project of the marginals 

(Recall KLDiv(P, Q) = sum(P log(P/Q)))

So MI(X, Y) = KLDiv(P(X,Y)||P_X * P_Y) = sum(P(X,Y)log(P(X,Y)/P_X * P_Y))


ok I don't understand this

seems like what they're doing is saying 
we estimate the distribution over classes of a certain head by just averaging the probabilities over a batch, then computing MI
on that 

I guess that's fine - say we have a balanced dataset of cows and camels on grass and sand 

the cow/camel classifier will be 50/50, as will the grass sand

but say within a batch, there happen to be lots of cows on sand
the cow/camel with return 1 a lot (say 90,10), whereas the 

intuitively, we would want to take this element-wise

for input x, and heads h1, h2, take 

eh whatever, I'll just reimplement it for now...

In [None]:
# from https://github.com/yoonholee/DivDis/blob/main/divdis.py
# TODO: understand this code
import torch
from einops import rearrange
from torch import nn

def to_probs(logits, heads):
    """
    Converts logits to probabilities.
    Input must have shape [batch_size, heads * classes].
    Output will have shape [batch_size, heads, classes].
    """

    B, N = logits.shape
    if N == heads:  # Binary classification; each head outputs a single scalar.
        preds = logits.sigmoid().unsqueeze(-1)
        probs = torch.cat([preds, 1 - preds], dim=-1)
    else:
        logits_chunked = torch.chunk(logits, heads, dim=-1)
        probs = torch.stack(logits_chunked, dim=1).softmax(-1)
    B, H, D = probs.shape
    assert H == heads
    return probs


class DivDisLoss(nn.Module):
    """Computes pairwise repulsion losses for DivDis.

    Args:
        logits (torch.Tensor): Input logits with shape [BATCH_SIZE, HEADS * DIM].
        heads (int): Number of heads.
        mode (str): DIVE loss mode. One of {pair_mi, total_correlation, pair_l1}.
    """

    def __init__(self, heads, mode="mi", reduction="mean"):
        super().__init__()
        self.heads = heads
        self.mode = mode
        self.reduction = reduction

    def forward(self, logits):
        heads, mode, reduction = self.heads, self.mode, self.reduction
        probs = to_probs(logits, heads)

        if mode == "mi":  # This was used in the paper
            marginal_p = probs.mean(dim=0)  # H, D
            marginal_p = torch.einsum(
                "hd,ge->hgde", marginal_p, marginal_p
            )  # H, H, D, D
            marginal_p = rearrange(marginal_p, "h g d e -> (h g) (d e)")  # H^2, D^2

            joint_p = torch.einsum("bhd,bge->bhgde", probs, probs).mean(
                dim=0
            )  # H, H, D, D
            joint_p = rearrange(joint_p, "h g d e -> (h g) (d e)")  # H^2, D^2

            # Compute pairwise mutual information = KL(P_XY | P_X x P_Y)
            # Equivalent to: F.kl_div(marginal_p.log(), joint_p, reduction="none")
            kl_computed = joint_p * (joint_p.log() - marginal_p.log())
            kl_computed = kl_computed.sum(dim=-1)
            kl_grid = rearrange(kl_computed, "(h g) -> h g", h=heads)
            repulsion_grid = -kl_grid
        elif mode == "l1":
            dists = (probs.unsqueeze(1) - probs.unsqueeze(2)).abs()
            dists = dists.sum(dim=-1).mean(dim=0)
            repulsion_grid = dists
        else:
            raise ValueError(f"{mode=} not implemented!")

        if reduction == "mean":  # This was used in the paper
            repulsion_grid = torch.triu(repulsion_grid, diagonal=1)
            repulsions = repulsion_grid[repulsion_grid.nonzero(as_tuple=True)]
            repulsion_loss = -repulsions.mean()
        elif reduction == "min_each":
            repulsion_grid = torch.triu(repulsion_grid, diagonal=1) + torch.tril(
                repulsion_grid, diagonal=-1
            )
            rows = [r for r in repulsion_grid]
            row_mins = [row[row.nonzero(as_tuple=True)].min() for row in rows]
            repulsion_loss = -torch.stack(row_mins).mean()
        else:
            raise ValueError(f"{reduction=} not implemented!")

        return repulsion_loss

In [None]:
from torch.distributions import Categorical, kl_divergence

class RegLoss(nn.Module):

    def forward(self, source_logits, target_logits):
        source_probs = torch.sigmoid(source_logits).mean([0, 1])
        target_probs = torch.sigmoid(target_logits).mean(1)
        dist_source = Categorical(probs=source_probs)
        dist_target = Categorical(probs=target_probs)
        reg_loss = kl_divergence(dist_source, dist_target).mean()
        return reg_loss

# Dataset Partition

In [None]:
sum([i == j for i, j in zip(dataset["train"]["y"], dataset["train"]["place"])]) / len(dataset["train"])

In [None]:
# same computation for validation
sum([i == j for i, j in zip(dataset["val"]["y"], dataset["val"]["place"])]) / len(dataset["val"])

In [None]:
# same computation for test
sum([i == j for i, j in zip(dataset["test"]["y"], dataset["test"]["place"])]) / len(dataset["test"])

In [None]:
# balance of labels 
sum([i == 0 for i in dataset["train"]["y"]]) / len(dataset["train"]) # use this to set auxiliary KL div loss

# Probe Heads Module

In [None]:
import torch.nn as nn
class Heads(nn.Module):
    def __init__(self, base_model, hidden_size, n_heads, output_fn=None):
        super().__init__()
        self.base_model = base_model
        self.heads = nn.Linear(hidden_size, n_heads)
        self.output_fn = output_fn
    
    def forward(self, x):
        x = self.base_model(x)
        if self.output_fn is not None:
            x = self.output_fn(x)
        x = self.heads(x)
        return x

# Train Loop

In [None]:
out[0].shape

In [None]:
# data preprocessing


In [None]:
sum([i == j for i, j in zip(dataset["train"]["y"], dataset["train"]["place"])]) / len(dataset["train"])

In [None]:
# hparams
lambda_1 = 10 
lambda_2 = 10
epochs = 4 
batch_size = 16 
learning_rate = 1e-3
weight_decay = 1e-4
gamma = 1e-1
n_heads = 2

In [None]:
import torch

In [None]:
# remove all instances where place != y on the training set 
dataset["train"] = dataset["train"].filter(lambda x: x["y"] == x["place"])

# heads
heads = Heads(model, 2048, n_heads, output_fn=lambda x: x[0])

# loss functions
loss_func = nn.BCEWithLogitsLoss()
divdis_loss_func = DivDisLoss(heads=2, mode="mi", reduction="mean")
reg_loss_func = RegLoss()

# optimizer
optimizer = torch.optim.SGD(heads.parameters(), lr=learning_rate, weight_decay=weight_decay)

# data loaders
source_loader = torch.utils.data.DataLoader(dataset["train"], batch_size=batch_size, shuffle=True)
target_loader = torch.utils.data.DataLoader(dataset["val"], batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset["test"], batch_size=batch_size, shuffle=True)

In [None]:
from torch.distributions import Categorical, kl_divergence

In [None]:
# train loop 
for epoch in range(epochs):
    model.train()
    for i, (source, target) in enumerate(zip(source_loader, target_loader)):
        optimizer.zero_grad()
        # compute source and target logits
        source_logits = heads(source["image"]) # batch, n_heads
        target_logits = heads(target["image"]) # batch, n_heads
        # compute source loss (cross entropy)
        source_loss = loss_func(source_logits, source["y"])
        # compute divdis loss
        divdis_loss = divdis_loss_func(target_logits)
        # compute regularization loss
        reg_loss = reg_loss_func(source_logits, target_logits)
        loss = source_loss + lambda_1 * divdis_loss + lambda_2 * reg_loss
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}")

In [None]:
# what framework am I going to use?

# initialize number of heads 
# initialize optimization or whatever 
# divide train set into train and

In [None]:
# need to define mi loss function 


In [None]:
# ok, now maybe I'll just use the huggingface trainer?