# Build a bidirectional text generator with XLNet  
In this notebook, we will
- Learn how to predict masked tokens with pretrained XLNet model, tokenizer in Huggingface
- Bi-directionally generate text from original text by predicting masks at both sides  EX) '<mask\> just ate five <mask\>' -> 'I just ate five cookies'
- Use beam based algorithm to enhance generation qualities of right-to-left prediction

original source (https://towardsdatascience.com/build-a-bidirectional-text-generator-with-xlnet-49d9d37b48a9)


# Background Information

- GPT-2, GPT-3 are great with text generation, but unidirectional (left-to-right)
- XLNet is able generate text in both directions by predicting mask tokens

## XLNet
- Paper by Google Brain and CMU (https://arxiv.org/pdf/1906.08237.pdf)
- Easy explanation (https://towardsdatascience.com/xlnet-explained-in-simple-terms-255b9fb2c97c)
- Fixes BERT's main problems
    1. Corrupts input with <mask> during pretrain, which doesn't appear in fine-tuning
    2. Neglects the dependency between masked positions  
        ex) Input: "New York is a city", masked input: (mask) (mask) is a city"  
            BERT tries to maximize log p(New | is a city) + log p (York | is a city)
            No dependency between "New" and "York", so could result in weird prediction like "New Francisco is a city"

- Outperforms BERT in 20 different tasks


### Main Idea
- Integrates the idea of autoregressive models (ex. GPT) and bi-directional context modeling (ex. BERT)

- Generalized **autoregressive** model, where next token is dependent on all previous tokens!

- Permutation Language Modeling (PLM): Uses all permutations of the input sequence factorization
    - Maximize expected log likelihood over all possible permutation of the sequence
    
    - From permutations, each position learns to utilize contextual information from all positions  

    - No mask needed, just need to ignore words that appear later than the target word (Just like Transformer decoder)  
    
    - **Therefore captures bidirectional context without masks!**
    
    - EX) Input sentence: "New(1) York(2) is(3) a(4) city(5)", and target word is 3rd word "is"
        - Possible Permutations  
        12 **3** 45: P(is | New York)  
        2 **3** 145: P(is | York)  
        54 **3** 12: P(is | a city) *example of right-to-left context*  
        
        ![image](https://miro.medium.com/max/1050/1*dMgzP_YboxpR8VXuGeAg_Q.png)

    - Captures dependencies, since it is an **autoregressive** model!
        - EX) Current permutation \[is a city New York]  
        
            XLNet, being an autoregressive model, predicts in the order of the sequence  
            
            Computes: log p (New | is a city) + log p (York | New, is a city)


# 1. Install needed modules
We only need the transformers library, which provides a simple interface to XLNet

In [None]:
#!pip install transformers

# 2. Example of masked words prediction with XLNet
XLNet can predicted several related masked words while taking into account the previous context  
Ex)  “<mask\> have <mask\> apples in hands” -> predict what should come in masks!

Lets load a tokenizer that processes incoming text into digital form  

XLNet uses SentencePiece method

SentencePiece
- Idea: Not all words are seperated by spaces (Chinese for example, words aren't seperated by spaces in a sentence)
- Solution
    1. Treat input as raw input stream, including spaces
    2. Use Byte-Pair Encoding or unigram to construct vocab

- More info here (https://huggingface.co/transformers/tokenizer_summary.html)


First, load XLNet model and tokenizer from transformers library

In [1]:
# Predict mentioned words in a sentence with XLNet

from transformers import XLNetTokenizer, XLNetLMHeadModel
import torch

tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')

Add padding text to help XLNet with short texts (proposed by Aman Rusia [link](https://amanrusia.medium.com/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e))

In [2]:
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e

PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""

Predict top 5 words for each <mask\> token. Feed the model with  

1. Tokenized text  
2. Masked word indexes  
3. Permutation masks (needed to disable input tokens to attend to masked tokens)

Turn padding text + input sentence into tokens with the tokenizer, then turn into tensor

In [3]:
torch.manual_seed(0)
# We show how to setup inputs to predict a next token using a bi-directional context.torch
# We will predict masked tokens
input_ids = torch.tensor(tokenizer.encode(PADDING_TEXT + "I gave you three apples. <mask> have <mask> apples in hands", add_special_tokens=False)).unsqueeze(0)


Let's check out the input for XLNet

In [4]:
# All words from Padding Text and given text are tokenized
print("length of sentence:", len(input_ids[0]))
input_ids

length of sentence: 177


tensor([[   67,  2840,    19,    18,  1484,    20,   965, 29077,  8719,  1273,
            21,    45,   273,    17,    10, 15048,    28, 27511,    21,  4185,
            11,    41,  2444,     9,    32,  1025,    20,  8719,    26,    23,
           673,   966,    19, 29077, 20643, 27511, 20822, 20643,    19,    17,
          6616, 17511,    18,  8978,    20,    18,   777,     9, 19233,  1527,
         17669,    19,    24,   673,    17, 28756,   150, 12943,  4354,   153,
            27,   442,    37,    45,   668,    21,    24,   256,    20,   416,
            22,  2771,  4901,     9, 12943,  4354,   153,    51,    24,  3004,
            21, 28142,    23,    65,    20,    18,   416,    34,    24,  2958,
         22947,     9,  1177,    45,   668,  3097, 13768,    23,   103,    28,
           441,   148,    48, 20522,    19, 12943,  4354,   153, 12860,    34,
            18,   326,    27, 17492,   684,    21,  6709,     9,  8585,   123,
           266,    19, 12943,  4354,   153,  6872,  

Create perm_mask, which will tell each input token where the masks are

In [6]:

targets = [-6, -4] # index for masks

perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)


Let's see what perm_mask looks like

In [7]:
# get matrix of input_ids.length * input_ids.length
# each row corresponds to which word is mask, and shouldn't attend to!
# mask = 1
print(perm_mask[0].shape)
perm_mask

torch.Size([177, 177])


tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

change mask index to 1

In [9]:
perm_mask[0, :, targets] = 1.0 # Previous tokens don't see last token

We can see index for mask is now 1

In [10]:
perm_mask[0][0] # we can see index for masks are now 1s

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.])

Now we create target_mapping, which points out where our targets to predict are

In [12]:

target_mapping = torch.zeros((1, len(targets), input_ids.shape[1]), dtype=torch.float)
target_mapping

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       

First tensor points out first mask, second tensor points out second mask

In [13]:
target_mapping[0, 0, targets[0]] = 1.0 # Our first prediction, first <mask>
target_mapping[0, 1, targets[1]] = 1.0 # Our second prediction, second <mask
target_mapping

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0.,

Let's send all the tensors, model to GPU

In [14]:
input_ids_tensor = input_ids.to("cuda")
target_mapping_tensor = target_mapping.to("cuda")
perm_mask_tensor = perm_mask.to("cuda")

In [15]:
model.eval()
if torch.cuda.is_available():
    model.to('cuda') # if we have a GPU

no_grad since we're just doing inference

In [16]:
with torch.no_grad():
    outputs = model(input_ids_tensor, perm_mask=perm_mask_tensor, target_mapping=target_mapping_tensor)
next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
next_token_logits # scores for which word should be in <mask>!

tensor([[[-23.8164, -31.6897, -32.0603,  ..., -28.9935, -29.9855, -25.5707],
         [-38.2683, -45.4465, -45.6469,  ..., -43.8869, -42.4743, -44.8687]]],
       device='cuda:0')

Use top k to find most probable tokens for masks!

In [17]:
for j in range(len(targets)):
    predicted_k_indexes = torch.topk(outputs[0][0][j], k = 5)
    predicted_logits_list = predicted_k_indexes[0]
    predicted_indexes_list = predicted_k_indexes[1]

    print("predicted word:",tokenizer.decode(input_ids[0][targets[j]].item()), j)
    for i,item  in enumerate(predicted_indexes_list):
        the_index = predicted_indexes_list[i].item()
        print("word and logits",tokenizer.decode(the_index),predicted_logits_list[i].item())

predicted word: <mask> 0
word and logits You -8.967314720153809
word and logits I -10.730646133422852
word and logits We -12.72852611541748
word and logits Now -14.039458274841309
word and logits They -14.771018028259277
predicted word: <mask> 1
word and logits three -23.03700828552246
word and logits the -24.339698791503906
word and logits these -25.60515022277832
word and logits two -25.806488037109375
word and logits your -25.95541000366211


You have 3 apples in hands... Sounds good!

# 3. Top-K bi-directional generation
- Create a loop

- At each iteration, model predict top-k tokens for left or right

- Add random token off topK and repeat

- We try to generate text starting from input sentence "text generation is cool"

In [18]:
import random
import numpy as np

In [19]:
# Function to select topK tokens from the probability list and
# then based on the selected K word distribution
def choose_from_top(probs, k=5, sample_size=1):
    ind = np.argpartition(probs, -k)[-k:]
    top_prob = probs[ind]
    # print(tokenizer.decode(ind))
    top_prob = top_prob / np.sum(top_prob) # normalize
    choice = np.random.choice(k, sample_size, p = top_prob, replace=False)
    token_ids = ind[choice]
    return token_ids

Let's select top 10 and loop 20 times (generate 10 words to start and end)

In [20]:
sent = "text generation is cool"
topk = 10
n = 20
# Lower temperatures make the model more confident in its top choices,
# while temperatures greater than 1 decrease confidence
temperature = 5

model.eval()
if torch.cuda.is_available():
    model.to('cuda') # if we have a GPU

Tokenize all the input we need

In [23]:
sent_tokens = tokenizer.encode(sent, add_special_tokens=False)
mask_tokens = tokenizer.encode('<mask>', add_special_tokens=False)
padding_tokens = tokenizer.encode(PADDING_TEXT, add_special_tokens=False)

In [24]:
print(sent_tokens)
print(mask_tokens)
print(padding_tokens)

[1758, 2887, 27, 2299]
[6]
[67, 2840, 19, 18, 1484, 20, 965, 29077, 8719, 1273, 21, 45, 273, 17, 10, 15048, 28, 27511, 21, 4185, 11, 41, 2444, 9, 32, 1025, 20, 8719, 26, 23, 673, 966, 19, 29077, 20643, 27511, 20822, 20643, 19, 17, 6616, 17511, 18, 8978, 20, 18, 777, 9, 19233, 1527, 17669, 19, 24, 673, 17, 28756, 150, 12943, 4354, 153, 27, 442, 37, 45, 668, 21, 24, 256, 20, 416, 22, 2771, 4901, 9, 12943, 4354, 153, 51, 24, 3004, 21, 28142, 23, 65, 20, 18, 416, 34, 24, 2958, 22947, 9, 1177, 45, 668, 3097, 13768, 23, 103, 28, 441, 148, 48, 20522, 19, 12943, 4354, 153, 12860, 34, 18, 326, 27, 17492, 684, 21, 6709, 9, 8585, 123, 266, 19, 12943, 4354, 153, 6872, 24, 3004, 20, 18, 9225, 2198, 19, 12717, 103, 22, 401, 24, 6348, 9, 12943, 4354, 153, 1068, 2768, 2286, 19, 33, 104, 19, 176, 24, 9313, 19, 20086, 28, 45, 10292, 9, 7, 2, 7739, 6122, 23, 3151]


Loop 20 times and generate text using methods described above!

In [25]:
for i in range(n):
    input = mask_tokens + sent_tokens + mask_tokens
    target_id1 = -len(input) # mask in the beginning
    target_id2 = -1 # at the end

    input_ids = torch.tensor(padding_tokens + input).unsqueeze(0) # We will predict masked token

    perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
    perm_mask[0, :, [target_id1, target_id2]] = 1.0 # set mask index to 1 (don't attend!)

    target_mapping = torch.zeros((1, 2, input_ids.shape[1]), dtype=torch.float)
    target_mapping[0, 0, target_id1] = 1.0 # Our first prediction
    target_mapping[0, 1, target_id2] = 1.0 # Our second prediction

    input_ids_tensor = input_ids.to("cuda")
    target_mapping_tensor = target_mapping.to("cuda")
    perm_mask_tensor = perm_mask.to("cuda")

    with torch.no_grad(): # no need to learn anything, just inference
        outputs = model(input_ids_tensor, perm_mask=perm_mask_tensor, target_mapping=target_mapping_tensor)
    
    predicted_tokens = []

    for j in range(2):
        probs = torch.nn.functional.softmax(outputs[0][0][j]/temperature, dim = 0).to('cpu').numpy()
        predicted_tokens.append(choose_from_top(probs, k=topk, sample_size=1))
    
    if i % 2 == 0: # add to left if iteration number is even
        tok = predicted_tokens[0][0]
        sent_tokens = [tok] + sent_tokens
        print('left: ', tokenizer.decode(sent_tokens))
    else: # add to right if iteration number is odd
        tok = predicted_tokens[1][0]
        sent_tokens = sent_tokens + [tok]
        print("right: ", tokenizer.decode(sent_tokens))

left:  The text generation is cool
right:  The text generation is cool.
left:  .The text generation is cool.
right:  .The text generation is cool. And
left:  2.The text generation is cool. And
right:  2.The text generation is cool. And you
left:  ? 2.The text generation is cool. And you
right:  ? 2.The text generation is cool. And you do
left:  What? 2.The text generation is cool. And you do
right:  What? 2.The text generation is cool. And you do see
left:  Like What? 2.The text generation is cool. And you do see
right:  Like What? 2.The text generation is cool. And you do see it
left:  :Like What? 2.The text generation is cool. And you do see it
right:  :Like What? 2.The text generation is cool. And you do see it,
left:  1:Like What? 2.The text generation is cool. And you do see it,
right:  1:Like What? 2.The text generation is cool. And you do see it, but
left:  Question 1:Like What? 2.The text generation is cool. And you do see it, but
right:  Question 1:Like What? 2.The text genera

Not too impressive...

# 4. Top-K-beam bi-directional text generation
Difficult for the model to generate text right-to-left  
Let's increase the chance of finding connected word sequences by creating certain number of beams, and choosing probable beam  
This example uses beam size 2

![image](https://miro.medium.com/max/1500/1*F8pqfzFBsyZwPxWIlhyOjQ.jpeg)

1. Generate right-to-left at certain length (at each stage, select next token candidates with top-K sampling)  

2. Take random beam from top-K most probable beams and add to start phrase  

3. Generate left-to-right beams with new start phrase

4. take random beam from top-K most probable beams and add to phrase, then iterate!

In [26]:
# create a combination of beam and top-k generation to generate sequences of n tokens from both sides
padding_tokens = tokenizer.encode(PADDING_TEXT, add_special_tokens=False)
mask_tokens = tokenizer.encode('<mask>', add_special_tokens=False)

model.eval()
if torch.cuda.is_available():
    model.to('cuda')


Create function candidates_gen, which takes  
- Tokenized start sentence
- Sequence of token candidates with probabilities  

Then generate **n** probable sequences on right or left side

This function will be used iteratively, so that generated token sequences from previous will be input next

Returns candidates for mask (ex. input: "five apples", output: [for, or, of\])

In [27]:
def candidates_gen(sent_tokens, candidate=([], 1, []), d='left', n_candidates=5, topk=20, temperature=5):
    branch_candidates = []  
    cand_tokens = candidate[0]
    
    # First prepare input depending on direction. ex) five apples
    if d == 'right':    
        input = sent_tokens + cand_tokens + mask_tokens     
        
        target_id = -1 # mask is at the end, since left-to-right generation
        input_ids = torch.tensor(padding_tokens + input).unsqueeze(0)  

        perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
        perm_mask[0, :, target_id] = 1.0  # Previous tokens don't see last token
    else:        
        input = mask_tokens + cand_tokens + sent_tokens    
        
        target_id = -len(input) # mask at the front, since right-to-left generation
        input_ids = torch.tensor(padding_tokens + input).unsqueeze(0)  

        perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
        perm_mask[0, :, [target_id - i for i in range(100)]] = 1.0  # Mask additional previous tokens to improve left-side generation
    
    # We will predict masked tokens
    target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)
    target_mapping[0, 0, target_id] = 1.0 # our right prediction

    if torch.cuda.is_available():
        input_ids_tensor = input_ids.to("cuda")
        target_mapping_tensor = target_mapping.to("cuda")
        perm_mask_tensor = perm_mask.to("cuda")
    else:
        input_ids_tensor = input_ids
        target_mapping_tensor = target_mapping
        perm_mask_tensor = perm_mask

    # Predict mask
    with torch.no_grad():
        outputs = model(input_ids_tensor, perm_mask=perm_mask_tensor, target_mapping=target_mapping_tensor)

    probs = torch.nn.functional.softmax(outputs[0][0][0]/temperature, dim = 0)
    selected_indexes = choose_from_top(probs.to('cpu').numpy(), k=topk, sample_size=n_candidates)
    selected_probs = probs[selected_indexes]

    # Add possible results to branch_candidates. ex) if right, "five apples for", "fiver apples or"
    for i, item in enumerate(selected_indexes):
        the_index = item.item()
        if d == 'right':
            new_sent = cand_tokens + [the_index]
        elif d == "left":
            new_sent = [the_index] + cand_tokens

        prob = selected_probs[i].item()
        # add word combinations to branch_candidates in format [sentence, cumulative probability, all probs]
        branch_candidates.append((new_sent, candidate[1] * prob, candidate[2] + [prob]))
    
    return branch_candidates




The beam_gen function uses candidates from candidates_gen to form some number of beams (depth below)

output: ["for you", "for me", ...\]

In [31]:
def beam_gen(sent_tokens, candidates, depth=5, d='right', sample_size=2, topk=10, temperature=5):
    beams = candidates[:]
    new_candidates = candidates[:]

    # create 5 beams
    while depth > 0:
        new_candidates = []
        for candidate in candidates:
            for new_candidate in candidates_gen(sent_tokens, candidate, d, sample_size, topk, temperature):
                beams.append(new_candidate)
                new_candidates.append(new_candidate)
        print("number of beams:", len(new_candidates))
        candidates = new_candidates[:]
        depth -= 1
        
    # sort candidate beams by a sum of logaryphms of probability of each word in a beam. Which is equivalent to product of probabilities
    sorted_beams = sorted(new_candidates, key=lambda tup: np.sum(np.log10(tup[2])), reverse=True)
    return beams, sorted_beams

Finally, the bi_generator function

- If direction: both
    - generate n_tokens on left side, n_tokens on right side, iterate

- first_sample_size: the number of candidates in the first stage of beam search

- sample_size: number of candidates in the next stages

In [32]:
def bi_generator(sent, direction, first_sample_size, sample_size, n_tokens, topk, iterations, temperature):
    sent_tokens = tokenizer.encode(sent, add_special_tokens=False)

    for i in range(iterations):
        if (i % 2 == 0 and direction == 'both') or direction == 'left':
            print('>> left side generation')
            candidates = candidates_gen(sent_tokens=sent_tokens, d='left', n_candidates=first_sample_size, topk=topk, temperature=temperature)
            beams, sorted_beams = beam_gen(sent_tokens, candidates, n_tokens-1, 'left', sample_size, temperature=temperature)
            topn = len(sorted_beams)//5 if len(sorted_beams) > 4 else len(sorted_beams)
            selected_candidate = random.choice(sorted_beams[:topn])
            sent_tokens = selected_candidate[0] + sent_tokens
            print(tokenizer.decode(sent_tokens))
        
        if (i % 2 != 0 and direction == 'both') or direction == 'right':
            print('>> right side generation')
            candidates = candidates_gen(sent_tokens=sent_tokens, d='right', n_candidates=first_sample_size, topk=topk, temperature=temperature)
            beams, sorted_beams = beam_gen(sent_tokens, candidates, n_tokens-1, 'right', sample_size, topk, temperature=temperature)
            topn = len(sorted_beams)//5 if len(sorted_beams) > 4 else len(sorted_beams)
            selected_candidate = random.choice(sorted_beams[:topn])
            sent_tokens = sent_tokens + selected_candidate[0]
            print(tokenizer.decode(sent_tokens))
    return tokenizer.decode(sent_tokens)

Let's try it out!

- Start sentence: text generation is cool

- each beam length: 4

- iterate 6 times (3 times on each side)

In [33]:
sent = "text generation is cool"
first_sample_size = 4
sample_size = 2
n_tokens = 4
topk = 20
iterations = 6
temperature = 4
direction = "both"

bi_generator(sent, direction, first_sample_size, sample_size, n_tokens, topk, iterations, temperature)

>> left side generation
number of beams: 8
number of beams: 16
number of beams: 32
creating the environment where text generation is cool
>> right side generation
number of beams: 8
number of beams: 16
number of beams: 32
creating the environment where text generation is cool is something I have
>> left side generation
number of beams: 8
number of beams: 16
number of beams: 32
<eod> For me, creating the environment where text generation is cool is something I have
>> right side generation
number of beams: 8
number of beams: 16
number of beams: 32
<eod> For me, creating the environment where text generation is cool is something I have spent an enormous amount
>> left side generation
number of beams: 8
number of beams: 16
number of beams: 32
of them.<eop><eod> For me, creating the environment where text generation is cool is something I have spent an enormous amount
>> right side generation
number of beams: 8
number of beams: 16
number of beams: 32
of them.<eop><eod> For me, creating the

'of them.<eop><eod> For me, creating the environment where text generation is cool is something I have spent an enormous amount of effort over a'

# Conclusion
We created a transformers based bidirectional text generator!