# Transformer and Language Models

In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence

import torchtext, datasets, math
from tqdm import tqdm

from queue import PriorityQueue
import operator

In [2]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
# torch.manual_seed(SEED)
# torch.backends.cudnn.deterministic = True

torch.cuda.get_device_name(0)

cuda:2


'NVIDIA GeForce RTX 2080 Ti'

## 1. Load data - Wiki Text

We will be using wikitext which contains a large corpus of text, perfect for language modeling task.  This time, we will use the `datasets` library from HuggingFace to load.

In [3]:
import os
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'

#there are raw and preprocessed version; we used the raw one and preprocessed ourselves for fun
dataset = datasets.load_dataset('ptb_text_only')
print(dataset)

Found cached dataset ptb_text_only (/home/st122934/.cache/huggingface/datasets/ptb_text_only/penn_treebank/1.1.0/8d1b97746fb9765d140e569ec5ddd35e20af4d37761f5e1bf357ea0b081f2c1f)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['sentence'],
        num_rows: 42068
    })
    test: Dataset({
        features: ['sentence'],
        num_rows: 3761
    })
    validation: Dataset({
        features: ['sentence'],
        num_rows: 3370
    })
})


In [4]:
print(dataset['train'][333]['sentence'])

'''
If you try to change the index you might notice that sometimes there is no paragraph 
and rather an empty string so we will have to care of that later.
'''

behind all the <unk> is some <unk> competition


'\nIf you try to change the index you might notice that sometimes there is no paragraph \nand rather an empty string so we will have to care of that later.\n'

## 2. Preprocessing

### Tokenizing

Simply tokenize the given text to tokens.

In [5]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

#function to tokenize
tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['sentence'])}  

#map the function to each example
tokenized_dataset = dataset.map(tokenize_data, remove_columns=['sentence'], fn_kwargs={'tokenizer': tokenizer})
print(tokenized_dataset['train'][333]['tokens'])

Loading cached processed dataset at /home/st122934/.cache/huggingface/datasets/ptb_text_only/penn_treebank/1.1.0/8d1b97746fb9765d140e569ec5ddd35e20af4d37761f5e1bf357ea0b081f2c1f/cache-6d1e5992cb800354.arrow
Loading cached processed dataset at /home/st122934/.cache/huggingface/datasets/ptb_text_only/penn_treebank/1.1.0/8d1b97746fb9765d140e569ec5ddd35e20af4d37761f5e1bf357ea0b081f2c1f/cache-6fd08b61fe272feb.arrow
Loading cached processed dataset at /home/st122934/.cache/huggingface/datasets/ptb_text_only/penn_treebank/1.1.0/8d1b97746fb9765d140e569ec5ddd35e20af4d37761f5e1bf357ea0b081f2c1f/cache-b98755e1d2264d31.arrow


['behind', 'all', 'the', '<unk>', 'is', 'some', '<unk>', 'competition']


### Numericalizing

We will tell torchtext to add any word that has occurred at least three times in the dataset to the vocabulary because otherwise it would be too big.

In [6]:
## numericalizing

# Define special symbols and indices
UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<sos>', '<eos>']

vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], min_freq=3, specials=special_symbols)   

vocab.set_default_index(vocab['<unk>'])   
print(len(vocab))                         
print(vocab.get_itos()[:100])       

9881
['<unk>', '<pad>', '<sos>', '<eos>', 'the', 'n', 'of', 'to', 'a', 'in', 'and', '.', "'", 's', 'that', 'for', '$', 'is', 'it', 'said', 'on', 'by', 'at', 'as', 'from', 'million', 'with', 'mr', 'was', 'be', 'are', 'its', 'he', 'but', 't', 'has', 'an', 'will', 'have', 'new', 'or', 'company', 'they', 'this', 'year', 'which', 'would', 'about', 'says', 'more', 'were', 'market', 'u', 'billion', 'his', 'had', 'their', 'up', 'one', 'than', 'who', 'some', 'been', 'also', 'stock', 'other', 'share', 'corp', 'not', 'we', 'inc', 'i', 'if', 'when', 'last', 'president', 'shares', 'years', 'all', 'first', 'two', 'because', 'trading', 'after', 'could', 'co', 'sales', '&', 'there', 'out', 'business', 'only', 'do', 'such', 'can', 'most', 'into', 'york', 'may', 'over']


In [24]:
import pickle
with open('vocab_beam.pkl', 'wb') as file:
      
    # A new file will be created
    pickle.dump(vocab, file)

## 3. Prepare the batch loader

### Prepare data

Given "Chaky loves eating at AIT", and "I really love deep learning", and given batch size = 3, we will get three batches of data "Chaky loves eating at", "AIT `<eos>` I really", "love deep learning `<eos>`".  

In [7]:
def get_data(dataset, vocab, batch_size):
    data = []
    # example = []
    for example in dataset:
        if example['tokens']:         
            #appends eos so we know it ends....so model learn how to end...                             
            tokens = example['tokens'].append('<eos>')   
            #numericalize          
            tokens = [vocab[token] for token in example['tokens']] 
            data.extend(tokens)                                    
    data = torch.LongTensor(data)                                 
    num_batches = data.shape[0] // batch_size #get the int number of batches...
    data = data[:num_batches * batch_size] #make the batch evenly, and cut out any remaining                      
    data = data.view(batch_size, num_batches)          
    return data #[batch size, bunch of tokens]

In [8]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'], vocab, batch_size)

## 4. Modeling 

In [9]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        
        batch_size = query.shape[0]
        
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
                
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        #Q = [batch size, query len, hid dim]
        #K = [batch size, key len, hid dim]
        #V = [batch size, value len, hid dim]
                
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        #Q = [batch size, n heads, query len, head dim]
        #K = [batch size, n heads, key len, head dim]
        #V = [batch size, n heads, value len, head dim]
                
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        #energy = [batch size, n heads, query len, key len]
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim = -1)
        #attention = [batch size, n heads, query len, key len]
                
        x = torch.matmul(self.dropout(attention), V)
        #x = [batch size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        #x = [batch size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        #x = [batch size, query len, hid dim]
        
        x = self.fc_o(x)
        #x = [batch size, query len, hid dim]
        
        return x, attention

In [10]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [batch size, seq len, hid dim]
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        #x = [batch size, seq len, pf dim]
        
        x = self.fc_2(x)
        #x = [batch size, seq len, hid dim]
        
        return x

In [11]:
class BeamSearchNode(object):
    def __init__(self, previousNode, wordId, logProb, length):
        self.prevNode = previousNode  #where does it come from
        self.wordid   = wordId  #the numericalized integer of the word
        self.logp     = logProb  #the log probability
        self.len      = length  #the current length; first word starts at 1

    def eval(self, alpha=0.7):
        # the score will be simply the log probability penaltized by the length 
        # we add some small number to avoid division error
        # read https://arxiv.org/abs/1808.10006 to understand how alpha is selected
        return self.logp / float(self.len + 1e-6) ** (alpha)
    
    #this is the function for comparing between two beamsearchnodes, whether which one is better
    #it is called when you called "put"
    def __lt__(self, other):
        return self.len < other.len

    def __gt__(self, other):
        return self.len > other.len

In [12]:
class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, 
                 pf_dim, dropout, device, pad_idx, max_length = 100):
                
        super().__init__()
        
        self.device = device
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
    
        self.pad_idx = pad_idx
    
    def make_mask(self, x):
        
        #x = [batch size, len]
        
        pad_mask = (x != self.pad_idx).unsqueeze(1).unsqueeze(2)
        #pad_mask = [batch size, 1, 1, len]
        
        x_len = x.shape[1]
        
        sub_mask = torch.tril(torch.ones((x_len, x_len), device = self.device)).bool()
        #sub_mask = [len, len]
            
        mask = pad_mask & sub_mask
        #mask = [batch size, 1, len, len]
        
        return mask 
    
    def forward(self, x):
        
        #x = [batch size, len]
                
        batch_size = x.shape[0]
        x_len    = x.shape[1]
        
        #get mask here since we remove seq2seq class
        mask   = self.make_mask(x)
        #mask = [batch size, 1, len, len]

        pos = torch.arange(0, x_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)          
            
        x = self.dropout((self.tok_embedding(x) * self.scale) + self.pos_embedding(pos))
        #x = [batch size, len, hid dim]
        
        for layer in self.layers:
            x, attention = layer(x, mask)
        
        #x = [batch size, len, hid dim]
        #attention = [batch size, n heads, len, len]
        
        output = self.fc_out(x)
        #output = [batch size, len, output dim]
            
        return output, attention

    def beam_decode(self, src_tensor, method='beam-search'):
        
        #src_tensor      = [batch size, src len]
        src_len = src_tensor.shape[1]
        
        #how many parallel searches
        beam_width = 3
        
        #how many sentence do you want to generate
        topk = 1  
        
        #final generated sentence
        decoded_batch = []
                                        
        # Start with the start of the sentence token
        decoder_input = torch.LongTensor([SOS_IDX]).to(device)

        # Number of sentence to generate
        endnodes = []  #hold the nodes of EOS, so we can backtrack
        number_required = min((topk + 1), topk - len(endnodes))

        # starting node -  hidden vector, previous node, word id, logp, length
        node = BeamSearchNode(None, decoder_input, 0, 1)
        nodes = PriorityQueue()  #this is a min-heap

        # start the queue
        nodes.put((-node.eval(), node))  #we need to put - because PriorityQueue is a min-heap
        qsize = 1

        # start beam search
        while True:
            # give up when decoding takes too long
            if qsize > 100: break
            
            # print(f"{nodes.queue=}")

            # fetch the best node
            # score is log p divides by the length scaled by some constants
            score, n       = nodes.get()
            decoder_input  = n.wordid

            #get all the previous nodes, so to construct a complete decoder input
            #because Transformer decoder expects the whole sentence
            prevNode = n.prevNode
            while prevNode != None:
                prev_word = torch.LongTensor([prevNode.wordid]).to(device)
                # print(f"{prev_word=}")
                decoder_input = torch.cat((decoder_input, prev_word))
                prevNode = prevNode.prevNode

            inv_idx       = torch.arange(decoder_input.size(0)-1, -1, -1).long()
            decoder_input = decoder_input[inv_idx]

            # wordid is simply the numercalized integer of the word
            current_len    = n.len

            decoder_input  = decoder_input.unsqueeze(0)
            #decoder_inpput: batch_size, src_len

            if n.wordid.item() == EOS_IDX and n.prevNode != None:
                endnodes.append((score, n))
                # if we reached maximum # of sentences required
                if len(endnodes) >= number_required:
                    break
                else:
                    continue

            # decode for one step using decoder
            # decoder_input = SOS_IDX
            # mask = [1, src len]
            decoder_input = F.pad(decoder_input, pad=(0, src_len), mode='constant', value=PAD_IDX)
            #pad because our decoder expects a whole sentence, not one token by token....

#             print(f"{current_len=}")
#             print(f"{decoder_input=}")

            prediction, _ = self.forward(decoder_input)
            #prediction   = [batch size, src len, output dim]

            prediction = prediction[:, current_len, :] #get only the next word, but ignoring the padding
            #prediction   = [batch size, output dim]

            #so basically prediction is probabilities across all possible vocab
            #we gonna retrieve k top probabilities (which is defined by beam_width) and their indexes
            #recall that beam_width defines how many parallel searches we want
            log_prob, indexes = torch.topk(prediction, beam_width)
            # log_prob      = (1, beam width)
            # indexes       = (1, beam width)
            
            # print(f"{log_prob.shape}")
            # print(f"{indexes.shape}")

            nextnodes = []  #the next possible node you can move to

            # we only select beam_width amount of nextnodes
            for top in range(beam_width):
                pred_t = indexes[0, top].reshape(-1)  #reshape because wordid is assume to be []; see when we define SOS
                log_p  = log_prob[0, top].item()

                #decoder previous node, current node, prob, length
                node = BeamSearchNode(n, pred_t, n.logp + log_p, n.len + 1)
                score = -node.eval()
                nextnodes.append((score, node))

            # put them into queue
            for i in range(len(nextnodes)):
                score, nn = nextnodes[i]
                nodes.put((score, nn))
                # increase qsize
            qsize += len(nextnodes) - 1


        # Once everything is finished, choose nbest paths, back trace them

        ## in case it does not finish, we simply get couple of nodes with highest probability
        if len(endnodes) == 0:
            endnodes = [nodes.get() for _ in range(topk)]

        #look from the end and go back....
        utterances = []
        for score, n in sorted(endnodes, key=operator.itemgetter(0)):
            utterance = []
            utterance.append(n.wordid)
            # back trace by looking at the previous nodes.....
            while n.prevNode != None:
                n = n.prevNode
                utterance.append(n.wordid)

            utterance = utterance[::-1]  #reverse it....
            utterances.append(utterance) #append to the list of sentences....

        decoded_batch.append(utterances)

        return decoded_batch  #(batch size, length)

In [13]:
class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)        
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        
        #x = [batch size, len, hid dim]
        #mask = [batch size, 1, len, len]
        
        #multi attention, skip and then norm
        _x, attention = self.self_attention(x, x, x, mask)
        x = self.self_attn_layer_norm(x + self.dropout(_x))
        #x = [batch size, len, hid dim]
        #attention = [batch size, n heads, len, len]
    
        #positionwise feedforward
        _x = self.positionwise_feedforward(x)
        x = self.ff_layer_norm(x + self.dropout(_x))
        #x = [batch size, len, hid dim]
        
        return x, attention

## 5. Training 

In [14]:
vocab_size = len(vocab)
hid_dim    = 256                
dec_layers = 3               
dec_heads  = 8
dec_pf_dim = 512
dec_dropout = 0.1     
lr = 1e-3                     

In [15]:
model = Decoder(vocab_size, hid_dim, dec_layers, dec_heads, dec_pf_dim, dec_dropout, device, PAD_IDX).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 6,675,865 trainable parameters


In [16]:
def get_batch(data, seq_len, idx):
    #data #[batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  #target simply is ahead of src by 1            
    return src, target

In [17]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    # data #[batch size, bunch of tokens]
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]  #we need to -1 because we start at 0
    num_batches = data.shape[-1]
        
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        
        src, target = get_batch(data, seq_len, idx) #src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, _ = model(src)               

        #need to reshape because criterion expects pred to be 2d and target to be 1d
        prediction = prediction.reshape(batch_size * seq_len, -1)  #prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [18]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]
    
    decoded_batch_list = []

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            
            #tareget = [batch size, dec len]

            batch_size= src.shape[0]

            prediction, _ = model(src)
            #prediction = [batch size, dec len, output_dim]
            
            #decoding using beam_search as example (you don't need to put here, because beam_search is for intference)
            decoded_batch = model.beam_decode(src, method='beam-search')
            
            #len(decoded_batch) = 64
            #len(decoded_batch[0]) = 1 = number of sentence generated, i.e., topk            
            decoded_batch_list.append(decoded_batch)
            
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
            
        #this is optional; you don't have to; printing first three samples of the first batch
        # print("print samples from first decode batch")
        # for sentence_index in decoded_batch_list[0][:3]:
        #     decode_text_arr = [vocab.lookup_token(i) for i in sentence_index[0]]
        #     decode_sentence = " ".join(decode_text_arr)
        #     print("pred target : {}".format(decode_sentence))
            
    return epoch_loss / num_batches

Here we will be using a `ReduceLROnPlateau` learning scheduler which decreases the learning rate by a factor, if the loss don't improve by a certain epoch.

In [19]:
n_epochs = 15
seq_len  = 25 #<----decoding length
clip    = 0.25

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-val-tr_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                                           

	Train Perplexity: 306.775
	Valid Perplexity: 182.751


                                                           

	Train Perplexity: 159.675
	Valid Perplexity: 148.775


                                                           

	Train Perplexity: 124.548
	Valid Perplexity: 136.654


                                                           

	Train Perplexity: 105.177
	Valid Perplexity: 130.914


                                                           

	Train Perplexity: 92.390
	Valid Perplexity: 128.211


                                                           

	Train Perplexity: 83.151
	Valid Perplexity: 128.134


                                                           

	Train Perplexity: 76.190
	Valid Perplexity: 127.953


                                                           

	Train Perplexity: 70.593
	Valid Perplexity: 128.927


                                                           

	Train Perplexity: 61.219
	Valid Perplexity: 125.991


                                                           

	Train Perplexity: 57.564
	Valid Perplexity: 127.286


                                                           

	Train Perplexity: 52.829
	Valid Perplexity: 126.459


                                                           

	Train Perplexity: 50.189
	Valid Perplexity: 124.935


                                                           

	Train Perplexity: 49.137
	Valid Perplexity: 125.309


                                                           

	Train Perplexity: 47.952
	Valid Perplexity: 124.797


                                                           

	Train Perplexity: 47.434
	Valid Perplexity: 124.984


## 6. Testing

In [20]:
model.load_state_dict(torch.load('best-val-tr_lm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

Test Perplexity: 123.293


## 7. Real-world inference

Here I only use pure sampling.  You may want to put the beam search here and compare.  I will leave them as your practice.

In [21]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, _ = model(src)
            
            #prediction: [batch size, seq len, vocab size]
            #prediction[:, -1]: [batch size, vocab size] #probability of last vocab
            
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']: #if it is unk, we sample again
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:    #if it is eos, we stop
                break

            indices.append(prediction) #autoregressive, thus output becomes input
            
            #####################################################################
            #I only do pure sampling....
            #you may want to compare here with top-k, top-p, and beam search here
            #####################################################################

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [23]:
prompt = 'it is the '
max_seq_len = 30
seed = 0

#smaller the temperature, more diverse tokens but comes 
#with a tradeoff of less-make-sense sentence
temperatures = [0.5, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.5
it is the way to fulfill the potential buyers of the size of the country ' s fund ' s n billion mark

0.7
it is the first time to fulfill the bill

0.75
it is the first time to fulfill by the u . s .

0.8
it is the third quarter

1.0
it is the third quarter ended sept . n by a net loss of $ n million or eight cents a share from $ n million or n cents a share last year

