In [2]:
# %matplotlib inline

import os
import sys
sys.path.append('../examples')

from tqdm import trange

import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config

from generate_with_calibration import get_lookahead_entropies, sample_sequence_calibrated

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [3]:
def set_seed(seed=42, n_gpu=0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

In [4]:
# setup cell 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count()

set_seed()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device)
model.eval()

MAX_LENGTH = int(10000)
length = 100

if length < 0 and model.config.max_position_embeddings > 0:
    length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < length:
    length = model.config.max_position_embeddings  # No generation bigger than model size 
elif length < 0:
    length = MAX_LENGTH 

vocab_size = tokenizer.vocab_size

print(f'VOCAB SIZE: {vocab_size}')
print(f'DEVICE: {device}')
print(f'N GPUS: {n_gpus}')

# variables set: model, vocab_size, length, device, n_gpus, context_tokens

11/29/2019 21:11:43 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json from cache at /u/myhu/.cache/torch/transformers/f2808208f9bec2320371a9f5f891c184ae0b674ef866b79c58177067d15732dd.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71
11/29/2019 21:11:43 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt from cache at /u/myhu/.cache/torch/transformers/d629f792e430b3c76a1291bb2766b0a047e36fae0588f9dbc1ae51decdff691b.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
11/29/2019 21:11:49 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /u/myhu/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.085d5f6a8e7812ea05ff0e6ed0645ab2e75d80387ad55c1ad9806ee70d272f80
11/29/2019 2

VOCAB SIZE: 50257
DEVICE: cuda
N GPUS: 2


In [17]:
# raw_text = 'In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.'
# raw_text = 'I like.'

context_tokens = tokenizer.encode(raw_text)



In [18]:
def calibrate(model, context, vocab_size, batch_size = 512, top_k=0, iters=100000,device='cpu'):
    N = len(context)
    Pr = torch.zeros(N-1)
    H = torch.zeros(N-1)
    S = torch.zeros(N-1)

    alpha = torch.randn(1, requires_grad=True)
    optimizer = optim.SGD([alpha], lr=0.001, momentum=0.9)

    def init_CEL():
        # TODO: missing 1st word
        for i in trange(1, N):
            context_i = torch.tensor(context[:i], dtype = torch.long, device=device).unsqueeze(0)
            lookahead_entropies = get_lookahead_entropies(
                model = model,
                context = context_i[0],
                batch_size = batch_size,
                vocab_size = vocab_size,
                device = device
            )

            inputs = {'input_ids': context_i}
            outputs = model(**inputs)
            next_logits = outputs[0][:, -1, :]
            next_probs = F.softmax(next_logits, dim=-1)[0]
    
            # cache useful values
            next_word = context[i]
            Pr[i-1] = next_probs[next_word]
            H[i-1] = lookahead_entropies[next_word]
            S[i-1] = torch.dot(next_probs, torch.exp(lookahead_entropies))

    def CEL(a):
        Za = S * torch.exp(-a)
        return torch.sum(Pr * torch.exp(-a * H) / Za)
    
    init_CEL()
    loss = CEL(alpha)
    loss.backward()
    
    lastloss = loss.clone().item()
    for i in range(iters):
        optimizer.step()
        loss = CEL(alpha)
        
        # print(loss)
        # print(lastloss)
        
        if i % 10000 == 9999:    
            print(f'Loss at iter {i}: {loss}. Alpha: {alpha}.')
        
        if loss.item() - lastloss > 0:
            print(f'Stopping at iteration {i}. Alpha: {alpha}.')
            break
        
        lastloss = loss.clone().item()
         
    return alpha

In [19]:
alpha = calibrate(
    model=model, 
    context=context_tokens, 
    batch_size=128, 
    vocab_size=vocab_size,
    iters=100000,
    device=device)

100%|██████████| 44/44 [25:22<00:00, 60.79s/it]


Stopping at iteration 1862. Alpha: tensor([1.2687], requires_grad=True).


In [12]:
alpha

tensor([1.4898], requires_grad=True)