In [None]:
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn.functional as F
import os
import gzip
import csv
import matplotlib.pyplot as plt

In [None]:
sts_dataset_path = '../datasets/stsbenchmark.tsv.gz'

if not os.path.exists(sts_dataset_path):
    util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)


In [None]:
train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as f_in:
    df = pd.read_csv(f_in, delimiter='\t', quoting=csv.QUOTE_NONE)
    df_train = df[df["split"] == "train"]
    df_other = df[df["split"] != "train"]
    
display(df_train.head())
display(df_other.head())

In [None]:
train_examples = []
other_examples = []

for i, x in df_train.iterrows():
    train_examples.append((x['sentence1'], x['sentence2'], x["score"]/5.0))
    
for i, x in df_other.iterrows():
    other_examples.append((x['sentence1'], x['sentence2'], x["score"]/5.0))
    
len(train_examples), len(other_examples)

In [None]:
models = [
    "bert-base-cased", "roberta-base", "microsoft/deberta-base", "google/electra-base-discriminator", "microsoft/deberta-v3-base", "albert-base-v2"
]
models_dict = {k:[] for k in models}

models_dict

In [None]:
device = "cuda:0"

def batch_to_device(x):
    return {k: v.to(device) for k, v in x.items()}

for m in models:
    auto_model = AutoModel.from_pretrained(m).to(device)
    tokenizer = AutoTokenizer.from_pretrained(m)

    l = []
    with torch.no_grad():
        for t in tqdm(train_examples):
            tokenized1 = batch_to_device(
                tokenizer(t[0], padding="do_not_pad", return_tensors="pt")
            )
            out1 = auto_model(**tokenized1).last_hidden_state
            dev1 = out1.std(dim=1)
            
            
            tokenized2 = batch_to_device(
                tokenizer(t[1], padding="do_not_pad", return_tensors="pt")
            )
            out2 = auto_model(**tokenized2).last_hidden_state
            dev2 = out2.std(dim=1)
            
            error = torch.exp((torch.abs(dev1) - torch.abs(dev2))**2)[0]
            l.append(error.detach().cpu().tolist())
            
    models_dict[m] = l