In [1]:
import torch
import emoji
import os
import sys
sys.path.insert(0, '../')

from config import GPT2EmojiConfig
from model import GPT2LMEmojiModel
from transformers import GPT2Tokenizer
from run_language_modeling import load_and_cache_examples, targets_mask

In [2]:
MODEL_CLASSES = {
    "gpt2": (GPT2EmojiConfig, GPT2LMEmojiModel, GPT2Tokenizer),
}

In [3]:
MODEL_PATH = '../checkpoint-180000'

In [4]:
args = torch.load(os.path.join(MODEL_PATH, 'training_args.bin'))

In [5]:
config_class, model_class, tokenizer_class = MODEL_CLASSES['gpt2']

In [6]:
config = config_class.from_pretrained(MODEL_PATH)

In [7]:
tokenizer = tokenizer_class.from_pretrained(MODEL_PATH)

In [8]:
model = model_class.from_pretrained(
            MODEL_PATH,
            config=config,
)

In [9]:
args.eval_data_file = '../' + args.eval_data_file
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

In [10]:
eval_dataset[0]

tensor([15883,   683,   262, 49414,     0,  1312,   760,   334,  3730,   481,
        52291])

In [11]:
map_token_id_to_target = dict(
        zip(tokenizer.encode(list(emoji.UNICODE_EMOJI.keys())), range(0, len(emoji.UNICODE_EMOJI.keys())))
    )

Token indices sequence length is longer than the specified maximum sequence length for this model (2811 > 1024). Running this sequence through the model will result in indexing errors


In [12]:
map_target_to_token_id = dict(
        zip(range(0, len(emoji.UNICODE_EMOJI.keys())), tokenizer.encode(list(emoji.UNICODE_EMOJI.keys())))
    )

Token indices sequence length is longer than the specified maximum sequence length for this model (2811 > 1024). Running this sequence through the model will result in indexing errors


In [13]:
def map_targets_to_tokens(targets):
    ids = [map_target_to_token_id[target] for target in targets]
    return tokenizer.decode(ids)

In [14]:
inputs = eval_dataset[0].view(1, -1)
labels, mask = targets_mask(inputs, tokenizer, map_token_id_to_target)

In [15]:
outputs = model(inputs, labels=labels, inputs_mask=mask)

In [16]:
top5_preds = outputs[1].view(inputs.size(1), -1).topk(5, dim=1)[1].tolist()

In [17]:
top5_preds

[[607, 391, 403, 1134, 806],
 [806, 1134, 1723, 884, 804],
 [939, 2160, 683, 2239, 1059],
 [1134, 1936, 2041, 2154, 806],
 [2154, 1134, 1936, 2182, 2305],
 [1936, 1982, 2035, 2305, 1134],
 [1134, 1936, 2041, 1723, 2154],
 [1936, 1134, 2305, 2035, 793],
 [1134, 1936, 1982, 2305, 2182],
 [1936, 2305, 2182, 1134, 1982],
 [2035, 1982, 1134, 793, 2305]]

In [18]:
tokenizer.convert_ids_to_tokens(inputs[0])

['make',
 'Ġhim',
 'Ġthe',
 'Ġhappiest',
 '!',
 'Ġi',
 'Ġknow',
 'Ġu',
 'Ġguys',
 'Ġwill',
 '❤']

In [19]:
[(map_targets_to_tokens(pred), text) for pred, text in zip(top5_preds, tokenizer.convert_ids_to_tokens(inputs[0]))]

[('👏 👇 👉 😭 😂', 'make'),
 ('😂 😭 😔 😳 😤', 'Ġhim'),
 ('🐐 🐍 👑 🍊 🔑', 'Ġthe'),
 ('😭 🥺 😌 😍 😂', 'Ġhappiest'),
 ('😍 😭 🥺 💖 💕', '!'),
 ('🥺 💜 ❤ 💕 😭', 'Ġi'),
 ('😭 🥺 😌 😔 😍', 'Ġknow'),
 ('🥺 😭 💕 ❤ 😘', 'Ġu'),
 ('😭 🥺 💜 💕 💖', 'Ġguys'),
 ('🥺 💕 💖 😭 💜', 'Ġwill'),
 ('❤ 💜 😭 😘 💕', '❤')]

In [20]:
def get_preds(model, item, topk=5):
    inputs = item.view(1, -1)
    labels, mask = targets_mask(inputs, tokenizer, map_token_id_to_target)
    
    outputs = model(inputs, labels=labels, inputs_mask=mask)
    topk_preds = outputs[1].view(inputs.size(1), -1).topk(topk, dim=1)[1].tolist()
    return topk_preds

In [21]:
def ids_to_tokens(inp):
    return tokenizer.convert_ids_to_tokens(inp)

In [22]:
[(map_targets_to_tokens(pred), text) for pred, text in zip(get_preds(model, eval_dataset[4]),ids_to_tokens(eval_dataset[4])) ]

[('👏 ❤ 😂 😭 😍', 'I'),
 ('😂 😭 💯 😎 😌', 'Ġget'),
 ('😭 😂 😩 😔 🙄', 'Ġdistracted'),
 ('😂 😭 🙃 🙄 😅', 'Ġeasily'),
 ('😂 😭 😔 🙄 🙃', 'Ġl'),
 ('😂 😭 💀 🤣 😅', 'ma'),
 ('😂 😭 💀 🤣 😅', 'o'),
 ('😂 😭 💀 😅 🤣', 'Ġmainly'),
 ('😂 😭 😅 💀 🙄', 'Ġbecause'),
 ('😂 😭 🙃 😅 🙄', 'Ġof'),
 ('😂 😭 😅 💀 🙃', 'Ġmusic'),
 ('😂 😭 🤣 😅 💀', '😂')]

In [23]:
data = eval_dataset[44]
[(map_targets_to_tokens(pred), text) for pred, text in zip(get_preds(model, data),ids_to_tokens(data)) ]

[('👏 ❤ 😂 😭 😍', 'I'),
 ('😭 😂 😩 😔 😊', 'Ġmust'),
 ('😂 😊 😅 👍 😉', 'Ġsay'),
 ('😂 😭 😍 😁 😅', '!!'),
 ('😂 🔥 😍 👍 😭', 'ĠSo'),
 ('😂 😭 😍 😅 😩', 'oo'),
 ('😂 😭 😍 😩 😊', 'Ġfl'),
 ('😂 😭 😍 😊 🙄', 'ipp'),
 ('😂 😭 😍 😩 🙄', 'in'),
 ('😍 😭 😂 🥺 😊', 'Ġcute'),
 ('😍 😭 💜 🥺 😂', '!!!'),
 ('🌹 😘 🌼 😊 ✨', '🌹'),
 ('🌹 🌸 🌼 🌻 ✨', '🦋'),
 ('🌹 🌼 🌻 🌸 🦋', '🦊'),
 ('🌹 🌸 🌷 🌼 🌻', '🌶'),
 ('🌸 🌹 🌻 🌷 🌼', '🦄')]

In [24]:
data = eval_dataset[124]
[(map_targets_to_tokens(pred), text) for pred, text in zip(get_preds(model, data),ids_to_tokens(data)) ]

[('😂 😭 👀 👏 😍', 'Wait'),
 ('😂 🔥 😍 👇 👀', 'Ġfor'),
 ('🐐 🔥 🍆 👑 🍊', 'Ġthe'),
 ('😂 😭 😍 🔥 🙏', 'Ġmass'),
 ('😂 😷 😭 👀 😁', 'Ġvaccination'),
 ('😂 😭 👀 🙄 😬', 'Ġcampaign'),
 ('😂 🙄 😭 👍 😁', 'Ġvery'),
 ('😂 😭 👀 🙏 😷', 'Ġsoon'),
 ('👀 😂 🤣 👇 🤔', '👀'),
 ('👇 👀 🏽 🏼 🏾', '👂')]

In [25]:
data = eval_dataset[548]
[(map_targets_to_tokens(pred), text) for pred, text in zip(get_preds(model, data),ids_to_tokens(data)) ]

[('😭 😂 👏 🙏 😔', 'Rest'),
 ('🙏 😭 😔 😂 🥺', 'Ġhelps'),
 ('🙏 😭 👍 ❤ 😔', 'Ġa'),
 ('🙏 😭 💜 🥺 ❤', 'Ġlot'),
 ('🙏 😭 ❤ 💜 😊', 'Ġand'),
 ('🙏 💜 😊 ❤ 💙', 'Ġhaving'),
 ('🙏 💜 👍 💕 ✨', 'Ġquality'),
 ('🙏 💜 😊 👍 💕', 'Ġtime'),
 ('😊 🙏 💜 ❤ 💙', 'Ġjust'),
 ('😊 🙏 💕 ❤ ✨', 'Ġbeing'),
 ('💜 💕 ✨ 💙 ❤', 'Ġyourself'),
 ('❤ 💜 😊 💕 💙', 'Ġand'),
 ('❤ 😊 💜 💕 🙏', 'Ġnot'),
 ('❤ 🙏 💜 💙 💕', 'ĠM'),
 ('🙏 💕 💜 💙 ❤', 'omm'),
 ('🙏 😊 💜 ❤ 💕', 'ing'),
 ('🙏 😊 ❤ 💜 💙', '.'),
 ('🖤 ✨ 💛 \U0001f90d 💙', '🖤')]

In [30]:
word_good = torch.tensor(tokenizer.encode('good'))
[(map_targets_to_tokens(pred), text) for pred, text in zip(get_preds(model, word_good),ids_to_tokens(word_good)) ]

[('👏 😍 😭 👍 🥺', 'good')]

In [33]:
word_bad = torch.tensor(tokenizer.encode('you are mean'))
[(map_targets_to_tokens(pred), text) for pred, text in zip(get_preds(model, word_bad),ids_to_tokens(word_bad)) ]

[('👏 ❤ 😂 😭 🥺', 'you'), ('🥺 ❤ 💖 💜 😭', 'Ġare'), ('😂 😔 😭 😳 🙄', 'Ġmean')]

In [34]:
data = eval_dataset[548]
get_preds(model, data)

[[1134, 806, 607, 889, 1723],
 [889, 1134, 1723, 806, 1936],
 [889, 1134, 2265, 2035, 1723],
 [889, 1134, 1982, 1936, 2035],
 [889, 1134, 2035, 1982, 2156],
 [889, 1982, 2156, 2035, 494],
 [889, 1982, 2265, 2305, 2181],
 [889, 1982, 2156, 2265, 2305],
 [2156, 889, 1982, 2035, 494],
 [2156, 889, 2305, 2035, 2181],
 [1982, 2305, 2181, 494, 2035],
 [2035, 1982, 2156, 2305, 494],
 [2035, 2156, 1982, 2305, 889],
 [2035, 889, 1982, 494, 2305],
 [889, 2305, 1982, 494, 2035],
 [889, 2156, 1982, 2035, 2305],
 [889, 2156, 2035, 1982, 494],
 [465, 2181, 2757, 2379, 494]]

In [41]:
from scipy.stats import pearsonr
def get_batch_logits(model, data):
    inputs = data.view(1, -1)
    labels, mask = targets_mask(inputs, tokenizer, map_token_id_to_target)
    
    outputs = model(inputs, labels=labels, inputs_mask=mask)
    return outputs[1].view(inputs.size(1), -1)

In [44]:
data = eval_dataset[548]
get_batch_logits(model, data)

tensor([[-4.0006, -7.5493, -9.0351,  ..., -6.0885, -6.6080, -7.2215],
        [-5.3994, -9.2103, -9.7454,  ..., -6.3101, -7.9404, -8.5960],
        [-3.8682, -7.1625, -7.3015,  ..., -4.6516, -5.0730, -6.3739],
        ...,
        [-5.5149, -7.9127, -8.5802,  ..., -6.3734, -6.4540, -7.8369],
        [-4.4432, -6.7025, -7.3937,  ..., -5.1135, -5.3539, -6.7177],
        [-1.6423, -2.7787, -3.0285,  ..., -1.3170, -1.6822, -2.2014]],
       grad_fn=<ViewBackward>)