In [1]:
%load_ext lab_black

In [2]:
import pandas as pd
import numpy as np
import torch

from twemoji.twemoji_dataset import TwemojiData, TwemojiBalancedData, TwemojiDataChunks
from embert import SimpleSembert, TopKAccuracy
from tqdm import tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRAIN_IDX = list(range(1711))
TEST_IDX = list(range(1810))

### load model

In [4]:
def get_model(balanced=False):
    model = SimpleSembert(dropout=0.2)
    model = model.to(device)
    if balanced:
        pretrained_path = "trained_models/balanced_sembert_dropout/balanced_sembert_dropout_chunk106.ckpt"
    else:
        pretrained_path = "trained_models/sembert_dropout/sembert_dropout_chunk77.ckpt"
    model.load_state_dict(torch.load(pretrained_path, map_location=device))
    model.eval()
    return model

In [5]:
%%capture
model = get_model()
model_balanced = get_model(balanced=True)

### load mapping dicts etc.

In [6]:
df_des = pd.read_csv("emoji_embedding/data/processed/emoji_descriptions.csv")
emoji_id_char = {k: v for k, v in zip(df_des.emoji_id, df_des.emoji_char)}

In [7]:
TOP_EMOJIS = (
    pd.read_csv("twemoji/data/twemoji_prevalence.csv")
    .sort_values(by="prevalence", ascending=False)
    .emoji_ids.tolist()
)

### load dataset

In [8]:
data = TwemojiData("test_v2")

### helper functions

In [19]:
def get_combined_prediction(X, model1, model2, weighting):
    predictions = model1(X, TEST_IDX)
    _, p_emoji_ids1 = torch.topk(predictions, weighting[0], dim=-1)

    predictions_restricted = model2(X, TEST_IDX)
    _, p_emoji_ids2 = torch.topk(predictions_restricted, weighting[1], dim=-1)

    combined_predictions = torch.cat([p_emoji_ids1, p_emoji_ids2], dim=1)
    return combined_predictions


def get_accuracy(p_emoji_ids, t_emoji_ids):
    accuracy = 0
    for i in range(len(p_emoji_ids)):
        y = set(t_emoji_ids[i])
        predicted_emojis = set(p_emoji_ids[i].tolist())
        accuracy += (1 / len(p_emoji_ids)) * (len(predicted_emojis.intersection(y)) > 0)
    return accuracy

In [20]:
X, y = data[:32]

In [21]:
pred = get_combined_prediction(X, model, model_balanced, [7, 3])
get_accuracy(pred, y)

0.6875