In [1]:
import os
import sys
sys.path.append('../examples')
sys.path.append('../jobs')
sys.path.append('../training_data')

from tqdm import trange

import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from generate_with_calibration import get_lookahead_entropies
from generate_with_entropy import sample_sequence, sample_sequence_batch

import logging
logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR)

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [2]:
# setup cell

def set_seed(seed=42, n_gpu=0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count()

set_seed()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device)
model.eval()

vocab_size = tokenizer.vocab_size

12/29/2019 21:48:04 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /u/myhu/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.085d5f6a8e7812ea05ff0e6ed0645ab2e75d80387ad55c1ad9806ee70d272f80
12/29/2019 21:48:04 - INFO - transformers.configuration_utils -   Model config {
  "attn_pdrop": 0.1,
  "embd_pdrop": 0.1,
  "finetuning_task": null,
  "initializer_range": 0.02,
  "is_decoder": false,
  "layer_norm_epsilon": 1e-05,
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "num_labels": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pruned_heads": {},
  "resid_pdrop": 0.1,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "torchscript": false,
  "use_bf

In [3]:
def calibrate(model, tokenizer, path, vocab_size, batch_size=512, alpha=0.0, top_k=0, iters=10, threshold=1e-5, device='cpu'):
    alpha = torch.tensor([alpha], requires_grad=True)
    total_loss = CEL(model, tokenizer, path, alpha, vocab_size, batch_size, top_k, device)
    print(f'Total loss: {total_loss.item()}. Alpha: {alpha.item()}')
    last_alpha = alpha.item()
    
    for _ in range(iters):
        grad_a = torch.autograd.grad(total_loss, alpha, create_graph=True)
        grad2_a = torch.autograd.grad(grad_a, alpha)
        alpha.data -= (grad_a[0] / grad2_a[0]).data
        
        total_loss = CEL(model, tokenizer, path, alpha, vocab_size, batch_size, top_k, device)
        print(f'Total loss: {total_loss.item()}. Alpha: {alpha.item()}')
        
        if abs(alpha.data - last_alpha) < threshold:
            break
            
        last_alpha = alpha.item()
    
    return alpha

def CEL(model, tokenizer, path, alpha, vocab_size, batch_size=512, top_k=0, device='cpu'):   
    # calculates the CEL on a single context.
    def CELHelper(context):
        N = len(context)
        context_CEL = torch.tensor([0.0])

        for i in range(1, N):
            context_i = torch.tensor(context[:i], dtype = torch.long, device=device).unsqueeze(0)

            inputs = {'input_ids': context_i}
            next_logits = model(**inputs)[0][:, -1, :].detach().cpu()

            if top_k == 0:
                candidates = None  
            else:
                candidates = torch.argsort(next_logits[0], descending=True,)[:top_k]
            
            lookahead_ents = get_lookahead_entropies(
                model = model,
                context = context_i[0],
                batch_size = batch_size,
                vocab_size = vocab_size,
                candidates = candidates,
                device = device
            ).cpu()
            
            if top_k != 0:
                # replace uncomputed entropies with average (for centered adjustment)
                top_average_ent = lookahead_ents[lookahead_ents != -1].mean()
                lookahead_ents[lookahead_ents != -1] = top_average_ent
            
            next_word = context[i]
            next_probs = F.softmax(next_logits, dim=-1)[0]
            Pr = next_probs[next_word]
            H = lookahead_ents[next_word]
            Za = torch.dot(next_probs, torch.exp(-alpha * lookahead_ents))

            context_CEL -= torch.log(Pr * torch.exp(-alpha * H) / Za)
            
        return context_CEL
    
    total_CEL = torch.tensor([0.0])

    with open(path) as fp:
        for line in fp:
            context = tokenizer.encode(line)
            # one way to fix memory issues: uncomment the below
            # if (len(context) > 100):
            #    continue
            total_CEL += CELHelper(context)
            
    return total_CEL

In [4]:
calibrate(model, 
          tokenizer, 
          path='../training_data/gbw/test/five_lines.txt', 
          vocab_size=vocab_size, 
          batch_size=128, top_k=128, iters=10, threshold=1e-6, device=device)

Total loss: 567.3639526367188. Alpha: 0.0
Total loss: 565.7002563476562. Alpha: 0.11386551707983017
tensor([0.1139])
Total loss: 565.6583251953125. Alpha: 0.13583789765834808
tensor([0.0220])
Total loss: 565.6582641601562. Alpha: 0.13663947582244873
tensor([0.0008])
Total loss: 565.6583251953125. Alpha: 0.13664056360721588
tensor([1.0878e-06])
Total loss: 565.6583251953125. Alpha: 0.1366405040025711
tensor([5.9605e-08])


tensor([0.1366], requires_grad=True)