# BERT as conditional generative model

In [2]:
!pip install pytorch_pretrained_bert #installation of bert model



In [3]:
import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet 

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


Load pretained Bert model

In [0]:
# Load pre-trained model (weights)
model_version = 'bert-base-uncased'
model = BertForMaskedLM.from_pretrained(model_version)
model.eval()
cuda = torch.cuda.is_available()
if cuda:
    model = model.cuda()

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=model_version.endswith("uncased"))

def tokenize_batch(batch):
    return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]

def untokenize_batch(batch):
    return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]

def detokenize(sent):
    """ Roughly detokenizes (mainly undoes wordpiece) """
    new_sent = []
    for i, tok in enumerate(sent):
        if tok.startswith("##"):
            new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]
        else:
            new_sent.append(tok)
    return new_sent

CLS = '[CLS]'
SEP = '[SEP]'
MASK = '[MASK]'
mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
sep_id = tokenizer.convert_tokens_to_ids([SEP])[0]
cls_id = tokenizer.convert_tokens_to_ids([CLS])[0]

Basic generation step,printer (inbuilt)

generating intial masked string with given word at random place (modified given function)

In [0]:
def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True):
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    logits = out[:, gen_idx]
    if temperature is not None:
        logits = logits / temperature
    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k, dim=-1)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().squeeze(-1)
    else:
        idx = torch.argmax(logits, dim=-1)
    return idx.tolist() if return_list else idx
  
  
def get_init_text(seed_text,rand_kk, max_len,conditional_word, batch_size = 1, rand_init=False):
    """ Get initial sentence by padding seed_text with either masks or random words to max_len """
    batch = [seed_text + [MASK] * (max_len) + [SEP] for _ in range(batch_size)]
    seed_len = len(seed_text)
    for jj in range(batch_size):     #masking our word to generate context revolving around
      batch[jj][seed_len + rand_kk] = conditional_word
    #if rand_init:
    #    for ii in range(max_len):
    #        init_idx[seed_len+ii] = np.random.randint(0, len(tokenizer.vocab))
    #print(batch)
    return tokenize_batch(batch)

def printer(sent, should_detokenize=True):
    if should_detokenize:
        sent = detokenize(sent)[1:-1]
    print(" ".join(sent))


Below is modified parallel sequential genrator:
- genrate initial sequence by calling get_init_text
- masked word is only replaced at position where we didn't put our conditonal word
- every iteration we make sure at given postion word is replace by given conditional word

In [0]:
import math
import time

def parallel_sequential_generation(seed_text,conditional_word, batch_size=10, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200,
                                   cuda=False, print_every=10, verbose=True):
    """ Generate for one random position at a timestep
    
    args:
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax
    """
    rand_kk = np.random.randint(0,max_len)
    seed_len = len(seed_text)
    batch = get_init_text(seed_text,rand_kk, max_len,conditional_word, batch_size)
    
    for ii in range(max_iter):
        kk = np.random.randint(0, max_len)
        if(kk != rand_kk):
          for jj in range(batch_size):
              batch[jj][seed_len+kk] = mask_id
          inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
          out = model(inp)
          topk = top_k if (ii >= burnin) else 0
          idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk, temperature=temperature, sample=(ii < burnin))
          for jj in range(batch_size):
              batch[jj][seed_len+kk] = idxs[jj]
              
          if verbose and np.mod(ii+1, print_every) == 0:
              for_print = tokenizer.convert_ids_to_tokens(batch[0])
              for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
              print("iter", ii+1, " ".join(for_print))
            
    return untokenize_batch(batch)

def generate(n_samples,conditional_word, seed_text="[CLS]", batch_size=10, max_len=25, 
             generation_mode="parallel-sequential",
             sample=True, top_k=100, temperature=1.0, burnin=200, max_iter=500,
             cuda=False, print_every=1):
    # main generation function to call
    sentences = []
    n_batches = math.ceil(n_samples / batch_size)
    start_time = time.time()
    for batch_n in range(n_batches):
        if generation_mode == "parallel-sequential":
            batch = parallel_sequential_generation(seed_text,conditional_word, batch_size=batch_size, max_len=max_len, top_k=top_k,
                                                   temperature=temperature, burnin=burnin, max_iter=max_iter, 
                                                   cuda=cuda, verbose=False)
        if (batch_n + 1) % print_every == 0:
            print("Finished batch %d in %.3fs" % (batch_n + 1, time.time() - start_time))
            start_time = time.time()
        
        sentences += batch
    return sentences

Generating statements of length 10 (you can modify parameters according to requirements)
some parameters are:
- batch_size : no. of sentences for every conditional word
- top_k : selection of word from most likely how much word
- max_iter : how many iteration of bert

In [0]:
def generate_statements(synonyms):
  n_samples = 1 #5
  batch_size = 2 #5
  max_len = 10 #40
  top_k = 10 #100
  temperature = 1.0
  generation_mode = "parallel-sequential"
  leed_out_len = 5 # max_len
  burnin = 250
  sample = True
  max_iter = 500

  words = synonyms
  #conditional_word = "happiness"
  sents = []
  # Choose the prefix context
  for conditional_word in words:
    try:
      token = tokenizer.convert_tokens_to_ids(conditional_word)[0]
      seed_text = "[CLS]".split()
      bert_sents = generate(n_samples,conditional_word, seed_text=seed_text, batch_size=batch_size, max_len=max_len,
                          generation_mode=generation_mode,
                          sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter,
                          cuda=cuda)
      sents.append(bert_sents)
    except:
      print("no token for word",conditional_word)

  return sents

  """for bert_sents in sents:
    for sent in bert_sents:
      printer(sent, should_detokenize=True)"""

How to run:
- provide input word related to which you want to genrate statements
- it will first find synonyms to them.
- now for each synonyms it will genrate sentences

In [8]:
input_word= input()
synonyms = [] 
  
for syn in wordnet.synsets(input_word): 
    for l in syn.lemmas(): 
        synonyms.append(l.name()) 

synonyms = list(set(synonyms))
print(set(synonyms)) 

sents = generate_statements(synonyms)

sadness
{'sadness', 'sorrow', 'lugubriousness', 'unhappiness', 'sorrowfulness', 'gloominess'}
Finished batch 1 in 8.262s
Finished batch 1 in 7.993s
no token for word lugubriousness
no token for word unhappiness
no token for word sorrowfulness
no token for word gloominess


Printing of generated sentences

In [9]:
print("generated sentences for word",input_word)
print()
i=1
for bert_sents in sents:
  for sent in bert_sents:
      print(i,")",sep="",end=" ")
      i=i+1
      printer(sent, should_detokenize=True)

generated sentences for word sadness

1) he now was having a moment of deep sadness .
2) baer ' s brown eyes held deep sadness .
3) death and sorrow . death and sorrow . sorrow .
4) the monk met gray ' s sorrowful look .


Some more output statements:

1. generated sentences for good
    - " but are they beneficial ? " he said .
    - " not exactly mutually beneficial , " he said .
    - the award - winning book gained critical acclaim as well
    - it could never work out like this again . well
    - but then again , it was only partially effective .
    - very strong , very agile , and very effective .


2. generated sentences for happy

    - well , glad to hear what was going on .
    - " very glad of it , " she said .
    - this was his life . hell , he was happy
    - " put your clothes back on , happy and happy
    
    
3. generated sentences for angry

    - chapter 15 : a raging storm . gunfire erupted somewhere - between the good guys and the bad guys .
    - he was still a raging monster , but now an image on the wall of monitors across the room .
    - meg starred in her first solo television series , with robert altman and pauline kael . born furious
    - in 2014 there were performances of the song in copenhagen ( september ) and london ( february 2014 ) furious
    - all right , i was kissing him , wild , wild kisses , those kisses , all untold .
    - " anything about that ... ? anything about killing any wild animals in the woods ? " i asked .
    - he heard her coming up for air and mashing her eyes shut . but he was so angry .
    - " hello " | 1981 hello world | " hello " | " hello world " | 1979 angry !

More example in read me file or you can try giving different input like
(fear,horror,envy,beautiful etc...)