In [None]:
from datasets import load_from_disk

In [None]:
import numpy as np

from pathlib import Path

In [None]:
from rae import PROJECT_ROOT
from rae.modules.attention import RelativeAttention, AttentionOutput

In [None]:
datasets_dir: Path = PROJECT_ROOT / "data" / "hf_datasets"

In [None]:
device: str = "cuda"

In [None]:
dataset = load_from_disk(str(datasets_dir / "N24News" / "encoded"))
dataset.set_format(type="torch", columns=["body_roberta-base", "image_vit_base_patch16_224"], output_all_columns=True)
dataset

In [None]:
dataset["train"]

In [None]:
target_key: str = "label"

train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [None]:
import random
from sklearn.model_selection import train_test_split
from rae.data.text.datamodule import AnchorsMode
from sklearn.utils import shuffle

from typing import *


def get_anchors(dataset, anchors_mode, anchors_num) -> Dict[str, Any]:
    dataset_to_consider = dataset

    if anchors_mode == AnchorsMode.DATASET:
        return {
            "anchor_idxs": list(range(len(dataset_to_consider))),
            "anchor_samples": list(dataset_to_consider),
            "anchor_targets": dataset_to_consider[target_key],
            "anchor_classes": dataset_to_consider.classes,
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.STRATIFIED_SUBSET:
        shuffled_idxs, shuffled_targets = shuffle(
            np.asarray(list(range(len(dataset_to_consider)))),
            np.asarray(dataset_to_consider[target_key]),
            random_state=0,
        )
        all_targets = sorted(set(shuffled_targets))
        class2idxs = {target: shuffled_idxs[shuffled_targets == target] for target in all_targets}

        anchor_indices = []
        i = 0
        while len(anchor_indices) < anchors_num:
            for target, target_idxs in class2idxs.items():
                if i < len(target_idxs):
                    anchor_indices.append(target_idxs[i])
                if len(anchor_indices) == anchors_num:
                    break
            i += 1

        anchors = [dataset_to_consider[int(idx)] for idx in anchor_indices]

        return {
            "anchor_idxs": anchor_indices,
            "anchor_samples": anchors,
            "anchor_targets": [anchor[target_key] for anchor in anchors],
            "anchor_classes": [
                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors
            ],
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.STRATIFIED:
        if anchors_num >= len(dataset_to_consider.classes):
            _, anchor_indices = train_test_split(
                list(range(len(dataset_to_consider))),
                test_size=anchors_num,
                stratify=dataset_to_consider[target_key] if anchors_num >= len(dataset_to_consider.classes) else None,
                random_state=0,
            )
        else:
            assert False
        anchors = [dataset_to_consider[int(idx)] for idx in anchor_indices]
        return {
            "anchor_idxs": anchor_indices,
            "anchor_samples": anchors,
            "anchor_targets": [anchor[target_key] for anchor in anchors],
            "anchor_classes": [
                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors
            ],
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.RANDOM_SAMPLES:
        anchor_idxs = list(range(len(dataset_to_consider)))
        random.shuffle(anchor_idxs)
        anchors = [dataset_to_consider[index] for index in anchor_idxs]
        return {
            "anchor_idxs": anchor_idxs,
            "anchor_samples": anchors,
            "anchor_targets": [anchor[target_key] for anchor in anchors],
            "anchor_classes": [
                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors
            ],
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.RANDOM_LATENTS:
        raise NotImplementedError
    else:
        raise RuntimeError()


anchors_num: int = 768
anchor_idxs = get_anchors(train_dataset, anchors_mode=AnchorsMode.STRATIFIED_SUBSET, anchors_num=anchors_num)[
    "anchor_idxs"
]
anchor_idxs = [int(x) for x in anchor_idxs]
anchors = [train_dataset[anchor_idx] for anchor_idx in anchor_idxs]

In [None]:
relative_projection = RelativeAttention(
    n_anchors=anchors_num,
    normalization_mode=None,
    similarity_mode="inner",
    values_mode="similarities",
    n_classes=train_dataset.features[target_key].num_classes,
    output_normalization_mode=None,
).to(device)

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

# def fit(X, y):
#     classifier = make_pipeline(
#         Normalizer(), StandardScaler(), SVC(gamma="auto", kernel="linear", random_state=42)
#     )  # , class_weight="balanced"))
#     classifier.fit(X, y)
#     return lambda x: classifier.predict(x)
import torch
from tqdm import tqdm
from torch import nn
from pytorch_lightning import seed_everything
from torch.utils.data import TensorDataset, DataLoader
from torch.nn import functional as F


class Aggregate(nn.Module):
    def __init__(self, num_signals: int, num_dims: int):
        super().__init__()
        self.num_signals: int = num_signals
        self.num_dims = num_dims
        # self.projection = nn.Linear(in_features=num_signals * num_dims, out_features=num_dims)
        # self.weights = nn.Parameter(torch.ones(num_dims, num_signals))
        # self.layer_norms = nn.ModuleList([
        #     nn.LayerNorm(normalized_shape=anchors_num),
        #     nn.LayerNorm(normalized_shape=anchors_num)
        # ])
        self.projection = nn.Linear(in_features=num_signals, out_features=1)
        self.projections = nn.ModuleList([nn.Linear(num_dims, num_dims), nn.Linear(num_dims, num_dims)])

    def forward(self, x):
        # if self.num_signals == 1:
        #     return x
        # x = torch.stack(x, dim=-1)
        # return self.projection(x).squeeze(-1)
        # return x.sum(dim=-1)
        # if self.num_signals == 1:
        #     return x.squeeze(-1)
        #
        x = torch.stack([self.projections[signal](x[..., signal]) for signal in range(x.size(-1))], dim=-1)
        return self.projection(x).squeeze(-1)
        # return torch.einsum("bes,esx->bex", x, self.weights)   # TODO: we could use a mean or another aggregation mode

    # def forward(self, x):
    #     return self.projection(x.view(x.size(0), -1))


def fit(X: torch.Tensor, y, seed):
    seed_everything(seed)
    dataset = TensorDataset(X, torch.as_tensor(y))
    loader = DataLoader(dataset, batch_size=32, pin_memory=True, shuffle=True, num_workers=4)

    model = [
        Aggregate(num_signals=X.size(-1), num_dims=X.size(1)),
        #
        nn.LayerNorm(normalized_shape=anchors_num),
        nn.Linear(in_features=anchors_num, out_features=anchors_num),
        nn.SiLU(),
        #
        nn.BatchNorm1d(num_features=anchors_num),
        nn.Linear(in_features=anchors_num, out_features=anchors_num),
        nn.SiLU(),
        #
        nn.BatchNorm1d(num_features=anchors_num),
        nn.Linear(in_features=anchors_num, out_features=train_dataset.features[target_key].num_classes),
        nn.ReLU(),
    ]

    # if len(X.shape) == 3:
    #     assert X.size(-1) > 1
    #     model.insert(0, nn.Flatten(-2, -1))
    #     model.insert(0, nn.Linear(in_features=X.size(-1), out_features=1))

    model = nn.Sequential(*model)

    model = model.to(device)
    opt = Adam(model.parameters(), lr=1e-3)
    loss_fn = CrossEntropyLoss()
    for epoch in tqdm(range(5), leave=False, desc="epoch"):
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            pred_y = model(batch_x)
            loss = loss_fn(pred_y, batch_y)
            loss.backward()
            opt.step()
            opt.zero_grad()
    model = model.eval().cpu()

    return lambda x: model(x).argmax(-1).detach().cpu()

In [None]:
import pandas as pd
from collections import defaultdict
from sklearn.metrics import precision_recall_fscore_support
import itertools

NORMALIZE: bool = True
train_y = train_dataset[target_key]
test_y = test_dataset[target_key]
anchor_latents = [anchor for anchor in anchors]

result = defaultdict(list)
for feature_names, embedding_type, seed in itertools.product(
    (("body_roberta-base", "image_vit_base_patch16_224"), ("body_roberta-base",), ("image_vit_base_patch16_224",)),
    ("absolute", "relative"),
    range(1),
):
    train_feature2latents = {feature_name: train_dataset[feature_name] for feature_name in feature_names}
    test_feature2latents = {feature_name: test_dataset[feature_name] for feature_name in feature_names}

    if NORMALIZE:
        train_feature2latents = {
            feature_name: F.normalize(latents, dim=-1, p=2) for feature_name, latents in train_feature2latents.items()
        }
        test_feature2latents = {
            feature_name: F.normalize(latents, dim=-1, p=2) for feature_name, latents in test_feature2latents.items()
        }

    if embedding_type == "relative":
        # !!! DO NOT SWAP TEST AND TRAIN LATENTS !!!
        with torch.no_grad():
            test_feature2latents = {
                feature_name: relative_projection(x=latents, anchors=train_feature2latents[feature_name][anchor_idxs])[
                    AttentionOutput.SIMILARITIES
                ]
                for feature_name, latents in test_feature2latents.items()
            }
            train_feature2latents = {
                feature_name: relative_projection(x=latents, anchors=latents[anchor_idxs])[AttentionOutput.SIMILARITIES]
                for feature_name, latents in train_feature2latents.items()
            }
        # !!! REALLY, DON'T !!!

    train_latents = torch.stack(list(train_feature2latents.values()), dim=-1)
    model = fit(X=train_latents, y=train_y, seed=seed)

    test_latents = torch.stack(list(test_feature2latents.values()), dim=-1)
    preds = model(test_latents)

    precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average="weighted")
    result["embed_type"].append(embedding_type)
    result["encodings"].append(feature_names)
    result["precision"].append(precision)
    result["recall"].append(recall)
    result["fscore"].append(fscore)
    result["seed"].append(seed)

    print(pd.DataFrame(result))

In [None]:
df2 = pd.DataFrame(result)
df1