In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
from pathlib import Path
from rae import PROJECT_ROOT
import numpy as np
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch import nn

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

In [None]:
from rae.modules.attention import RelativeAttention, AttentionOutput

In [None]:
device: str = "cuda"
target_key: str = "label"
data_key: str = "image"
dataset_name: str = "imagenet-1k"
num_anchors: int = 768
train_perc: float = 0.2

In [None]:
from datasets import load_dataset, ClassLabel


def get_dataset(split: str, perc: float):
    seed_everything(42)
    assert 0 < perc <= 1
    dataset = load_dataset(dataset_name)[split]

    # Select a random subset
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    indices = indices[: int(len(indices) * perc)]
    dataset = dataset.select(indices)

    def clean_sample(sample):
        return sample

    #     dataset = dataset.map(clean_sample)

    return dataset

In [None]:
train_dataset = get_dataset(split="train", perc=train_perc)

In [None]:
class2idx = train_dataset.features[target_key].str2int
train_dataset.features[target_key].num_classes, len(train_dataset)

In [None]:
class2idx

In [None]:
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

In [None]:
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, AutoModel


def load_transformer(transformer_name):
    transformer = timm.create_model(transformer_name, pretrained=True, num_classes=0)
    return transformer.requires_grad_(False).eval()

In [None]:
test_dataset = get_dataset(split="validation", perc=train_perc)
len(test_dataset)

In [None]:
@torch.no_grad()
def call_transformer(batch, transformer):
    #     batch["encoding"] = batch["encoding"].to(device)
    sample_encodings = transformer(batch["encoding"].to(device))
    #     hidden = sample_encodings["hidden_states"][-1]
    #     assert hidden.size(-1) == hidden.size(-2), hidden.size()
    #     print(sample_encodings.shape)
    return {"hidden": sample_encodings}


#     hidden = F.avg_pool2d(hidden, hidden.size(-1))

#     return {"hidden": hidden[:, 0, :].flatten(1).squeeze(), "logits": sample_encodings["logits"]}

# TODO: aggregation mode
# result = []
# for sample_encoding, sample_mask in zip(sample_encodings, batch["mask"]):
#     result.append(sample_encoding[sample_mask].mean(dim=0))

# return torch.stack(result, dim=0)
#     return sample_encodings[:, 0, :]  # CLS

In [None]:
# t = load_transformer(transformer_name="vit_base_patch16_224")

In [None]:
# config = resolve_data_config({}, model=t)
# transform = create_transform(**config)
# call_transformer(collate_fn(train_dataset.select(range(2)), None, transform), t.to(device))

In [None]:
from typing import *


assert num_anchors <= len(train_dataset)

seed_everything(42)
anchor_idxs = list(range(len(train_dataset)))
random.shuffle(anchor_idxs)
anchor_idxs = anchor_idxs[:num_anchors]

anchor_dataset = train_dataset.select(anchor_idxs)
len(anchor_dataset)

In [None]:
transformer_names = list(
    {
        #     "google/vit-base-patch16-224",
        "vit_base_patch16_224",
        "vit_small_patch16_224",
        "vit_base_resnet50_384",
        "rexnet_100",
        #         "regnetx_002"
        #     "nvidia/mit-b0",
        #     "nvidia/mit-b2",
        #     "nvidia/mit-b3",
        #     "facebook/vit-mae-base"
    }
)

In [None]:
# relative_projection = RelativeAttention(
#     n_anchors=num_anchors,
#     normalization_mode="l2",
#     similarity_mode="inner",
#     values_mode="similarities",
#     n_classes=train_dataset.features[target_key].num_classes,
#     output_normalization_mode=None,
# ).to(device)

In [None]:
def relative_projection(x, anchors):
    x = F.normalize(x, p=2, dim=-1)
    anchors = F.normalize(anchors, p=2, dim=-1)
    return torch.einsum("bm, am -> ba", x, anchors)

In [None]:
# dummy_x = torch.randn(32, 512, 16, 16)
# dummy_anchors = torch.randn(42, 512)
# relative_projection(x=dummy_x, anchors=dummy_anchors).shape

In [None]:
def collate_fn(batch, feature_extractor, transform):
    #     encoding = feature_extractor(
    #         [sample[data_key] for sample in batch],
    #         return_tensors="pt",
    #     )
    #     encoding = {"pixel_values" : torch.stack([transform(sample['image'].convert("RGB")) for sample in batch], dim=0)}
    # mask = encoding["attention_mask"] * encoding["special_tokens_mask"].bool().logical_not()
    # return {"encoding": encoding, "mask": mask.bool()}
    return {"encoding": torch.stack([transform(sample["image"].convert("RGB")) for sample in batch], dim=0)}

In [None]:
def get_latents(dataloader, anchors, split: str, transformer) -> Dict[str, torch.Tensor]:
    absolute_latents: List = []
    relative_latents: List = []
    #     logits_latents: List = []

    transformer = transformer.to(device)
    for batch in tqdm(dataloader, desc=f"[{split}] Computing latents"):
        with torch.no_grad():
            transformer_out = call_transformer(batch=batch, transformer=transformer)

            #             logits_latents.append(transformer_out["logits"].cpu())
            absolute_latents.append(transformer_out["hidden"].cpu())

            if anchors is not None:
                batch_rel_latents = relative_projection(x=transformer_out["hidden"], anchors=anchors)
                relative_latents.append(batch_rel_latents.cpu())

    absolute_latents: torch.Tensor = torch.cat(absolute_latents, dim=0).cpu()
    #     logits_latents: torch.Tensor = torch.cat(logits_latents, dim=0).cpu()
    relative_latents: torch.Tensor = (
        torch.cat(relative_latents, dim=0).cpu() if len(relative_latents) > 0 else relative_latents
    )

    transformer = transformer.cpu()
    return {
        "absolute": absolute_latents,
        "relative": relative_latents,
        #         "logits": logits_latents
    }

In [None]:
from rae import PROJECT_ROOT

LATENTS_DIR: Path = PROJECT_ROOT / "data" / "latents" / "imagenet" / str(train_perc)
LATENTS_DIR.mkdir(exist_ok=True, parents=True)

In [None]:
def load_latents(split: str, transformer_names: Sequence[str]):
    transformer2latents = {}

    for transformer_name in transformer_names:
        transformer_path = LATENTS_DIR / split / f"{transformer_name.replace('/', '-')}.pt"
        if transformer_path.exists():
            transformer2latents[transformer_name] = torch.load(transformer_path)

    return transformer2latents

In [None]:
from functools import partial

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    Resize,
    ToTensor,
)


def encode_latents(transformer_names: Sequence[str], dataset, transformer_name2latents, split: str):
    for transformer_name in transformer_names:
        transformer = load_transformer(transformer_name=transformer_name)
        config = resolve_data_config({}, model=transformer)
        transform = create_transform(**config)
        transformer_name2latents[transformer_name] = {
            "anchors_latents": (
                anchors_latents := get_latents(
                    dataloader=DataLoader(
                        anchor_dataset,
                        num_workers=4,
                        pin_memory=True,
                        collate_fn=partial(collate_fn, feature_extractor=None, transform=transform),
                        batch_size=32,
                    ),
                    split=f"{transformer_name}, anchor, {split}",
                    anchors=None,
                    transformer=transformer,
                )["absolute"]
            ),
            **get_latents(
                dataloader=DataLoader(
                    dataset,
                    num_workers=4,
                    pin_memory=True,
                    collate_fn=partial(collate_fn, feature_extractor=None, transform=transform),
                    batch_size=32,
                ),
                split=f"{split}/{transformer_name}",
                anchors=anchors_latents.to(device),
                transformer=transformer,
            ),
        }
        # Save latents
        if CACHE_LATENTS:
            transformer_path = LATENTS_DIR / split / f"{transformer_name.replace('/', '-')}.pt"
            transformer_path.parent.mkdir(exist_ok=True, parents=True)
            torch.save(transformer_name2latents[transformer_name], transformer_path)

In [None]:
# Compute test latents

FORCE_RECOMPUTE: bool = False
CACHE_LATENTS: bool = True

transformer2test_latents: Dict[str, Mapping[str, torch.Tensor]] = load_latents(
    split="test", transformer_names=transformer_names
)
missing_transformers = (
    transformer_names
    if FORCE_RECOMPUTE
    else [t_name for t_name in transformer_names if t_name not in transformer2test_latents]
)
encode_latents(
    transformer_names=missing_transformers,
    dataset=test_dataset,
    transformer_name2latents=transformer2test_latents,
    split="test",
)

In [None]:
# Compute train latents

FORCE_RECOMPUTE: bool = False
CACHE_LATENTS: bool = True

transformer2train_latents: Dict[str, Mapping[str, torch.Tensor]] = load_latents(
    split="train", transformer_names=transformer_names
)
missing_transformers = (
    transformer_names
    if FORCE_RECOMPUTE
    else [t_name for t_name in transformer_names if t_name not in transformer2train_latents]
)
encode_latents(
    transformer_names=missing_transformers,
    dataset=train_dataset,
    transformer_name2latents=transformer2train_latents,
    split="train",
)

In [None]:
transformer_name2hidden_dim = {
    transformer_name: latents["absolute"][0].shape[0] for transformer_name, latents in transformer2train_latents.items()
}
transformer_name2hidden_dim

In [None]:
latent_normalize: bool = True

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam


# def fit(X, y, seed, **kwargs):
#     classifier = make_pipeline(
#         Normalizer(), StandardScaler(), SVC(gamma="auto", kernel="linear", max_iter=200, random_state=seed)
#     )  # , class_weight="balanced"))
#     classifier.fit(X, y)
#     return lambda x: classifier.predict(x)


class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


def fit(X: torch.Tensor, y, seed, normalize: bool, hidden_dim: int):
    seed_everything(seed)
    if normalize:
        X = F.normalize(X, p=2, dim=-1)
    dataset = TensorDataset(X, torch.as_tensor(y))
    loader = DataLoader(dataset, batch_size=32, pin_memory=True, shuffle=True, num_workers=4)

    model = nn.Sequential(
        nn.LayerNorm(normalized_shape=hidden_dim),
        nn.Linear(in_features=hidden_dim, out_features=num_anchors),
        nn.SiLU(),
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=num_anchors),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(in_features=num_anchors, out_features=num_anchors),
        nn.SiLU(),
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=num_anchors),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(in_features=num_anchors, out_features=train_dataset.features[target_key].num_classes),
    ).to(device)
    opt = Adam(model.parameters(), lr=1e-4)
    loss_fn = CrossEntropyLoss()
    for epoch in tqdm(range(1), leave=False, desc="epoch"):
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            pred_y = model(batch_x)
            loss = loss_fn(pred_y, batch_y)
            loss.backward()
            opt.step()
            opt.zero_grad()
    model = model.cpu().eval()
    return lambda x: model(x).argmax(-1).detach().cpu()

In [None]:
SEEDS = list(range(3))
train_classifiers = {
    seed: {
        embedding_type: {
            transformer_name: fit(
                train_latents[embedding_type],
                train_dataset[target_key],
                seed=seed,
                normalize=latent_normalize,
                hidden_dim=transformer_name2hidden_dim[transformer_name]
                if embedding_type == "absolute"
                else num_anchors,
            )
            #             if embedding_type == "relative"
            #             else fake_model()
            for transformer_name, train_latents in tqdm(
                transformer2train_latents.items(), leave=False, desc="transformer"
            )
        }
        for embedding_type in tqdm(["absolute", "relative"], leave=False, desc="embedding_type")
    }
    for seed in tqdm(SEEDS, leave=False, desc="seed")
}

In [None]:
train_classifiers

In [None]:
from sklearn.metrics import precision_recall_fscore_support, mean_absolute_error
import itertools

numeric_results = {
    "seed": [],
    "embed_type": [],
    "train_model": [],
    "test_model": [],
    "precision": [],
    "recall": [],
    "fscore": [],
    "stitched": [],
}
for seed, embed_type2transformer2classifier in train_classifiers.items():
    for embed_type, transformer2classifier in embed_type2transformer2classifier.items():
        for (transformer_name1, classifier1), (transformer_name2, classifier2) in itertools.product(
            transformer2classifier.items(), repeat=2
        ):
            if embed_type == "absolute" and (
                transformer_name2hidden_dim[transformer_name1] != transformer_name2hidden_dim[transformer_name2]
            ):
                precision = recall = fscore = mae = np.nan
            else:
                test_latents = transformer2test_latents[transformer_name1][embed_type]
                if latent_normalize:
                    test_latents = F.normalize(test_latents, p=2, dim=-1)
                preds = classifier2(test_latents)
                test_y = np.array(test_dataset[target_key])

                precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average="weighted")
                mae = mean_absolute_error(y_true=test_y, y_pred=preds)
            numeric_results["embed_type"].append(embed_type)
            numeric_results["train_model"].append(transformer_name1)
            numeric_results["test_model"].append(transformer_name2)
            numeric_results["precision"].append(precision)
            numeric_results["recall"].append(recall)
            numeric_results["fscore"].append(fscore)
            numeric_results["stitched"].append(transformer_name1 != transformer_name2)
            numeric_results["seed"].append(seed)

import pandas as pd

pd.options.display.max_columns = None
pd.options.display.max_rows = None
df = pd.DataFrame(numeric_results)
df.to_csv(
    f"vision_transformer-stitching-{dataset_name}-{train_perc}.tsv",
    sep="\t",
)

df = df.groupby(
    [
        "embed_type",
        "stitched",
        "train_model",
        "test_model",
    ]
).agg([np.mean])
df

In [None]:
full_df = pd.read_csv(
    f"vision_transformer-stitching-{dataset_name}-{train_perc}.tsv",
    sep="\t",
    index_col=0,
)

df = full_df.groupby(
    [
        "embed_type",
        "stitched",
        "train_model",
        "test_model",
    ]
).agg([np.mean, "count"])
df

In [None]:
f"vision_transformer-stitching-{dataset_name}-{train_perc}.tsv",

In [None]:
full_df.drop(columns=["stitched", "seed", "precision", "recall"]).groupby(
    ["embed_type", "train_model", "test_model"]
).agg([np.mean]).round(3)

In [None]:
# it_dataset = get_samples(lang="it", sample_idxs=list(range(1000)))
# it_transformer_name: str = "dbmdz/bert-base-italian-cased"
# transformer, tokenizer = load_transformer(transformer_name=it_transformer_name)
# it_anchor_latents = get_latents(
#     dataloader=DataLoader(
#         get_samples("it", sample_idxs=anchor_idxs),
#         num_workers=16,
#         pin_memory=True,
#         collate_fn=partial(collate_fn, tokenizer=tokenizer),
#         batch_size=32,
#     ),
#     split=f"{it_transformer_name}",
#     anchors=None,
#     transformer=transformer,
# )
# it_latents = get_latents(
#     dataloader=DataLoader(
#         it_dataset,
#         num_workers=16,
#         pin_memory=True,
#         collate_fn=partial(collate_fn, tokenizer=tokenizer),
#         batch_size=32,
#     ),
#     split=f"{it_transformer_name}",
#     anchors=it_anchor_latents["absolute"].to(device),
#     transformer=transformer,
# )
# subsample_anchors = it_latents["relative"][:31, :]
# for i_sample, sample in enumerate(it_samples):
#     if sample["target"] == 3:
#         continue
#     for embed_type in ("relative", "absolute"):
#         latents = it_latents[embed_type]
#         latents = torch.cat([latents[i_sample, :].unsqueeze(0), subsample_anchors], dim=0)
#         classifier = train_classifiers[SEEDS[0]][embed_type]["en"]
#         print(
#             embed_type,
#             classifier(latents)[0].item(),
#             sample["class"],
#         )
#     print()
#     if i_sample > 100:
#         break