In [1]:
import os
import sys
sys.path.append('../examples')
sys.path.append('../jobs')
sys.path.append('../training_data')

from tqdm import trange

import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from generate_with_calibration import get_lookahead_entropies
from generate_with_entropy import sample_sequence, sample_sequence_batch

import logging
logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR)

In [2]:
# setup cell

def set_seed(seed=42, n_gpu=0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count()

set_seed()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device)
model.eval()

vocab_size = tokenizer.vocab_size

01/21/2020 22:35:12 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /u/myhu/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.085d5f6a8e7812ea05ff0e6ed0645ab2e75d80387ad55c1ad9806ee70d272f80
01/21/2020 22:35:12 - INFO - transformers.configuration_utils -   Model config {
  "attn_pdrop": 0.1,
  "embd_pdrop": 0.1,
  "finetuning_task": null,
  "initializer_range": 0.02,
  "is_decoder": false,
  "layer_norm_epsilon": 1e-05,
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "num_labels": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pruned_heads": {},
  "resid_pdrop": 0.1,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "torchscript": false,
  "use_bf

In [5]:
def calibrate(model, tokenizer, path, save_path, vocab_size, batch_size=512, alpha=0.0, top_k=0, iters=10, threshold=1e-5, device='cpu'):
    alpha = torch.tensor([alpha], requires_grad=True)
    total_loss = CEL(model, tokenizer, path, alpha, vocab_size, batch_size, top_k, device)
    print(f'Total loss: {total_loss.item()}. Alpha: {alpha.item()}')
    last_alpha = alpha.item()
    
    for _ in range(iters):
        grad_a = torch.autograd.grad(total_loss, alpha, create_graph=True)
        grad2_a = torch.autograd.grad(grad_a, alpha)
        alpha.data -= (grad_a[0] / grad2_a[0]).data
        np.savez(save_path, alpha=alpha.item())
        
        total_loss = CEL(model, tokenizer, path, alpha, vocab_size, batch_size, top_k, device)
        print(f'Total loss: {total_loss.item()}. Alpha: {alpha.item()}')
        
        if abs(alpha.data - last_alpha) < threshold:
            break
            
        last_alpha = alpha.item()
    
    return alpha

def CEL(model, tokenizer, path, alpha, vocab_size, batch_size=512, top_k=0, device='cpu'):   
    # calculates the CEL on a single context.
    def CELHelper(context):
        N = len(context)
        context_CEL = torch.tensor([0.0])

        for i in range(1, N):
            with torch.no_grad():
                context_i = torch.tensor(context[:i], dtype = torch.long, device=device).unsqueeze(0)

                inputs = {'input_ids': context_i}
                next_logits = model(**inputs)[0][:, -1, :].detach().cpu()

                if top_k == 0:
                    candidates = None  
                else:
                    candidates = torch.argsort(next_logits[0], descending=True,)[:top_k]
                
                lookahead_ents = get_lookahead_entropies(
                    model = model,
                    context = context_i[0],
                    batch_size = batch_size,
                    vocab_size = vocab_size,
                    candidates = candidates,
                    device = device
                ).cpu()
                
                next_probs = F.softmax(next_logits, dim=-1)[0]
                
                if top_k != 0:
                    # replace uncomputed entropies with average (for centered adjustment)
                    next_probs = next_probs[lookahead_ents != -1]
                    top_average_ent = (lookahead_ents[lookahead_ents != -1] * next_probs / next_probs.sum()).sum()
                    lookahead_ents[lookahead_ents != -1] = top_average_ent
                    print(top_average_ent)
            
            # context[i] is the next word
            context_CEL -= torch.log(
                F.softmax(next_logits - alpha * lookahead_ents, dim=-1)[0][context[i]]
            )
        return context_CEL
    
    total_CEL = torch.tensor([0.0])

    with open(path) as fp:
        for line in fp:
            context = tokenizer.encode(line)
            total_CEL += CELHelper(context)
            
    return total_CEL

In [6]:
calibrate(model = model, 
          tokenizer = tokenizer, 
          path = '../training_data/gbw/training/news1-head100', 
          save_path = 'yeet.npz', 
          vocab_size = vocab_size, 
          batch_size=64, 
          alpha=0.0, top_k=64, iters=10, threshold=1e-5, device=device)

tensor(6.4961)
tensor(6.6620)
tensor(6.7481)
tensor(8.1192)
tensor(3.4060)
tensor(4.8028)
tensor(1.9325)
tensor(2.8117)
tensor(5.1670)
tensor(2.4363)
tensor(4.0686)
tensor(5.4590)
tensor(2.4045)
tensor(1.9969)
tensor(2.3931)
tensor(5.7640)
tensor(4.6259)
tensor(3.8761)
tensor(3.3797)
tensor(4.6152)
tensor(4.0812)
tensor(4.3290)
tensor(3.8643)
tensor(3.7930)
tensor(3.6306)
tensor(4.5981)
tensor(2.6192)
tensor(2.5462)
tensor(4.6567)
tensor(3.1148)
tensor(5.4159)
tensor(4.0128)
tensor(3.4288)
tensor(3.4389)
tensor(3.2748)
tensor(2.5306)
tensor(3.0973)
tensor(3.1128)
tensor(2.9960)
tensor(3.6017)
tensor(4.0000)
tensor(3.0622)
tensor(3.7164)
tensor(4.3095)
tensor(5.5681)
tensor(2.1556)
tensor(3.0182)
tensor(4.9075)


KeyboardInterrupt: 

In [3]:
def getTemp(model, tokenizer, path, vocab_size, batch_size=512, alpha=-0.0298, device='cpu'):
    
    def tempHelper(context):
        N = len(context)
        ret = []

        for i in range(1, N):
            context_i = torch.tensor(context[:i], dtype = torch.long, device=device).unsqueeze(0)

            inputs = {'input_ids': context_i}
            logits = model(**inputs)[0][:, -1, :].detach().cpu()
            
            lookahead_ents = get_lookahead_entropies(
                model = model,
                context = context_i[0],
                batch_size = batch_size,
                vocab_size = vocab_size,
                candidates = None,
                device = device
            ).cpu()
            
            temps = logits / (logits - alpha * lookahead_ents)
            next_probs = F.softmax(logits, dim=-1)

            tmp = np.average(temps, weights=next_probs)
            ret.append(tmp)
        
        print(f'TEMPS ON SUBCONTEXTS: {ret}')
        return np.mean(ret)
    
    temp = []
    with open(path) as fp:
        for line in fp:
            context = tokenizer.encode(line)
            temp.append(tempHelper(context))
            print(f'TEMPS: {temp}')
            np.savez('temps_cache', temp=temp)
    
    return np.mean(temp)

In [None]:
avg_temp = getTemp(model, 
        tokenizer, 
        path='../training_data/gbw/test/100_lines.txt', 
        vocab_size=vocab_size,
        alpha=0.0339,       
        batch_size=128, device=device)
print(avg_temp)

TEMPS ON SUBCONTEXTS: [0.9952983, 0.99895227, 0.99877185, 0.998509, 0.9987402, 0.99738127, 0.99888325, 0.9974514, 0.99884814, 0.99861395]
TEMPS: [0.998145]
TEMPS ON SUBCONTEXTS: [0.9936589, 0.99658066, 0.99761313, 0.99872345, 0.9989922, 0.99835134, 0.99789286, 0.9980367, 0.9981012, 0.99868804, 0.9981411, 0.99833083, 0.9986163, 0.9981435, 0.99874216, 0.9987142, 0.99827087, 0.99932826, 0.98206943, 0.99830616, 0.99862796, 0.9974344, 0.99855417, 0.9980082, 0.9987525, 0.9987006, 0.9979091, 0.9984489, 0.9988872, 0.9993035, 0.9990976, 0.99802727, 0.9993804, 0.9981861, 0.997812, 0.9996179, 0.9989182, 0.9982746, 0.99906623, 0.99859756, 0.9993165, 0.99834657, 0.99896353, 0.9984059, 0.99775684, 0.9990283, 0.99822646]
TEMPS: [0.998145, 0.99802023]
TEMPS ON SUBCONTEXTS: [0.9938946, 0.9975625, 0.99831814, 0.998715, 0.99805033, 0.9976702, 0.9979383, 0.99846315, 0.99825966]
TEMPS: [0.998145, 0.99802023, 0.99765235]
TEMPS ON SUBCONTEXTS: [0.9952983, 0.9983457, 0.99867594, 0.9977489, 0.9990582, 0.999298

In [2]:
temps = np.load('temps_cache.npz')['temp']

In [3]:
np.average(temps)

0.99814785