<a href="https://colab.research.google.com/github/madscience101/creative-gpt2/blob/master/snippets/GPT2_ban_tokens_(by_index_or_letter).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [0]:
!pip3 install pytorch-pretrained-bert

In [0]:
from pytorch_pretrained_bert import GPT2Tokenizer,GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

## Imports

In [0]:
!export PYTHONIOENCODING=UTF-8

In [0]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model

import logging

import argparse
import logging
from tqdm import trange

import numpy as np
from collections import Counter

# Init the model

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

enc = GPT2Tokenizer.from_pretrained('gpt2')

In [0]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device)

model.eval()

# Banning tokens

In [0]:
def top_k_logits(logits, k):
    """
    Masks everything but the k top entries as -infinity (1e10).
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
    if k == 0:
        return logits
    else:
        values = torch.topk(logits, k)[0]
        batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
        return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)

In [0]:
def get_sequence_masked(model, length,banned, context=None, temperature=1, top_k=50, device='cuda'):
    assert context is not None, 'Specify exactly one of start_token and context!'
    context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
    
    prev = context
    output = context
    past = None
    with torch.no_grad():
        for i in range(length):
            logits, past = model(prev, past=past)
            logits = logits[:, -1, :] / temperature
            
            logits = top_k_logits(logits, k=top_k)
            logits = logits.index_fill_(-1, banned,-1e10)

            log_probs = F.softmax(logits, dim=-1)
            prev = torch.multinomial(log_probs, num_samples=1)    
            output = torch.cat((output, prev), dim=1)
    return output

In [0]:
def generate_samples_masked(prompt,model,device,enc,banned_tokens=[50256],
              nsamples=1,length=200,temperature=1, 
              top_k=40, seed=0):
   
    banned_idx = torch.LongTensor(banned_tokens).to(device)

    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    context_tokens = enc.encode(prompt)
        
    generated = 0
    
    with torch.no_grad():
      for _ in range(nsamples):
              out = get_sequence_masked(
                  model=model, length=length,
                  context=context_tokens,
                  banned = banned_idx,
                  temperature=temperature, top_k=top_k, device=device
              )

              out_trimmed = out[:, len(context_tokens):].tolist()[0]
              generated += 1
              text = enc.decode(out_trimmed)
              print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
              print(text)
      print("=" * 80)

## TEST - ban tokens

In [0]:
global_banned_tokens = [50256] #EndOfText
local_banned_tokens = [198, 628] #\n and \n\n

In [0]:
banned_tokens = global_banned_tokens+local_banned_tokens

In [0]:
prompt = 'In the beginning'

generate_samples_masked(prompt, model, device, enc, nsamples=3, length=30, banned_tokens=banned_tokens, seed=0)

# Ban tokens with a specific letter

In [0]:
NUM_TOKENS = len(enc.encoder)
TOKENS = np.array([enc.decode([i]) for i in range(NUM_TOKENS)])
len(TOKENS), NUM_TOKENS

In [0]:
TOKENS_LOWER = [t.lower() for t in TOKENS]

In [0]:
def tokens_with_letter(letter):
  letter_in_tokens = np.array([letter in t for t in TOKENS_LOWER])
  return(np.where(letter_in_tokens==True)[0])

In [0]:
letter = 't'
banned_by_letter = tokens_with_letter(letter)

prompt = 'In the beginning,'
generate_samples_masked(prompt, model, device, enc, nsamples=3, length=30, banned_tokens=banned_by_letter, seed=0)