In [None]:
%load_ext lab_black

In [None]:
from transformers import DistilBertTokenizer, DistilBertModel

import torch
import torch.nn as nn

import pandas as pd
import ast

import time

### load data

In [None]:
df = pd.read_csv("data/processed/emoji_descriptions.csv")
df.emjpd_aliases = df.emjpd_aliases.apply(
    lambda x: ast.literal_eval(x) if isinstance(x, str) else []
)
df.emjpd_aliases = df.emjpd_aliases.apply(lambda x: " ".join(x))

### load model

In [None]:
# distilbert is smaller in number of parameters but comparable to performance of bert
# maybe using a finetuned version would be better, how could we finetune it>?>
base_model_name = "distilbert-base-uncased"  # the tokenizer does not know the emojis
tokenizer = DistilBertTokenizer.from_pretrained(base_model_name)
model = DistilBertModel.from_pretrained(base_model_name)

### load data

In [None]:
def get_model_embedding(text_ls):
    encoded_input = tokenizer(
        text_ls, return_tensors="pt", truncation=True, padding=True
    )
    with torch.no_grad():
        output = model(**encoded_input).last_hidden_state

    result_ls = []
    for i, l in enumerate(encoded_input.attention_mask.sum(dim=1).tolist()):
        result_ls.append(output[i, :l].mean(dim=0))
    return result_ls

In [None]:
def get_embeddings(batch_size=100):

    filler = "\u25A1" * 3
    s = df.emoji_name_og.fillna("") + filler
    s += df.emjpd_aliases.fillna("") + filler
    s += df.emjpd_description_main.fillna("").str.replace("\n", filler) + filler
    s_ls = s.tolist()

    embedding_ls = []

    for i in range(batch_size, len(s_ls), batch_size):
        start = time.time()
        embedding_ls += get_model_embedding(s_ls[i - batch_size : i])
        print(f"processed {i- batch_size} to {i}, took {time.time() - start}")

    embedding_dict = {
        k: v for k, v in zip(df.emoji_char.tolist()[: len(embedding_ls)], embedding_ls)
    }
    return embedding_dict

In [None]:
data = get_embeddings(128)

In [None]:
torch.save(data, "data/processed/emojipedia_description_embedding.ckpt")