In [8]:
import torch
from transformers import T5Tokenizer, AutoModel

In [9]:
model_name = "rinna/japanese-gpt2-medium"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)


Some weights of the model checkpoint at rinna/japanese-gpt2-medium were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
token_embeddings = model.wte.weight.clone()
vocab = tokenizer.get_vocab()
vectors = {}

In [20]:
for idx in vocab.values():
    vectors[idx] = token_embeddings[idx].detach().numpy().copy()

vectors = sorted(vectors.items(), key=lambda x: x[0])
vectors = [v for idx, v in vectors]

[(0, array([ 0.05126814, -0.05149134, -0.01524023, ..., -0.0925899 ,
        0.06515305,  0.07455667], dtype=float32)), (1, array([ 0.02486288, -0.0375488 ,  0.05762729, ..., -0.02145918,
       -0.00549667,  0.04309609], dtype=float32)), (2, array([-0.08654413,  0.02108668, -0.09448435, ..., -0.05984586,
        0.02151007,  0.07769898], dtype=float32)), (3, array([ 0.006766  , -0.00886511, -0.01670182, ...,  0.06467898,
       -0.03085513, -0.0547839 ], dtype=float32))]


In [12]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0)

In [None]:
import holoviews as hv
from holoviews import opts

hv.extension('plotly')

In [None]:
def clean(text):
    text = text.strip()
    text = text.strip("Ġ")
    return text

In [42]:
reduced_vectors = tsne.fit_transform(vectors)

In [43]:
points = hv.Points(reduced_vectors)
labels = hv.Labels({('x', 'y'): reduced_vectors, 'text': [clean(token) for token, _ in zip(vocab, reduced_vectors)]}, ['x', 'y'], 'text')

(points * labels).opts(
    # opts.Labels(xoffset=0.05, yoffset=0.05, size=6, padding=0.2, width=1500, height=1000),
    opts.Labels(size=6, yoffset=1, width=1500, height=1000),
    opts.Points(color='black', marker='x', size=3),
)

In [44]:
vocab

{'<unk>': 0,
 '<s>': 1,
 '</s>': 2,
 '[PAD]': 3,
 '[CLS]': 4,
 '[SEP]': 5,
 '[MASK]': 6,
 '、': 7,
 '。': 8,
 '▁': 9,
 'の': 10,
 'は': 11,
 'が': 12,
 '・': 13,
 ')': 14,
 '(': 15,
 '年': 16,
 'に': 17,
 'を': 18,
 'で': 19,
 'と': 20,
 '」': 21,
 '月': 22,
 '「': 23,
 '1': 24,
 '2': 25,
 'や': 26,
 'である': 27,
 'から': 28,
 'した': 29,
 'も': 30,
 '3': 31,
 'し': 32,
 '日': 33,
 'として': 34,
 'する': 35,
 '』': 36,
 '4': 37,
 'では': 38,
 '『': 39,
 'た': 40,
 '5': 41,
 'には': 42,
 '6': 43,
 '年に': 44,
 'など': 45,
 '7': 46,
 '10': 47,
 '▁また': 48,
 'という': 49,
 '8': 50,
 '▁-': 51,
 '9': 52,
 'された': 53,
 'している': 54,
 'して': 55,
 'る': 56,
 'な': 57,
 'て': 58,
 '第': 59,
 'ス': 60,
 '-': 61,
 '大': 62,
 '人': 63,
 '"': 64,
 'その': 65,
 '12': 66,
 '日に': 67,
 'ている': 68,
 '市': 69,
 '11': 70,
 'による': 71,
 'となった': 72,
 'によって': 73,
 'により': 74,
 'であり': 75,
 ':': 76,
 '▁この': 77,
 'であった': 78,
 'され': 79,
 'この': 80,
 's': 81,
 '中': 82,
 ',': 83,
 '山': 84,
 'm': 85,
 '.': 86,
 '町': 87,
 'ア': 88,
 'となる': 89,
 'へ': 90,
 'がある': 91,
 '一': 92,
 '