In [None]:
import random
from pathlib import Path
from typing import List, Tuple

In [None]:
import plotly.express as px
import sklearn.pipeline
import torch
from nn_core.serialization import load_model, NNCheckpointIO
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer, AutoTokenizer

In [None]:
from rae.data.text import TREC
from rae.modules.attention import RelativeAttention, AttentionOutput
from rae.pl_modules.pl_text_classifier import LightningTextClassifier

In [None]:
def get_model_cfg(ckpt_path: Path):
    cfg = NNCheckpointIO.load(path=ckpt_path)["cfg"]
    return cfg

In [None]:
def plot_latent_space(metadata, validation_stats_df, x_data: str, y_data: str):
    color_discrete_map = {
        class_name: color
        for class_name, color in zip(metadata.class_to_idx, px.colors.qualitative.Plotly[: len(metadata.class_to_idx)])
    }

    latent_val_fig = px.scatter(
        validation_stats_df,
        x=x_data,
        y=y_data,
        category_orders={"class_name": metadata.class_to_idx.keys()},
        #             # size='std_0',  # TODO: fixme, plotly crashes with any column name to set the anchor size
        color="class_name",
        hover_name="image_index",
        hover_data=["image_index", "anchor_index"],
        facet_col="is_anchor",
        color_discrete_map=color_discrete_map,
        # symbol="is_anchor",
        # symbol_map={False: "circle", True: "star"},
        size_max=40,
        # range_x=[-5, 5],
        color_continuous_scale=None,
        # range_y=[-5, 5],
    )
    return latent_val_fig

In [None]:
def load_ckpt(ckpt_path: Path):
    return load_model(module_class=LightningTextClassifier, checkpoint_path=ckpt_path, strict=False).to(device).eval()

In [None]:
CODE_VERSION = 0.1

device: str = "cuda"

In [None]:
from datasets import load_dataset

datasets = load_dataset("trec")
train_dataset = datasets["train"]
test_dataset = datasets["test"]
train_dataset, test_dataset

In [None]:
target_key: str = "label-coarse"
data_key: str = "text"

In [None]:
class2idx = train_dataset.features["label-fine"].str2int
train_dataset.features["label-fine"]

In [None]:
def load_transformer(transformer_name):
    transformer = AutoModel.from_pretrained(transformer_name, output_hidden_states=True, return_dict=True)
    transformer.requires_grad_(False).eval().to(device)
    return transformer, AutoTokenizer.from_pretrained(transformer_name)

In [None]:
transformer_name: str = [
    "bert-base-cased",
    "bert-base-uncased",
    "google/electra-base-discriminator",
    "roberta-base",
    "albert-base-v2",
    "distilbert-base-uncased",
    "distilbert-base-cased",
    "xlm-roberta-base",
][0]
transformer, tokenizer = load_transformer(transformer_name=transformer_name)

In [None]:
import numpy as np
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, Normalizer
from sklearn.svm import SVC
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import classification_report

In [None]:
train_y = np.array(train_dataset[target_key])
test_y = np.array(test_dataset[target_key])
len(set(train_y)), len(set(test_y))

In [None]:
def call_transformer(batch, transformer):
    encoding = batch["encoding"].to(device)
    sample_encodings = transformer(**encoding)["hidden_states"][-1]
    # TODO: aggregation mode
    result = []
    for sample_encoding, sample_mask in zip(sample_encodings, batch["mask"]):
        result.append(sample_encoding[sample_mask].mean(dim=0))

    return torch.stack(result, dim=0)

In [None]:
from rae.data.text.datamodule import AnchorsMode
from sklearn.utils import shuffle

In [None]:
from typing import *


def get_anchors(dataset, anchors_mode, anchors_num) -> Dict[str, Any]:
    dataset_to_consider = dataset

    if anchors_mode == AnchorsMode.DATASET:
        return {
            "anchor_idxs": list(range(len(dataset_to_consider))),
            "anchor_samples": list(dataset_to_consider),
            "anchor_targets": dataset_to_consider[target_key],
            "anchor_classes": dataset_to_consider.classes,
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.STRATIFIED_SUBSET:
        shuffled_idxs, shuffled_targets = shuffle(
            np.asarray(list(range(len(dataset_to_consider)))),
            np.asarray(dataset_to_consider[target_key]),
            random_state=0,
        )
        all_targets = sorted(set(shuffled_targets))
        class2idxs = {target: shuffled_idxs[shuffled_targets == target] for target in all_targets}

        anchor_indices = []
        i = 0
        while len(anchor_indices) < anchors_num:
            for target, target_idxs in class2idxs.items():
                if i < len(target_idxs):
                    anchor_indices.append(target_idxs[i])
                if len(anchor_indices) == anchors_num:
                    break
            i += 1

        anchors = [dataset_to_consider[int(idx)] for idx in anchor_indices]

        return {
            "anchor_idxs": anchor_indices,
            "anchor_samples": anchors,
            "anchor_targets": [anchor[target_key] for anchor in anchors],
            "anchor_classes": [
                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors
            ],
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.STRATIFIED:
        if anchors_num >= len(dataset_to_consider.classes):
            _, anchor_indices = train_test_split(
                list(range(len(dataset_to_consider))),
                test_size=anchors_num,
                stratify=dataset_to_consider[target_key] if anchors_num >= len(dataset_to_consider.classes) else None,
                random_state=0,
            )
        else:
            anchor_indices = HARDCODED_ANCHORS[:anchors_num]
        anchors = [dataset_to_consider[int(idx)] for idx in anchor_indices]
        return {
            "anchor_idxs": anchor_indices,
            "anchor_samples": anchors,
            "anchor_targets": [anchor[target_key] for anchor in anchors],
            "anchor_classes": [
                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors
            ],
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.RANDOM_SAMPLES:
        anchor_idxs = list(range(len(dataset_to_consider)))
        random.shuffle(anchor_idxs)
        anchors = [dataset_to_consider[index] for index in anchor_idxs]
        return {
            "anchor_idxs": anchor_idxs,
            "anchor_samples": anchors,
            "anchor_targets": [anchor[target_key] for anchor in anchors],
            "anchor_classes": [
                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors
            ],
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.RANDOM_LATENTS:
        raise NotImplementedError
    else:
        raise RuntimeError()

In [None]:
anchors_num: int = 768
relative_projection = RelativeAttention(
    n_anchors=anchors_num,
    normalization_mode="l2",
    similarity_mode="inner",
    values_mode="similarities",
    n_classes=train_dataset.features[target_key].num_classes,
    output_normalization_mode=None,
).to(device)

In [None]:
def collate_fn(batch):
    encoding = tokenizer(
        [sample[data_key] for sample in batch],
        return_tensors="pt",
        return_special_tokens_mask=True,
        truncation=True,
        padding=True,
    )
    mask = encoding["attention_mask"] * encoding["special_tokens_mask"].bool().logical_not()
    del encoding["special_tokens_mask"]
    return {"encoding": encoding, "mask": mask.bool()}

In [None]:
anchor_idxs = get_anchors(train_dataset, anchors_mode=AnchorsMode.STRATIFIED_SUBSET, anchors_num=anchors_num)[
    "anchor_idxs"
]
anchor_idxs = [int(x) for x in anchor_idxs]

In [None]:
def get_latents(dataloader, anchors, split: str, transformer) -> Tuple[torch.Tensor, torch.Tensor]:
    absolute_latents: List = []
    relative_latents: List = []

    for batch in tqdm(dataloader, desc=f"[{split}] Computing relative latents"):
        with torch.no_grad():
            batch_latents = call_transformer(batch=batch, transformer=transformer)
            #             batch_latents = batch_latents[:, :anchors_num]

            absolute_latents.append(batch_latents.cpu())

            if anchors is not None:
                batch_latents = relative_projection.encode(x=batch_latents, anchors=anchors)[
                    AttentionOutput.SIMILARITIES
                ].cpu()
                relative_latents.append(batch_latents)

    absolute_latents: torch.Tensor = torch.cat(absolute_latents, dim=0)
    relative_latents: torch.Tensor = (
        torch.cat(relative_latents, dim=0) if len(relative_latents) > 0 else relative_latents
    )

    return absolute_latents, relative_latents

In [None]:
%%capture
anchors = [train_dataset[anchor_idx] for anchor_idx in anchor_idxs]
anchors, _ = get_latents(
    dataloader=DataLoader(anchors, num_workers=8, pin_memory=True, collate_fn=collate_fn, batch_size=32),
    split="anchor",
    anchors=None,
    transformer=transformer,
)
anchors = anchors.to(device)

In [None]:
absolute_train_X, relative_train_X = get_latents(
    dataloader=DataLoader(train_dataset, num_workers=8, pin_memory=True, collate_fn=collate_fn, batch_size=32),
    split="train",
    anchors=anchors,
    transformer=transformer,
)
absolute_train_X.shape, relative_train_X.shape

In [None]:
absolute_test_X, relative_test_X = get_latents(
    dataloader=DataLoader(test_dataset, num_workers=16, pin_memory=True, collate_fn=collate_fn, batch_size=64),
    split="test",
    anchors=anchors,
    transformer=transformer,
)

In [None]:
def svm_fit(X, y):
    classifier = make_pipeline(
        Normalizer(), StandardScaler(), SVC(gamma="auto", kernel="linear", random_state=42)
    )  # , class_weight="balanced"))
    return classifier.fit(X, y)

In [None]:
target_names = test_dataset.features[target_key].names
len(target_names)

In [None]:
# absolute_classifier: sklearn.pipeline.Pipeline = svm_fit(absolute_train_X, train_y)
# relative_classifier: sklearn.pipeline.Pipeline = svm_fit(relative_train_X, train_y)

In [None]:
from rae import PROJECT_ROOT

absolute_classifier, relative_classifier = torch.load(PROJECT_ROOT / "test.pt")

In [None]:
absolute_y_pred = absolute_classifier.predict(absolute_test_X)

In [None]:
print(classification_report(test_y, absolute_y_pred, target_names=target_names, output_dict=False))

In [None]:
relative_y_pred = relative_classifier.predict(relative_test_X)

In [None]:
print(classification_report(test_y, relative_y_pred, output_dict=False))

In [None]:
# torch.save((absolute_classifier, relative_classifier), PROJECT_ROOT / "test.pt")