### Generating Examples

In [None]:
# for a in ["2", "4", "8", "16", "32", "64", "128"]:
#   for b in ["2", "4", "8", "16", "32", "64", "128"]:
#     if a !=b:
#       if os.path.exists("kl_diffs/ld" + a + "_vs_ld" + b + "/diff.pt"):
#         print(f"ld{a} vs ld{b} already exists, skipping")
#       else:
#         print(f"Running {a} vs {b}")
#         gen_kl_diff(m1_path=f"ld{a}_e9.pt", m2_path=f"ld{b}_e9.pt", embed=True, embed_model_name="google/vit-base-patch16-224")

In [None]:
from tqdm import tqdm
from torchvision.utils import save_image
from pathlib import Path
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import AutoFeatureExtractor, AutoModel, ViTImageProcessor

import torch
import os


def get_embeddings(model, extractor, imgs, device):
    inputs = extractor(
        images=[x.detach().cpu().numpy() for x in imgs], return_tensors="pt"
    ).to(device)
    model.to(device)
    embeds = model(**inputs).last_hidden_state[:, 0].to(device)
    return embeds

def gen_kl_diff(
    m1_path: str=None,
    m2_path: str=None,
    bsize: int=100,
    n_kl_samples: int=1000,
    n_ll_samples: int=100,
    embed: bool=True,
    embed_model_name: str="google/vit-base-patch16-224",
):
    with torch.no_grad():
        if embed:
            extractor = ViTImageProcessor.from_pretrained(embed_model_name, do_rescale=False)
            embed_model = AutoModel.from_pretrained(embed_model_name, output_hidden_states=True)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        m1 = torch.load(m1_path, weights_only=False)
        m1.to(device).eval()

        m2 = torch.load(
            os.path.join(m2_path),
            weights_only=False
        )
        m2.to(device).eval()
        print(device)

        # File Structure
        m1_name = os.path.basename(m1_path)[:-6]
        m2_name = os.path.basename(m2_path)[:-6]

        save_dir = 'kl_diffs/' + m1_name + '_vs_' + m2_name
        Path(save_dir).mkdir(parents=True, exist_ok=True)


        kl_diffs = []
        m1_full_samples = []
        m2_full_samples = []

        nbatches = n_kl_samples // bsize

        for i in tqdm(range(nbatches), total=nbatches):

            m1_samples, m1_m1_ll = m1.sample(bsize, device)
            m2_samples, m2_m2_ll = m2.sample(bsize, device)

            m1_m2_ll = m1.log_likelihood(m2_samples, n_samples=n_ll_samples)
            m2_m1_ll = m2.log_likelihood(m1_samples, n_samples=n_ll_samples)

            kl_diffs.append(-(m1_m1_ll + m1_m2_ll) + (m2_m1_ll + m2_m2_ll))

            if embed:
                m1_full_samples.append(get_embeddings(embed_model, extractor, m1_samples, device))
                m2_full_samples.append(get_embeddings(embed_model, extractor, m2_samples, device))
            else:
                m1_full_samples.append(m1_samples)
                m2_full_samples.append(m2_samples)

        kl_diffs = torch.concatenate(kl_diffs)
        m1_samples = torch.concatenate(m1_full_samples)
        m2_samples = torch.concatenate(m2_full_samples)
        samples = torch.concatenate((m1_samples, m2_samples), axis=-1)
        torch.save(kl_diffs, save_dir + '/diff.pt')
        torch.save(samples, save_dir + '/img.pt')

gen_kl_diff(m1_path="ld2_e9.pt", m2_path="ld4_e9.pt", embed=True, embed_model_name="google/vit-base-patch16-224")

In [None]:
# !zip -r file.zip kl_diffs

In [None]:
import torch
import os
import torch.nn as nn
from torch.utils.data import Dataset

class MetaMLP(nn.Module):

    def __init__(
        self,
        hidden_dims=None,
        activation='relu',
    ):
        super(MetaMLP, self).__init__()
        if hidden_dims is None:
            hidden_dims = [1536, 1000, 1000, 1000, 1]

        if activation == 'relu':
            act_fn = nn.ReLU()

        modules = []
        for i in range(len(hidden_dims) - 1):
            modules.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
            modules.append(nn.ReLU())

        self.model = nn.Sequential(*modules[:-1])

    def forward(self, x):
        x = self.model(x)
        return x

class MetaDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.n_per_set = len(self.data[0])

    def __getitem__(self, idx):
        idx = idx // self.n_per_set

        return self.data[idx], self.labels[idx].float() / 1000.

    def __len__(self):
        return sum([len(d) for d in self.data])


def load_data(files, device):
    data = []
    labels = []

    for l in files:

        lab = "kl_diffs/" +l + "/diff.pt"
        da =  "kl_diffs/" + l + "/img.pt"

        a = torch.load(lab, weights_only=True, map_location=torch.device("cpu")).to(device)
        b = torch.load(da, weights_only=True, map_location=torch.device("cpu")).to(device)

        labels.append(a)
        data.append(b)

    return torch.cat(data, dim=0), torch.cat(labels, dim=0)

def split(images, labels, split_ratio=0.8, subset = 1):

    if subset != 1:
        print("WARNING: Training on a subset")
        subset_samples = int(labels.size(0)*subset)
        indices = torch.randperm(labels.size(0))
        images = images[indices[:subset_samples]]
        labels = labels[indices[:subset_samples]]

    # Calculate the number of samples to select
    num_samples = int(labels.size(0)*subset * split_ratio)

    # Randomly select samples for the first split
    indices = torch.randperm(labels.size(0))
    selected_images = images[indices[:num_samples]]
    selected_labels = labels[indices[:num_samples]]

    remaining_images = images[indices[num_samples:]]
    remaining_labels = labels[indices[num_samples:]]

    return (selected_images, selected_labels), (remaining_images, remaining_labels)


### Training

In [None]:
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ConstantLR, LinearLR, SequentialLR


def train(
    bsize: int=256,
    nepochs: int=10,
    lr: float=1e-3,
):

    model = MetaMLP()
    model.cuda()

    images, labels = load_data(files = os.listdir("kl_diffs"), device ='cuda')
    train, test = split(images, labels, subset=0.5)

    trainset = MetaDataset(train[0], train[1])
    testset = MetaDataset(test[0], test[1])

    trainloader = DataLoader(trainset, batch_size=bsize, shuffle=True)
    testloader = DataLoader(testset, batch_size=bsize, shuffle=False)

    batch_per_epoch = len(trainloader)
    optim = Adam(model.parameters(), lr=lr)
    scheduler = SequentialLR(
        optim,
        schedulers=[
            LinearLR(optim, start_factor=1e-3, end_factor=1, total_iters=batch_per_epoch),
            ConstantLR(optim, factor=1, total_iters=(nepochs - 2) * batch_per_epoch),
            LinearLR(optim, start_factor=1, end_factor=1e-3, total_iters=batch_per_epoch),
        ],
        milestones=[batch_per_epoch, (nepochs - 1) * batch_per_epoch],
    )
    optim.zero_grad()

    for epoch in range(nepochs):

        model.train()

        train_loss = 0
        train_acc = 0
        for i, (imgs, labels) in tqdm(enumerate(trainloader), total=len(trainloader)):
            out = model(imgs).squeeze()
            loss = ((out - labels)**2).mean()
            #loss = F.binary_cross_entropy_with_logits(out, labels)
            loss.backward()
            optim.step()
            optim.zero_grad()
            scheduler.step()
            train_loss += loss.item()
            train_acc += ((out > 0) == (labels > 0)).float().mean().item()

        train_loss /= len(trainloader)
        train_acc /= len(trainloader)

        print(f"Train Loss: {train_loss:.4f}\tAccuracy: {train_acc:.4f}")

        model.eval()

        test_loss = 0
        test_acc = 0
        for i, (imgs, labels) in tqdm(enumerate(testloader), total=len(testloader)):
            out = model(imgs).squeeze()
            loss = ((out - labels)**2).mean()
            #loss = F.binary_cross_entropy_with_logits(out, labels)
            test_loss += loss.item()
            test_acc += ((out > 0) == (labels > 0)).float().mean().item()

        test_loss /= len(testloader)
        test_acc /= len(testloader)

        print(f"Test Loss: {test_loss:.4f}\tAccuracy: {test_acc:.4f}")

train()