In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
from pathlib import Path
from rae import PROJECT_ROOT
import numpy as np
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch import nn

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

In [None]:
from rae.modules.attention import RelativeAttention, AttentionOutput

In [None]:
device: str = "cuda"
fine_grained: bool = False
target_key: str = f"{'fine' if fine_grained else 'coarse'}_label"
data_key: str = "image"
dataset_name: str = "cifar100"
num_anchors: int = 768
train_perc: float = 1

In [None]:
from datasets import load_dataset, ClassLabel


def get_dataset(split: str, perc: float):
    seed_everything(42)
    assert 0 < perc <= 1
    dataset = load_dataset(dataset_name)[split]

    # Select a random subset
    if perc != 1:
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        indices = indices[: int(len(indices) * perc)]
        dataset = dataset.select(indices)

    def clean_sample(sample):
        return sample

    #     dataset = dataset.map(clean_sample)

    return dataset

In [None]:
train_dataset = get_dataset(split="train", perc=train_perc)

In [None]:
class2idx = train_dataset.features[target_key].str2int
num_classes = train_dataset.features[target_key].num_classes
num_classes, len(train_dataset)

In [None]:
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, AutoModel


def load_transformer(transformer_name):
    transformer = timm.create_model(transformer_name, pretrained=True, num_classes=0)
    return transformer.requires_grad_(False).eval()

In [None]:
test_dataset = get_dataset(split="test", perc=train_perc)
len(test_dataset)

In [None]:
@torch.no_grad()
def call_transformer(batch, transformer):
    #     batch["encoding"] = batch["encoding"].to(device)
    sample_encodings = transformer(batch["encoding"].to(device))
    #     hidden = sample_encodings["hidden_states"][-1]
    #     assert hidden.size(-1) == hidden.size(-2), hidden.size()
    #     print(sample_encodings.shape)
    return {"hidden": sample_encodings}

In [None]:
from typing import *


def get_anchors(num_anchors: int, seed: int):
    seed_everything(seed)
    assert num_anchors <= len(train_dataset)
    anchor_idxs = list(range(len(train_dataset)))
    random.shuffle(anchor_idxs)
    anchor_idxs = anchor_idxs[:num_anchors]
    return anchor_idxs

In [None]:
transformer_name = "vit_small_patch16_224"

In [None]:
transformer = load_transformer(transformer_name=transformer_name)
config = resolve_data_config({}, model=transformer)
transform = create_transform(**config)

In [None]:
def relative_projection(x, anchors):
    x = F.normalize(x, p=2, dim=-1)
    anchors = F.normalize(anchors, p=2, dim=-1)
    return torch.einsum("bm, am -> ba", x, anchors)

In [None]:
def collate_fn(batch, feature_extractor, transform):
    return {"encoding": torch.stack([transform(sample["img"].convert("RGB")) for sample in batch], dim=0)}

In [None]:
def get_latents(dataloader, anchors, split: str, transformer) -> Dict[str, torch.Tensor]:
    absolute_latents: List = []
    relative_latents: List = []
    #     logits_latents: List = []

    transformer = transformer.to(device)
    for batch in tqdm(dataloader, desc=f"[{split}] Computing latents"):
        with torch.no_grad():
            transformer_out = call_transformer(batch=batch, transformer=transformer)

            #             logits_latents.append(transformer_out["logits"].cpu())
            absolute_latents.append(transformer_out["hidden"].cpu())

            if anchors is not None:
                batch_rel_latents = relative_projection(x=transformer_out["hidden"], anchors=anchors)
                relative_latents.append(batch_rel_latents.cpu())

    absolute_latents: torch.Tensor = torch.cat(absolute_latents, dim=0).cpu()
    #     logits_latents: torch.Tensor = torch.cat(logits_latents, dim=0).cpu()
    relative_latents: torch.Tensor = (
        torch.cat(relative_latents, dim=0).cpu() if len(relative_latents) > 0 else relative_latents
    )

    transformer = transformer.cpu()
    return {
        "absolute": absolute_latents,
        "relative": relative_latents,
    }

In [None]:
absolute_latents = {
    split: get_latents(
        dataloader=DataLoader(
            train_dataset if split == "train" else test_dataset,
            num_workers=4,
            pin_memory=True,
            collate_fn=partial(collate_fn, feature_extractor=None, transform=transform),
            batch_size=32,
        ),
        split=f"{split}/{transformer_name}",
        anchors=None,
        transformer=transformer,
    )["absolute"]
    for split in ("train", "test")
}
absolute_latents

In [None]:
absolute_latents["train"].shape, absolute_latents["test"].shape

In [None]:
encoding_dim: int = absolute_latents["train"].size(-1)
encoding_dim

In [None]:
sweep = {
    "seed": list(range(3)),
    "num_anchors": list(range(1, 50, 2)) + list(range(50, 100, 5)) + list(range(100, 1000, 20))
    # "seed_index": [0],
    #     "num_epochs": [10, 30, 50],
    #     "in_channels": [num_anchors],
    # "out_channels": [10, 32, 64],
    #     "out_channels": [num_anchors],
    #     "num_layers": [64, 32],
    #     "dropout": [0.1, 0.5],
    # "hidden_fn": [torch.relu, torch.tanh, torch.sigmoid],
    # "conv_fn": [torch.relu, torch.tanh, torch.sigmoid],
    #     "hidden_fn": [torch.nn.ReLU(), torch.nn.Tanh()],
    #     "conv_fn": [torch.nn.ReLU(), torch.nn.Tanh()],
    #     "optimizer": [torch.optim.Adam, torch.optim.SGD],
    #     "lr": [0.01, 0.02],
    #     "encoder": [
    #         (
    #             "GCN2Conv",
    #             functools.partial(
    #                 encoder_factory,
    #                 encoder_type="GCN2Conv",
    #                 **dict(alpha=0.1, theta=0.5, shared_weights=True, normalize=False),
    #             ),
    #         ),
    #         # ("GCNConv", functools.partial(encoder_factory, encoder_type="GCNConv")),
    #         # ("GATConv", functools.partial(encoder_factory, encoder_type="GATConv")),
    #         ("GINConv", functools.partial(encoder_factory, encoder_type="GINConv")),
    #     ],
}
from sklearn.model_selection import ParameterGrid

experiments = ParameterGrid(sweep)
experiments

In [None]:
latent_normalize: bool = True

In [None]:
from sklearn.model_selection import ParameterSampler, ParameterGrid
import logging
from tqdm import tqdm
from pprint import pprint
from rae.utils.utils import to_device
from pytorch_lightning.utilities.seed import log as seed_log
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import pandas as pd


class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


stats = {x: [] for x in ("experiment", "epoch", "loss", "val_fscore", "val_acc", "num_anchors")}

# for i, experiment in enumerate(pbar := tqdm(ParameterSampler(sweep, n_iter=50, random_state=42), desc="Experiment")):
for i, experiment in enumerate(pbar := tqdm(ParameterGrid(sweep), desc="Experiment", leave=True)):
    seed: int = experiment["seed"]
    temp_log_level = seed_log.getEffectiveLevel()
    seed_log.setLevel(logging.ERROR)
    seed_everything(seed)
    seed_log.setLevel(temp_log_level)

    num_anchors: int = experiment["num_anchors"]
    anchor_idxs = get_anchors(num_anchors=num_anchors, seed=seed)
    anchors = absolute_latents["train"][anchor_idxs].to(device)

    train_data = absolute_latents["train"].to(device)
    train_data = relative_projection(x=train_data, anchors=anchors)

    test_data = absolute_latents["test"].to(device)
    test_data = relative_projection(x=test_data, anchors=anchors)
    if latent_normalize:
        train_data = F.normalize(train_data, p=2, dim=-1)
        test_data = F.normalize(test_data, p=2, dim=-1)

    #     tensor_train_dataset = TensorDataset(train_data, torch.as_tensor(train_dataset[target_key]))
    #     train_loader = DataLoader(tensor_train_dataset, batch_size=32, pin_memory=True, shuffle=True, num_workers=4)
    train_y = torch.as_tensor(train_dataset[target_key]).to(device)

    test_y = np.asarray(test_dataset[target_key])

    model = nn.Sequential(
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=num_anchors),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(in_features=num_anchors, out_features=encoding_dim),
        nn.Tanh(),
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=encoding_dim),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(in_features=encoding_dim, out_features=train_dataset.features[target_key].num_classes),
        #         nn.Tanh(),
        #         Lambda(lambda x: x.permute(1, 0)),
        #         nn.InstanceNorm1d(num_features=num_anchors),
        #         Lambda(lambda x: x.permute(1, 0)),
        #         nn.Linear(in_features=num_anchors, out_features=train_dataset.features[target_key].num_classes),
        #         nn.ReLU(),
    ).to(device)

    opt = Adam(model.parameters(), lr=1e-3)
    loss_fn = CrossEntropyLoss()

    for epoch in tqdm(range(10), leave=False, desc="epoch"):
        model.train()

        pred_y = model(train_data)
        loss = loss_fn(pred_y, train_y)
        loss.backward()
        opt.step()
        opt.zero_grad()

        model.eval()
        with torch.no_grad():
            test_preds = model(test_data).softmax(-1).argmax(-1).cpu().numpy()

        precision, recall, fscore, _ = precision_recall_fscore_support(test_y, test_preds, average="weighted")
        acc = accuracy_score(test_y, test_preds)

        loss = loss.detach().cpu().numpy()

        stats["experiment"].append(i)
        stats["epoch"].append(epoch)
        stats["loss"].append(loss)
        stats["val_fscore"].append(fscore)
        stats["val_acc"].append(acc)
        stats["num_anchors"].append(num_anchors)

        pbar.set_description(f"Epoch: {epoch}, Loss: {loss:.4f}" f"Val F1: {fscore:.4f} num_anchors: {num_anchors}")

    model = model.cpu().eval()

stats = pd.DataFrame(stats)
stats.to_csv(
    PROJECT_ROOT / "experiments" / "sec:anchor-analysis" / f"{dataset_name}_data_manifold_stats_anchors_analysis.tsv",
    sep="\t",
)

In [None]:
transformer_name2hidden_dim = {
    transformer_name: latents["absolute"][0].shape[0] for transformer_name, latents in transformer2train_latents.items()
}
transformer_name2hidden_dim

In [None]:
SEEDS = list(range(3))
train_classifiers = {
    seed: {
        embedding_type: {
            transformer_name: fit(
                train_latents[embedding_type],
                train_dataset[target_key],
                seed=seed,
                normalize=latent_normalize,
                hidden_dim=num_anchors
                if embedding_type == "relative"
                else transformer_name2hidden_dim[transformer_name],
            )
            for transformer_name, train_latents in tqdm(
                transformer2train_latents.items(), leave=False, desc="transformer"
            )
        }
        for embedding_type in tqdm(["absolute", "relative"], leave=False, desc="embedding_type")
    }
    for seed in tqdm(SEEDS, leave=False, desc="seed")
}

In [None]:
train_classifiers

In [None]:
from sklearn.metrics import precision_recall_fscore_support, mean_absolute_error
import itertools

numeric_results = {
    "seed": [],
    "embed_type": [],
    "train_model": [],
    "test_model": [],
    "precision": [],
    "recall": [],
    "fscore": [],
    "stitched": [],
}
for seed, embed_type2transformer2classifier in train_classifiers.items():
    for embed_type, transformer2classifier in embed_type2transformer2classifier.items():
        for (transformer_name1, classifier1), (transformer_name2, classifier2) in itertools.product(
            transformer2classifier.items(), repeat=2
        ):
            if embed_type == "absolute" and (
                transformer_name2hidden_dim[transformer_name1] != transformer_name2hidden_dim[transformer_name2]
            ):
                precision = recall = fscore = np.nan
            else:
                test_latents = transformer2test_latents[transformer_name1][embed_type]
                if latent_normalize:
                    test_latents = F.normalize(test_latents, p=2, dim=-1)
                preds = classifier2(test_latents)
                test_y = np.array(test_dataset[target_key])

                precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average="weighted")
            numeric_results["embed_type"].append(embed_type)
            numeric_results["train_model"].append(transformer_name1)
            numeric_results["test_model"].append(transformer_name2)
            numeric_results["precision"].append(precision)
            numeric_results["recall"].append(recall)
            numeric_results["fscore"].append(fscore)
            numeric_results["stitched"].append(transformer_name1 != transformer_name2)
            numeric_results["seed"].append(seed)


import pandas as pd

pd.options.display.max_columns = None
pd.options.display.max_rows = None
df = pd.DataFrame(numeric_results)
df.to_csv(
    f"vision_transformer-stitching-{dataset_name}-{'fine' if fine_grained else 'coarse'}-{train_perc}.tsv",
    sep="\t",
)
df = df[df.train_model != "regnetx_002"][df.test_model != "regnetx_002"][df.train_model != "rexnet_100"][
    df.test_model != "rexnet_100"
]
df = df.groupby(
    [
        "embed_type",
        "stitched",
        "train_model",
        "test_model",
    ]
).agg([np.mean])
df

In [None]:
full_df = pd.read_csv(
    f"vision_transformer-stitching-{dataset_name}-{'fine' if fine_grained else 'coarse'}-{train_perc}.tsv",
    sep="\t",
    index_col=0,
)
# full_df = full_df[full_df.train_model != "regnetx_002"][full_df.test_model != "regnetx_002"][
#     full_df.train_model != "rexnet_100"
# ][full_df.test_model != "rexnet_100"]

df = full_df.groupby(
    [
        "embed_type",
        "stitched",
        "train_model",
        "test_model",
    ]
).agg([np.mean, "count"])
df

In [None]:
full_df.drop(columns=["stitched", "seed", "precision", "recall"]).groupby(
    ["embed_type", "train_model", "test_model"]
).agg([np.mean]).round(3)