In [24]:
# %matplotlib inline

import os
import sys
sys.path.append('../examples')

from tqdm import trange

import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config

from generate_with_calibration import get_lookahead_entropies
# import calibrate as cal

In [3]:
def set_seed(seed=42, n_gpu=0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

In [4]:
# setup cell 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count()

set_seed()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device)
model.eval()

MAX_LENGTH = int(10000)
length = 100

if length < 0 and model.config.max_position_embeddings > 0:
    length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < length:
    length = model.config.max_position_embeddings  # No generation bigger than model size 
elif length < 0:
    length = MAX_LENGTH 

vocab_size = tokenizer.vocab_size
raw_text = 'In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.'
context_tokens = tokenizer.encode(raw_text)
# print(context_tokens)

print(f'VOCAB SIZE: {vocab_size}')
print(f'DEVICE: {device}')
print(f'N GPUS: {n_gpus}')

# variables set: model, vocab_size, length, device, n_gpus, context_tokens

11/28/2019 21:13:25 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json from cache at /home/michael/.cache/torch/transformers/f2808208f9bec2320371a9f5f891c184ae0b674ef866b79c58177067d15732dd.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71
11/28/2019 21:13:25 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt from cache at /home/michael/.cache/torch/transformers/d629f792e430b3c76a1291bb2766b0a047e36fae0588f9dbc1ae51decdff691b.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
11/28/2019 21:13:26 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /home/michael/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.085d5f6a8e7812ea05ff0e6ed0645ab2e75d80387ad55c1ad9806ee70d2

VOCAB SIZE: 50257
DEVICE: cpu
N GPUS: 0


In [22]:
def calibrate(model, context, batch_size, vocab_size, top_k=0, iters=1000,device='cpu'):
    N = len(context)
    Pr = torch.zeros(N)
    H = torch.zeros(N)
    S = torch.zeros(N)

    alpha = torch.randn(1, requires_grad=True)
    optimizer = optim.SGD([alpha], lr=0.001, momentum=0.9)

    def init_CEL():
        for i in trange(N):
            context_i = torch.tensor(context[:i], dtype = torch.long, device=device).unsqueeze(0)
            
            lookahead_entropies = get_lookahead_entropies(
                model = model,
                context = context_i[0],
                batch_size = 512,
                vocab_size = vocab_size,
                device = device
            )

            inputs = {'input_ids': context_i}
            outputs = model(**inputs)
            next_logits = outputs[0][:, -1, :]
            next_probs = F.softmax(next_logits, dim=-1)

            # cache useful values
            next_word = context[i]
            Pr[i] = next_probs[next_word]
            H[i] = lookahead_entropies[i]
            S[i] = torch.dot(next_probs, torch.exp(-lookahead_entropies))

    def CEL():
        Za = S * torch.exp(alpha)
        return torch.sum(Pr * torch.exp(-alpha * H) / Za)
    
    init_CEL()

    for i in trange(iters):
        optimizer.zero_grad()
        loss = CEL()
        loss.backward()
        optimizer.step()

        # print statistics
        if i % 100 == 99:    
            print(f'Loss at iter {i}: {loss}') # TODO: format this

    return alpha

In [25]:
alpha = calibrate(
    model=model, 
    context=context_tokens, 
    batch_size=128, 
    vocab_size=vocab_size, 
    device=device)

  0%|          | 0/45 [00:00<?, ?it/s]

KeyboardInterrupt: 