In [8]:
import torch
from transformers import T5Tokenizer, AutoModel

In [9]:
model_name = "rinna/japanese-gpt2-medium"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)


Some weights of the model checkpoint at rinna/japanese-gpt2-medium were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
token_embeddings = model.wte.weight.clone()
vocab = tokenizer.get_vocab()
vectors = {}

In [20]:
for idx in vocab.values():
    vectors[idx] = token_embeddings[idx].detach().numpy().copy()

vectors = sorted(vectors.items(), key=lambda x: x[0])
vectors = [v for idx, v in vectors]

[(0, array([ 0.05126814, -0.05149134, -0.01524023, ..., -0.0925899 ,
        0.06515305,  0.07455667], dtype=float32)), (1, array([ 0.02486288, -0.0375488 ,  0.05762729, ..., -0.02145918,
       -0.00549667,  0.04309609], dtype=float32)), (2, array([-0.08654413,  0.02108668, -0.09448435, ..., -0.05984586,
        0.02151007,  0.07769898], dtype=float32)), (3, array([ 0.006766  , -0.00886511, -0.01670182, ...,  0.06467898,
       -0.03085513, -0.0547839 ], dtype=float32))]


In [12]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0)

In [None]:
import holoviews as hv
from holoviews import opts

hv.extension('plotly')

In [None]:
def clean(text):
    text = text.strip()
    text = text.strip("Ä ")
    return text

In [42]:
reduced_vectors = tsne.fit_transform(vectors)

In [43]:
points = hv.Points(reduced_vectors)
labels = hv.Labels({('x', 'y'): reduced_vectors, 'text': [clean(token) for token, _ in zip(vocab, reduced_vectors)]}, ['x', 'y'], 'text')

(points * labels).opts(
    # opts.Labels(xoffset=0.05, yoffset=0.05, size=6, padding=0.2, width=1500, height=1000),
    opts.Labels(size=6, yoffset=1, width=1500, height=1000),
    opts.Points(color='black', marker='x', size=3),
)

In [44]:
vocab

{'<unk>': 0,
 '<s>': 1,
 '</s>': 2,
 '[PAD]': 3,
 '[CLS]': 4,
 '[SEP]': 5,
 '[MASK]': 6,
 'ã': 7,
 'ã': 8,
 'â': 9,
 'ã®': 10,
 'ã¯': 11,
 'ã': 12,
 'ã»': 13,
 ')': 14,
 '(': 15,
 'å¹´': 16,
 'ã«': 17,
 'ã': 18,
 'ã§': 19,
 'ã¨': 20,
 'ã': 21,
 'æ': 22,
 'ã': 23,
 '1': 24,
 '2': 25,
 'ã': 26,
 'ã§ãã': 27,
 'ãã': 28,
 'ãã': 29,
 'ã': 30,
 '3': 31,
 'ã': 32,
 'æ¥': 33,
 'ã¨ãã¦': 34,
 'ãã': 35,
 'ã': 36,
 '4': 37,
 'ã§ã¯': 38,
 'ã': 39,
 'ã': 40,
 '5': 41,
 'ã«ã¯': 42,
 '6': 43,
 'å¹´ã«': 44,
 'ãªã©': 45,
 '7': 46,
 '10': 47,
 'âã¾ã': 48,
 'ã¨ãã': 49,
 '8': 50,
 'â-': 51,
 '9': 52,
 'ããã': 53,
 'ãã¦ãã': 54,
 'ãã¦': 55,
 'ã': 56,
 'ãª': 57,
 'ã¦': 58,
 'ç¬¬': 59,
 'ã¹': 60,
 '-': 61,
 'å¤§': 62,
 'äºº': 63,
 '"': 64,
 'ãã®': 65,
 '12': 66,
 'æ¥ã«': 67,
 'ã¦ãã': 68,
 'å¸': 69,
 '11': 70,
 'ã«ãã': 71,
 'ã¨ãªã£ã': 72,
 'ã«ãã£ã¦': 73,
 'ã«ãã': 74,
 'ã§ãã': 75,
 ':