In [None]:
%load_ext lab_black

In [None]:
import torch
import torch.nn as nn
from cnn import Img2Vec

from emoji_image_dataset import EmojiClassificationDataset

import pandas as pd

from embedding_analysis import EmbeddingAnalysis

### load data

In [None]:
df = pd.read_csv("data/processed/emoji_descriptions.csv")[
    ["emjpd_emoji_name", "emoji_char"]
].drop_duplicates()
df.columns = ["label", "emoji_char"]

In [None]:
X_train = EmojiClassificationDataset("train", label_flag=False)
df_train = (
    pd.read_csv("data/meta/img_meta_train.csv")
    .drop_duplicates(subset="label")
    .reset_index()
    .rename(columns={"index": "old_index"})
).merge(df, how="left")

In [None]:
X_zero = EmojiClassificationDataset("zeroshot", label_flag=False)
df_zero = (
    pd.read_csv("data/meta/img_meta_zeroshot.csv")
    .drop_duplicates(subset="label")
    .reset_index()
    .rename(columns={"index": "old_index"})
).merge(df, how="left")

### load model

In [None]:
model = Img2Vec(200, "model/emoji_image_embedding/emimem.ckpt")
model.eval()

sim = nn.CosineSimilarity()

### create embeddings

In [None]:
def create_embeddings(X, df):
    ls = []
    batch_size = 64
    for b in range(0, len(df), batch_size):
        batch = []
        for idx in df.old_index[b : b + batch_size]:
            batch.append(X[idx].unsqueeze(0))
        batch = torch.cat(batch, dim=0)
        o = model(batch)
        ls.append(o)

    embeddings = torch.concat(ls, dim=0)
    return embeddings

In [None]:
%%time
embeddings_train = create_embeddings(X_train, df_train)
index_label_train = df_train.label.to_dict()
label_emoji = {k:v for k, v in zip(df_train.label, df_train.emoji_char)}
ea_train = EmbeddingAnalysis(embeddings_train, index_label_train, label_emoji)

In [None]:
%%time 
embeddings_zero = create_embeddings(X_zero, df_zero)
index_label_zero = df_zero.label.to_dict()
label_emoji = {k:v for k, v in zip(df_zero.label, df_zero.emoji_char)}
ea_zero = EmbeddingAnalysis(embeddings_zero, index_label_zero, label_emoji)

In [None]:
df_total = pd.concat([df_train, df_zero]).reset_index(drop=True)
index_label_total = df_total.label.to_dict()
label_emoji_total = {k: v for k, v in zip(df_total.label, df_total.emoji_char)}
embeddings_total = torch.cat([embeddings_train, embeddings_zero], dim=0)
ea_total = EmbeddingAnalysis(embeddings_total, index_label_total, label_emoji_total)

In [None]:
df_total

In [None]:
embeddings_total.shape

## similarity check

### training data

In [None]:
df_train.sample(10)

In [None]:
ea_train.most_similar("pensive_face")

In [None]:
ea_train.most_similar("smiling_face_with_smiling_eyes")

In [None]:
ea_train.most_similar("face_vomiting")

In [None]:
ea_train.most_similar("see-no-evil_monkey")

In [None]:
ea_train.most_similar("flag-_lebanon")

In [None]:
ea_train.most_similar("eggplant")

### zero shot

In [None]:
df_total.loc[df_total["class"].isna()].sample(10)

In [None]:
ea_total.most_similar("sparkling_heart")

In [None]:
ea_total.most_similar("weary_face")

In [None]:
ea_total.most_similar("last_quarter_moon")

In [None]:
ea_total.most_similar("flag-_bouvet_island")

In [None]:
ea_total.most_similar("face_with_open_mouth")

In [None]:
ea_total.most_similar("enraged_face")

In [None]:
ea_total.most_similar("flag-_china")

In [None]:
ea_total.most_similar("minibus")

In [None]:
ea_total.most_similar("1st_place_medal")

In [None]:
ea_total.most_similar("waffle")

In [None]:
ea_total.most_similar("chart_increasing")

In [None]:
ea_total.most_similar("sauropod")

In [None]:
ea_total.most_similar("flexed_biceps")

In [None]:
ea_total.most_similar("bacon")