In [1]:
from Todd import MahalanobisFilter, extract_embeddings
import torch
from torch.utils.data import Dataset, DataLoader

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from toddbenchmark.generation_data import prep_dataset, prep_model, GenerationDataset
from datasets import load_dataset


In [2]:
%load_ext autoreload
%autoreload 2


In [3]:
model, tokenizer = prep_model("Helsinki-NLP/opus-mt-de-en")




# Load and prep dataset using ToddBenchmark

In [4]:

in_dataset = prep_dataset("wmt16", "de-en", tokenizer=tokenizer)
out_dataset = prep_dataset("wmt16", "ro-en", tokenizer=tokenizer)

# For the sake of this example we only use 100 samples to keep things quick!
in_val = in_dataset[1][:100]
in_test = in_dataset[2][:100]
out_test = out_dataset[2][:100]

del in_dataset
del out_dataset


Found cached dataset wmt16 (/home/mdarrin/.cache/huggingface/datasets/wmt16/de-en/1.0.0/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset wmt16 (/home/mdarrin/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
# Make dataloader
in_val_loader = DataLoader(in_val, shuffle=False, batch_size=4)
in_test_loader = DataLoader(in_test, shuffle=False, batch_size=4)
out_test_loader = DataLoader(out_test, shuffle=False, batch_size=4)


## Extracting reference data to fit the detectors

In [16]:
# We work here in a case where the classes do not matter
# So we can skip retrieving them and the only class key will be 0
# It would be different in a classification problem with enough data we would have a reference per class
ref_embeddings, _ = extract_embeddings(model,tokenizer, in_val_loader, layers=[6])

In [18]:
ref_embeddings.keys()

dict_keys([(6, 0)])

In [30]:
len(ref_embeddings[(6,0)])

100

In [162]:
maha_detector = MahalanobisFilter(threshold=0.5, layers=[6])
maha_detector.fit(ref_embeddings)

In [163]:
print(maha_detector.covs[(6, 0)].shape)

torch.Size([512, 512])


In [164]:
with torch.no_grad():
    for batch in in_test_loader:

        inputs = tokenizer(
            batch["source"], padding=True, truncation=True, return_tensors="pt"
        )
        output = model.generate(
            **inputs,
            return_dict_in_generate=True,
            output_hidden_states=True,
            output_scores=True,
        )

        print(maha_detector(output))
        print(maha_detector.compute_scores(output))
        break


tensor([ True, False,  True,  True])
tensor([-1.6155e+08,  2.2526e+08, -3.4467e+08, -1.8851e+08])
