In [1]:
import torch
from transformers import AutoModel, AutoTokenizer

In [2]:
# model_name = "cl-tohoku/bert-base-japanese-whole-word-masking"
model_name = "gpt2"
# model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.8.attn.masked_bias', 'h.1.attn.masked_bias', 'h.4.attn.masked_bias', 'h.3.attn.masked_bias', 'h.11.attn.masked_bias', 'h.7.attn.masked_bias', 'h.10.attn.masked_bias', 'h.5.attn.masked_bias', 'h.2.attn.masked_bias', 'h.6.attn.masked_bias', 'h.0.attn.masked_bias', 'h.9.attn.masked_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
token_embeddings = model.get_input_embeddings().weight.clone()
vocab = tokenizer.get_vocab()
vectors = {}

In [4]:
for idx in vocab.values():
    vectors[idx] = token_embeddings[idx].detach().numpy().copy()

vectors = sorted(vectors.items(), key=lambda x: x[0])
vectors = [v for idx, v in vectors]

In [5]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0)

import holoviews as hv
from holoviews import opts

hv.extension('plotly')

def clean(text):
    text = text.strip()
    text = text.strip("Ġ")
    return text

In [16]:
reduced_vectors = tsne.fit_transform(vectors)

In [17]:
points = hv.Points(reduced_vectors)
labels = hv.Labels({('x', 'y'): reduced_vectors, 'text': [clean(token) for token, _ in zip(vocab, reduced_vectors)]}, ['x', 'y'], 'text')

(points * labels).opts(
    # opts.Labels(xoffset=0.05, yoffset=0.05, size=6, padding=0.2, width=1500, height=1000),
    opts.Labels(size=6, yoffset=1, width=1500, height=1000),
    opts.Points(color='black', marker='x', size=3),
)

In [8]:
vocab

{'Ġpunitive': 32952,
 'inness': 32990,
 'ichita': 41940,
 'Ġtonight': 9975,
 'Ġstriker': 19099,
 'prev': 47050,
 'ĠEP': 14724,
 'Ġrotating': 24012,
 'ĠKar': 9375,
 'Ġdead': 2636,
 'ĠBanner': 27414,
 'oreAnd': 40219,
 'Mut': 41603,
 'Ġweren': 6304,
 'ĠADHD': 22822,
 'serious': 34009,
 '204': 18638,
 'scene': 29734,
 'humane': 44766,
 'osure': 4567,
 'Ġvolunteers': 11661,
 'Ġward': 15305,
 'Ġexpend': 12220,
 'Ġabduction': 39630,
 'irez': 31762,
 'ĠPerspective': 42051,
 'ĠPhen': 34828,
 'Ġcalam': 35765,
 'Ġreconsider': 26540,
 'VB': 44526,
 'Ġcasino': 21507,
 'Ġabove': 2029,
 'Ġscratch': 12692,
 'Ġtremend': 11039,
 'Prof': 15404,
 '993': 44821,
 'Ġransomware': 49134,
 'eenth': 28117,
 'Ġwhale': 22206,
 'Ġhostile': 12524,
 'XXXX': 24376,
 'ynasty': 19488,
 'ĠNES': 31925,
 'Ast': 33751,
 'Ġprick': 41409,
 'Ġgrasping': 44787,
 'Ġcolonists': 43430,
 'Ġterrit': 5771,
 'Har': 13587,
 'ieties': 9545,
 'Ġtender': 15403,
 'stay': 31712,
 'Ġabolished': 32424,
 'ĠSupports': 45267,
 'Ġoffense': 6907,