In [1]:
import pandas as pd 
import numpy as np 
import torch

from twemoji.twemoji_dataset import TwemojiData
from embert import SimpleSembert, TopKAccuracy
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRAIN_IDX = list(range(1711))
TEST_IDX = list(range(1810))

In [3]:
def evaluate_on_dataset(model, data, k=1):
    accuracy = 0
    counter = 0
    score = TopKAccuracy(k)
    with tqdm(enumerate(data)) as tbatch:
        for i, batch in tbatch:
            X = batch[0]
            y = batch[1]
            outputs = model(X, TEST_IDX)
            batch_accuracy = score(outputs, y)
            accuracy += len(X) * batch_accuracy
            counter += len(X)

            tbatch.set_postfix(
                batch_accuracy=batch_accuracy,
                running_accuracy=accuracy / counter,
            )

    total_accuracy = accuracy / counter
    print(f"total accuracy is {total_accuracy}")
    return total_accuracy

### load prevalence

In [4]:
prevalence = pd.read_csv("twemoji/data/twemoji_prevalence.csv")
prevalence["faktor"] = np.log(prevalence.prevalence/prevalence.prevalence.mean())

### helper function

In [5]:
%%capture
pretrained_path = "trained_models/main_run/sembert_chunk51.ckpt"

model = SimpleSembert()
model = model.to(device)
model.load_state_dict(torch.load(pretrained_path, map_location=device))
print(f"loaded pretrained params from: {pretrained_path}")
model.eval()

### evaluate 

In [6]:
data = TwemojiData("balanced_train_v2")

In [8]:
# evaluate_on_dataset(model, data, k=1)

#### change normalization

In [9]:
model.normalize = prevalence.apply(lambda x: (int(x.emoji_ids), x.faktor), axis = 1).tolist()

In [10]:
evaluate_on_dataset(model, data, k=1)

11it [00:30,  2.76s/it, batch_accuracy=0.125, running_accuracy=0.169]


KeyboardInterrupt: 

In [33]:
data[0]

('Are you even joking with these right now? But it fits my theme though Late night so bad… https://t.co/EZURbZGynz',
 [1448, 371])