In [None]:
!pip install transformers

In [41]:
import torch
import numpy as np
import math

from transformers import GPT2Tokenizer, GPT2LMHeadModel

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [5]:
TRAIN_FILE_PATH = "/content/drive/My Drive/wikitext-2-raw/wiki.train.raw"
TEST_FILE_PATH = "/content/drive/My Drive/wikitext-2-raw/wiki.test.raw"

text_train = open(TRAIN_FILE_PATH, 'r').read()
text_test = open(TEST_FILE_PATH, 'r').read()

with open(TRAIN_FILE_PATH + ".short", "w") as f:
  f.write(text_train[:1000000])

with open(TEST_FILE_PATH + ".short", "w") as f:
  f.write(text_test[:500000])

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)

In [None]:
!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/language-modeling/run_language_modeling.py
!ls -l *.py

In [None]:
!python run_language_modeling.py \
    --output_dir=output \
    --model_type=gpt2 \
    --model_name_or_path=gpt2 \
    --do_train \
    --train_data_file=$"/content/drive/My Drive/wikitext-2-raw/wiki.train.raw.short" \
    --do_eval \
    --eval_data_file=$"/content/drive/My Drive/wikitext-2-raw/wiki.test.raw.short"

In [None]:
# Function to first select topN tokens from the probability list and then based on the selected N word distribution
# get random token ID
def choose_from_top(probs, n=5):
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob) # Normalize
    choice = np.random.choice(n, 1, p = top_prob)
    token_id = ind[choice][0]
    return int(token_id), top_prob[choice]

In [None]:
def generate_some_text(input_str, text_len = 100):
    cur_ids = torch.LongTensor(tokenizer.encode(input_str)).to(device)
    k=0
    l=0
    model.eval()
    with torch.no_grad():
        for i in range(text_len):
            outputs = model(cur_ids, labels=cur_ids)
            loss, logits = outputs[:2]
            softmax_logits = torch.softmax(logits[-1], dim=0) #Take the first(only one) batch and the last predicted embedding
            next_token_id, prob = choose_from_top(softmax_logits.cpu().numpy(), n=5) #Randomly(from the given probability distribution) choose the next word from the top n words
            k+=1
            l+=np.log2(prob)
            if ([next_token_id] == tokenizer.encode(tokenizer.eos_token)): # if the network generated the end of the sentence, stop 
              break
            cur_ids = torch.LongTensor(cur_ids.cpu().tolist() + [next_token_id]).to(device) # Add the last word 

        output_text = tokenizer.decode(cur_ids)
        print(output_text)
        print('perplexity=',np.power(2,-l/k))

In [None]:
generate_some_text("The rain was unexpectedly warm")

In [25]:
def count_perplexity(encodings):
  input_ids = encodings.input_ids.to(device)

  with torch.no_grad():
      outputs = model(input_ids, labels=input_ids)
      loss=outputs[0]

  ppl=math.exp(outputs[0])
  return ppl

In [46]:
input=['The moon is made of chocolate.', 'The moon is made of cheese.', 'The moon is made of oxygen and silicon.', 
       'Lions live in forests and eat berries.', 'Lions live in cities and eat hoofed mammals.', 'Lions live in savannas and eat hoofed mammals.',
       'All summer and autumn birds store fat to hibernate for the winte.', 'All summer and autumn bears store fat to hibernate for the winter.', 
       'Caterpillars run fast and see in the dark.', 'Cats run fast and see in the dark.',
       'Dolphins are predators living in meadows.', 'Dolphins are predators living in seas.', 'Dolphins are predators living in the mountains.']
for str in input:
  tokens=tokenizer(str, return_tensors='pt')
  result=count_perplexity(tokens)
  print('perplexity of (', str, ') =', result)

perplexity of ( The moon is made of chocolate. ) = 70.63049732650195
perplexity of ( The moon is made of cheese. ) = 49.780604325072204
perplexity of ( The moon is made of oxygen and silicon. ) = 42.888554838426735
perplexity of ( Lions live in forests and eat berries. ) = 178.15567000019414
perplexity of ( Lions live in cities and eat hoofed mammals. ) = 115.49040691303497
perplexity of ( Lions live in savannas and eat hoofed mammals. ) = 56.169676999696016
perplexity of ( All summer and autumn birds store fat to hibernate for the winte. ) = 185.25426346448492
perplexity of ( All summer and autumn bears store fat to hibernate for the winter. ) = 67.57462327582572
perplexity of ( Caterpillars run fast and see in the dark. ) = 34.48157288879732
perplexity of ( Cats run fast and see in the dark. ) = 67.93385217505953
perplexity of ( Dolphins are predators living in meadows. ) = 217.17167948787576
perplexity of ( Dolphins are predators living in seas. ) = 304.0794363171705
perplexity of (