In [None]:
import sys
sys.path.append("..")

In [None]:
from functools import partial
from typing import Any

import numpy as np
import transformers
from bert_finetune import tokenize_function
from omegaconf import OmegaConf
from repsim.nlp import get_dataset, get_tokenizer, ShortcutAdder

In [None]:
def get_dataset_cfg(config_name):
    return OmegaConf.load(f"/root/similaritybench/nlp/config/dataset/{config_name}.yaml")


def get_model_cfg(config_name):
    return OmegaConf.load(f"/root/similaritybench/nlp/config/model/{config_name}.yaml")


def get_tokenizer(dataset_cfg, model_cfg):
    return transformers.AutoTokenizer.from_pretrained(
        model_cfg.kwargs.tokenizer_name,
        additional_special_tokens=[f"[CLASS{i}]" for i in range(dataset_cfg.finetuning.num_labels)],
    )


def ratio_shortcut_equals_true_label(shortcut_strength, dataset, dataset_cfg, tokenizer, ratio_subset="validation"):
    dataset_name = dataset_cfg.path + "__" + dataset_cfg.name if dataset_cfg.name is not None else dataset_cfg.path
    feature_column = dataset_cfg.feature_column[0]
    sc_adder = ShortcutAdder(
        num_labels=dataset_cfg.finetuning.num_labels, p=shortcut_strength, feature_column=feature_column
    )
    ds_w_shortcut = dataset.map(sc_adder)
    tokenized_dataset = ds_w_shortcut.map(
        partial(
            tokenize_function,
            tokenizer=tokenizer,
            dataset_name=dataset_name,
            feature_column=sc_adder.new_feature_column,
        ),
        batched=True,
    )
    additional_tokids_to_toks = {
        idx: tok for tok, idx in zip(tokenizer.additional_special_tokens, tokenizer.additional_special_tokens_ids)
    }

    def shortcut_eq_label(example: dict[str, Any]) -> dict[str, str]:
        label = example["label"]
        added_tok_id = example["input_ids"][1]
        shortcut_label = int(additional_tokids_to_toks[added_tok_id][6:-1])
        # print(label, shortcut_label)
        return {"shortcut_eq_label": label == shortcut_label}

    new_ds = tokenized_dataset[ratio_subset].map(shortcut_eq_label)
    return {
        f"ratio": sum(new_ds["shortcut_eq_label"]) / len(new_ds["shortcut_eq_label"]),
        "dataset": tokenized_dataset,
    }


def get_class_distribution(dataset):
    class_distributions = {}
    for subset in dataset.keys():
        labels_array = np.array(dataset[subset]["label"])

        # Count the occurrences of each class using np.unique
        unique_labels, label_counts = np.unique(labels_array, return_counts=True)

        # Convert the results to a dict
        class_distribution = dict(zip(unique_labels, label_counts))
        class_distributions[subset] = class_distribution

    return class_distributions


def print_class_distributions(dataset):
    class_distributions = get_class_distribution(dataset)
    for subset, distribution in class_distributions.items():
        print(f"Class distribution for {subset}:")
        for label, count in distribution.items():
            print(f"{label}: {count}")

## SST 2

In [None]:
dataset_cfg = get_dataset_cfg("sst2")
model_cfg = get_model_cfg("multibert")
tokenizer = get_tokenizer(dataset_cfg, model_cfg)
dataset = get_dataset(dataset_cfg.path, dataset_cfg.name)

result = ratio_shortcut_equals_true_label(shortcut_strength=0.75, dataset, dataset_cfg, tokenizer)
print(result["ratio"])

In [None]:
tokenizer

In [None]:
tokenized_dataset = result["dataset"]
tokenized_dataset["validation"][1]

## MNLI

In [None]:
dataset_cfg = get_dataset_cfg("mnli")
model_cfg = get_model_cfg("multibert")
tokenizer = get_tokenizer(dataset_cfg, model_cfg)
dataset = get_dataset(dataset_cfg.path, dataset_cfg.name)

result = ratio_shortcut_equals_true_label(0.25, dataset, dataset_cfg, tokenizer, "validation_matched")
print(result["ratio"])

In [None]:
print_class_distributions(dataset)


In [None]:
distrs = get_class_distribution(dataset)
distr = distrs["validation_matched"]
[count / sum(distr.values()) for count in distr.values()]

In [None]:
np.linspace(0.354, 1, num=5)