In [None]:
import sys
import pandas as pd
import numpy as np
import torch
import pytorch_lightning as pl
import torch.utils.data as data

from transformers import AutoModel, AutoTokenizer

sys.path.append('../scripts/')
from forward_simulation import BowStudent, BowRationaleDataset
from forward_simulation import BowNLIStudent, BowNLIRationaleDataset
from forward_simulation import collate_fn

In [None]:
t5_tokenizer = AutoTokenizer.from_pretrained('t5-small')
trainer = pl.Trainer(accelerator='cpu', enable_progress_bar=False)

In [None]:
def compute_simulability(fname, dataset_cls, student_model):
    ds = dataset_cls(fname, t5_tokenizer)
    dl = data.DataLoader(ds, batch_size=16, collate_fn=collate_fn, shuffle=False) 
    outputs = trainer.predict(student_model, dl)
    stacked_outputs = {k: [x[k] for x in outputs] for k in outputs[0].keys()}
    gold = torch.cat(stacked_outputs['gold'])
    pred = torch.cat(stacked_outputs['pred'])
    sim = torch.mean((gold == pred).float()).item()
    return sim

---

### IMDB

In [None]:
imdb_student = BowStudent.load_from_checkpoint("../lightning_logs/version_2/checkpoints/epoch=4-step=140.ckpt").eval()

In [None]:
compute_simulability(
    fname='../data/edits/revised_imdb_test_beam_15_sparsemap_30p.tsv',
    dataset_cls=BowRationaleDataset,
    student_model=imdb_student
)

---

### SNLI

In [None]:
snli_student = BowStudent.load_from_checkpoint("../lightning_logs/version_2/checkpoints/epoch=4-step=140.ckpt").eval()

In [None]:
compute_simulability(
    fname='../data/edits/revised_snli_test_beam_15_sparsemap_30p.tsv',
    dataset_cls=BowNLIRationaleDataset,
    student_model=snli_student
)