In [None]:
import random
from pathlib import Path
from typing import List, Tuple

import pandas as pd

In [None]:
import plotly.express as px
import sklearn.pipeline
import torch
from nn_core.serialization import load_model, NNCheckpointIO
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer, AutoTokenizer

In [None]:
from rae.data.text import TREC
from rae.modules.attention import RelativeAttention, AttentionOutput
from rae.pl_modules.pl_text_classifier import LightningTextClassifier

In [None]:
device: str = "cuda"

fine_grained: bool = True

In [None]:
from datasets import load_dataset, ClassLabel

dataset_key: str = ("trec",)
# dataset_key: str = ("amazon_reviews_multi", "en")
datasets = load_dataset(*dataset_key)

if dataset_key[0] == "dbpedia_14":

    def clean_sample(example):
        example["content"] = example["content"].strip('"').strip()
        return example

    datasets = datasets.map(clean_sample)
    target_key: str = "label"
    data_key: str = "content"

elif dataset_key[0] == "trec":
    target_key: str = "label-coarse"
    data_key: str = "text"

elif dataset_key[0] == "amazon_reviews_multi":

    def clean_sample(sample):
        title: str = sample["review_title"].strip('"').strip(".").strip()
        body: str = sample["review_body"].strip('"').strip(".").strip()

        if body.lower().startswith(title.lower()):
            title = ""

        if len(title) > 0 and title[-1].isalpha():
            title = f"{title}."

        sample["content"] = f"{title} {body}".lstrip(".").strip()
        if fine_grained:
            sample[target_key] = str(sample["stars"] - 1)
        else:
            sample[target_key] = sample["stars"] > 3
        return sample

    target_key: str = "stars"
    data_key: str = "content"
    datasets = datasets.map(clean_sample)
    datasets = datasets.cast_column(
        target_key,
        ClassLabel(num_classes=5 if fine_grained else 2, names=list(map(str, range(1, 6) if fine_grained else (0, 1)))),
    )


else:
    assert False

datasets

In [None]:
train_dataset = datasets["train"]  # .select(range(1000))
test_dataset = datasets["test"]  # .select(range(1000))
train_dataset, test_dataset

In [None]:
class2idx = train_dataset.features[target_key].str2int
train_dataset.features[target_key], class2idx

In [None]:
def load_transformer(transformer_name):
    transformer = AutoModel.from_pretrained(transformer_name, output_hidden_states=True, return_dict=True)
    transformer.requires_grad_(False).eval()
    return transformer, AutoTokenizer.from_pretrained(transformer_name)

In [None]:
transformer_names: str = [
    "bert-base-cased",
    "bert-base-uncased",
    "google/electra-base-discriminator",
    "roberta-base",
    # "albert-base-v2",
    # "distilbert-base-uncased",
    # "distilbert-base-cased",
    "xlm-roberta-base",
]

transformers = {
    transformer_name: load_transformer(transformer_name=transformer_name)
    for transformer_name in transformer_names  # all these have latents already cached in latents.pt
}

In [None]:
import numpy as np
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, Normalizer
from sklearn.svm import SVC
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import classification_report

In [None]:
train_y = np.array(train_dataset[target_key])
test_y = np.array(test_dataset[target_key])
len(set(train_y)), len(set(test_y))

In [None]:
@torch.no_grad()
def call_transformer(batch, transformer):
    encoding = batch["encoding"].to(device)
    sample_encodings = transformer(**encoding)["hidden_states"][-1]
    # TODO: aggregation mode
    # result = []
    # for sample_encoding, sample_mask in zip(sample_encodings, batch["mask"]):
    #     result.append(sample_encoding[sample_mask].mean(dim=0))

    # return torch.stack(result, dim=0)
    return sample_encodings[:, 0, :]  # CLS

In [None]:
from rae.data.text.datamodule import AnchorsMode
from sklearn.utils import shuffle

In [None]:
from typing import *


def get_anchors(dataset, anchors_mode, anchors_num) -> Dict[str, Any]:
    dataset_to_consider = dataset

    if anchors_mode == AnchorsMode.DATASET:
        return {
            "anchor_idxs": list(range(len(dataset_to_consider))),
            "anchor_samples": list(dataset_to_consider),
            "anchor_targets": dataset_to_consider[target_key],
            "anchor_classes": dataset_to_consider.classes,
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.STRATIFIED_SUBSET:
        shuffled_idxs, shuffled_targets = shuffle(
            np.asarray(list(range(len(dataset_to_consider)))),
            np.asarray(dataset_to_consider[target_key]),
            random_state=0,
        )
        all_targets = sorted(set(shuffled_targets))
        class2idxs = {target: shuffled_idxs[shuffled_targets == target] for target in all_targets}

        anchor_indices = []
        i = 0
        while len(anchor_indices) < anchors_num:
            for target, target_idxs in class2idxs.items():
                if i < len(target_idxs):
                    anchor_indices.append(target_idxs[i])
                if len(anchor_indices) == anchors_num:
                    break
            i += 1

        anchors = [dataset_to_consider[int(idx)] for idx in anchor_indices]

        return {
            "anchor_idxs": anchor_indices,
            "anchor_samples": anchors,
            "anchor_targets": [anchor[target_key] for anchor in anchors],
            "anchor_classes": [
                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors
            ],
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.STRATIFIED:
        if anchors_num >= len(dataset_to_consider.classes):
            _, anchor_indices = train_test_split(
                list(range(len(dataset_to_consider))),
                test_size=anchors_num,
                stratify=dataset_to_consider[target_key] if anchors_num >= len(dataset_to_consider.classes) else None,
                random_state=0,
            )
        else:
            anchor_indices = HARDCODED_ANCHORS[:anchors_num]
        anchors = [dataset_to_consider[int(idx)] for idx in anchor_indices]
        return {
            "anchor_idxs": anchor_indices,
            "anchor_samples": anchors,
            "anchor_targets": [anchor[target_key] for anchor in anchors],
            "anchor_classes": [
                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors
            ],
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.RANDOM_SAMPLES:
        anchor_idxs = list(range(len(dataset_to_consider)))
        random.shuffle(anchor_idxs)
        anchors = [dataset_to_consider[index] for index in anchor_idxs]
        return {
            "anchor_idxs": anchor_idxs,
            "anchor_samples": anchors,
            "anchor_targets": [anchor[target_key] for anchor in anchors],
            "anchor_classes": [
                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors
            ],
            "anchor_latents": None,
        }
    elif anchors_mode == AnchorsMode.RANDOM_LATENTS:
        raise NotImplementedError
    else:
        raise RuntimeError()


anchors_num: int = 768
anchor_idxs = get_anchors(train_dataset, anchors_mode=AnchorsMode.STRATIFIED_SUBSET, anchors_num=anchors_num)[
    "anchor_idxs"
]
anchor_idxs = [int(x) for x in anchor_idxs]
anchors = [train_dataset[anchor_idx] for anchor_idx in anchor_idxs]

In [None]:
relative_projection = RelativeAttention(
    n_anchors=anchors_num,
    normalization_mode="l2",
    similarity_mode="inner",
    values_mode="similarities",
    n_classes=train_dataset.features[target_key].num_classes,
    output_normalization_mode=None,
).to(device)

In [None]:
def collate_fn(batch, tokenizer):
    encoding = tokenizer(
        [sample[data_key] for sample in batch],
        return_tensors="pt",
        return_special_tokens_mask=True,
        truncation=True,
        padding=True,
    )
    # mask = encoding["attention_mask"] * encoding["special_tokens_mask"].bool().logical_not()
    del encoding["special_tokens_mask"]
    return {"encoding": encoding}

In [None]:
def get_latents(dataloader, anchors, split: str, transformer) -> Dict[str, torch.Tensor]:
    absolute_latents: List = []
    relative_latents: List = []

    transformer = transformer.to(device)
    for batch in tqdm(dataloader, desc=f"[{split}] Computing latents"):
        with torch.no_grad():
            batch_latents = call_transformer(batch=batch, transformer=transformer)

            absolute_latents.append(batch_latents.cpu())

            if anchors is not None:
                batch_rel_latents = relative_projection.encode(x=batch_latents, anchors=anchors)[
                    AttentionOutput.SIMILARITIES
                ]
                relative_latents.append(batch_rel_latents.cpu())

    absolute_latents: torch.Tensor = torch.cat(absolute_latents, dim=0)
    relative_latents: torch.Tensor = (
        torch.cat(relative_latents, dim=0).cpu() if len(relative_latents) > 0 else relative_latents
    )

    transformer = transformer.cpu()
    return {
        "absolute": absolute_latents,
        "relative": relative_latents,
    }

In [None]:
from rae import PROJECT_ROOT

LATENTS_DIR: Path = PROJECT_ROOT / "data" / "latents" / "/".join(dataset_key)
LATENTS_DIR.mkdir(exist_ok=True, parents=True)
LATENTS_DIR

In [None]:
def load_latents():
    latents = {}

    for transformer_name in transformers.keys():
        transformer_path = LATENTS_DIR / f"{transformer_name.replace('/', '-')}.pt"
        if transformer_path.exists():
            latents[transformer_name] = torch.load(transformer_path)

    return latents


list(load_latents().keys())

In [None]:
from functools import partial

FORCE_RECOMPUTE: bool = False
CACHE_LATENTS: bool = True

latents = load_latents()

missing_transformers = (
    transformer_names if FORCE_RECOMPUTE else [t_name for t_name in transformer_names if t_name not in latents]
)

for transformer_name in missing_transformers:
    latents[transformer_name] = {
        "anchors_latents": (
            anchors_latents := get_latents(
                dataloader=DataLoader(
                    anchors,
                    num_workers=8,
                    pin_memory=True,
                    collate_fn=partial(collate_fn, tokenizer=transformers[transformer_name][1]),
                    batch_size=32,
                ),
                split=f"{transformer_name}, anchor",
                anchors=None,
                transformer=transformers[transformer_name][0],
            )["absolute"]
        ),
        **{
            str(dataset_split.split): get_latents(
                dataloader=DataLoader(
                    dataset_split,
                    num_workers=8,
                    pin_memory=True,
                    collate_fn=partial(collate_fn, tokenizer=transformers[transformer_name][1]),
                    batch_size=32,
                ),
                split=f"{transformer_name}, {str(dataset_split.split)}",
                anchors=anchors_latents.to(device),
                transformer=transformers[transformer_name][0],
            )
            for dataset_split in [train_dataset, test_dataset]
        },
    }
    # Save latents
    if CACHE_LATENTS:
        transformer_path = LATENTS_DIR / f"{transformer_name.replace('/', '-')}.pt"
        torch.save(latents[transformer_name], transformer_path)
latents

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam


# def fit(X, y):
#     classifier = make_pipeline(
#         Normalizer(), StandardScaler(), SVC(gamma="auto", kernel="linear", random_state=42)
#     )  # , class_weight="balanced"))
#     classifier.fit(X, y)
#     return lambda x: classifier.predict(x)


from torch import nn
from pytorch_lightning import seed_everything


class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


DATASET2LR = {
    "trec": 1e-4,
    "amazon_reviews_multi": 1e-3,
    "dbpedia_14": 1e-4,
}


def fit(X, y, seed):
    seed_everything(seed)
    dataset = TensorDataset(X, torch.as_tensor(y))
    loader = DataLoader(dataset, batch_size=32, pin_memory=True, shuffle=True, num_workers=4)

    model = nn.Sequential(
        nn.LayerNorm(normalized_shape=anchors_num),
        nn.Linear(in_features=anchors_num, out_features=anchors_num),
        nn.SiLU(),
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=anchors_num),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(in_features=anchors_num, out_features=anchors_num),
        nn.SiLU(),
        Lambda(lambda x: x.permute(1, 0)),
        nn.InstanceNorm1d(num_features=anchors_num),
        Lambda(lambda x: x.permute(1, 0)),
        nn.Linear(in_features=anchors_num, out_features=train_dataset.features[target_key].num_classes),
        nn.ReLU(),
    ).to(device)
    opt = Adam(model.parameters(), lr=DATASET2LR[dataset_key[0]])
    loss_fn = CrossEntropyLoss()
    for epoch in tqdm(range(5 if fine_grained else 3), leave=False, desc="epoch"):
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            pred_y = model(batch_x)
            loss = loss_fn(pred_y, batch_y)
            loss.backward()
            opt.step()
            opt.zero_grad()
    model = model.cpu().eval()
    return lambda x: model(x).argmax(-1).detach().cpu()

In [None]:
SEEDS = [0, 1, 2, 3, 4]


fitted_classifiers = {
    seed: {
        transformer_name: {
            embedding_type: fit(latents[transformer_name]["train"][embedding_type], train_y, seed)
            for embedding_type in tqdm(["absolute", "relative"], leave=False, desc="embedding_type")
        }
        for transformer_name in tqdm(transformers, desc="transformer")
    }
    for seed in SEEDS
}
fitted_classifiers

In [None]:
from sklearn.metrics import precision_recall_fscore_support

numeric_results = {
    "seed": [],
    "embed_type": [],
    "embed_transformer": [],
    "classifier_transformer": [],
    "precision": [],
    "recall": [],
    "fscore": [],
    "stitched": [],
}
for seed in SEEDS:
    for embed_type in ["absolute", "relative"]:
        for embed_transformer in transformers:
            for classifier_transformer in transformers:
                test_latents = latents[embed_transformer]["test"][embed_type]
                classifier = fitted_classifiers[seed][classifier_transformer][embed_type]
                preds = classifier(test_latents)

                precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average="weighted")
                numeric_results["embed_type"].append(embed_type)
                numeric_results["embed_transformer"].append(embed_transformer)
                numeric_results["classifier_transformer"].append(classifier_transformer)
                numeric_results["precision"].append(precision)
                numeric_results["recall"].append(recall)
                numeric_results["fscore"].append(fscore)
                numeric_results["stitched"].append(embed_transformer != classifier_transformer)
                numeric_results["seed"].append(seed)


import pandas as pd

pd.options.display.max_columns = None
pd.options.display.max_rows = None
df = pd.DataFrame(numeric_results)
dataset_name = "_".join(dataset_key)
fine_grained_str = "_fine_grained" if fine_grained else "_coarse"
df.to_csv(
    f"nlp_stitching-{dataset_name}{'' if dataset_key[0] != 'amazon_reviews_multi' else fine_grained_str}.tsv", sep="\t"
)

df = df.groupby(
    [
        "embed_type",
        "stitched",
        "embed_transformer",
        "classifier_transformer",
    ]
).agg([np.mean, "count"])
df

In [None]:
dataset_name

In [None]:
df.groupby(
    [
        "embed_type",
        "stitched",
    ]
).agg([np.mean, "count"])