In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

In [None]:
# load GPT2 model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2",
                                             output_hidden_states=True,
                                             device_map="auto")

# end of sentence/text token padding
tokenizer.pad_token = tokenizer.eos_token

# Uncommenting this will display the model's configuration
model.config

In [None]:
def next_word_prediction_probs(prompt, n=10):
    inp_tok = tokenizer(prompt,
                        padding=True,
                        return_tensors="pt").to(next(model.parameters()).device)
    input_ids = inp_tok["input_ids"]
    logits = model(**inp_tok).logits[:, -1, :]
    vals = [[tokenizer.decode(tk.item()),
             logits[0][tk.item()].tolist()] for tk in torch.argsort(logits, descending=True)[:, :n][0]]
    return vals

In [None]:
next_word_prediction_probs("")

In [None]:
def next_tokens(prompt, n=10):
    inp_tok = tokenizer(prompt,
                        padding=True,
                        return_tensors="pt").to(next(model.parameters()).device)
    input_ids = inp_tok["input_ids"]
    for i in range(n):
        logits = model(input_ids).logits[:, -1, :]
        pid = torch.argsort(logits, descending=True)[:, :1]
        input_ids =  torch.cat((input_ids, pid),dim=1)
    return tokenizer.decode(input_ids[0])

In [None]:
next_tokens("",n=128)

In [None]:
def next_tokens_mn_generation(prompt, n=10):
    inp_tok = tokenizer(prompt,
                        padding=True,
                        return_tensors="pt").to(next(model.parameters()).device)
    input_ids = inp_tok["input_ids"]

    for i in range(n):
        logits = model(input_ids).logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        pid = torch.multinomial(probs, num_samples=1)
        input_ids =  torch.cat((input_ids, pid),dim=1)
    return tokenizer.decode(input_ids[0])

In [None]:
next_tokens_mn_generation("")

In [None]:
def next_tokens_generation(prompt, n=10,temperature=0.9):
    inp_tok = tokenizer(prompt,
                        padding=True,
                        return_tensors="pt").to(next(model.parameters()).device)
    input_ids = inp_tok.input_ids
    attention_mask = inp_tok["attention_mask"]
    output = model.generate(input_ids,
                            do_sample=True,
                            temperature=temperature,
                            attention_mask=attention_mask,
                            pad_token_id=tokenizer.eos_token_id,
                            max_length=n)
    return tokenizer.decode(output[0])

In [None]:
next_tokens_generation("",n=128)

In [None]:
# extract the embeddings for a sentence
def get_sentence_embeddings(sentence):
    inp_tok = tokenizer(sentence,
             padding=True,
             return_tensors="pt").to(next(model.parameters()).device)
    input_ids = inp_tok["input_ids"]
    output = model(input_ids)

    # return tokenized text for indexing
    tokenized_text = [tokenizer.decode(id).strip() for id in input_ids[0]]

    # extract hidden states
    embs = torch.stack(output['hidden_states'], dim=0)
    embs = torch.squeeze(embs, dim=1)
    embs = embs.permute(1,0,2)

    # mean embeddings in the last four layers
    vectors = [torch.mean(t[-4:], dim=0).to('cpu').detach().numpy() for t in embs]

    return tokenized_text, input_ids, vectors


In [None]:
sentences = ["To carry a Line to haul some of the seal aboard",
            "As a proof, he subjoined Friedemann's letter and seal",
            "a seal, black cherry tree, balm of gilead tree",
            "His Father bad him go and fetch home two Kine to seal",
            "The house in which the salt works is carried on..is also called a seal",
            "In estimating the capacity of a tank and its corresponding holder, due allowance must be made for the height of the dip or seal"
]

In [None]:
tokenizer(" seal")

In [None]:
tokenizer.decode([13810])

In [None]:
tidx = 13810
dists = list()
for sidx, sentence in enumerate(sentences):
  t,i,v = get_sentence_embeddings(sentence)
  dists.append(v[((i==tidx).nonzero().squeeze())[1].to('cpu').numpy()])

In [None]:
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt

pca = PCA(n_components = 2)
plot_data = pca.fit_transform(dists)

# extract x&y values
xs, ys = plot_data[:, 0], plot_data[:, 1]

# create labels
labels = ["seal"] * plot_data.shape[0]

# plot data
fig = plt.figure(figsize=(20, 15))
plt.clf()
plt.title("Contextual Embeddings (PCA)")
plt.style.use('ggplot')
plt.scatter(xs, ys, marker = '^')
for i, w in enumerate(labels):
     plt.annotate(w, xy = (xs[i], ys[i]), xytext = (3, 3),fontsize=14,
                  textcoords = 'offset points',
                  ha = 'left',
                  va = 'top')
plt.show()