In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import os
from io import open
import re
import unicodedata
import itertools
import random

In [2]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

# Data Preprocessing

In [3]:
# display the movie lines
def print_file_entry(file, n=5):
    with open(file, 'rb') as moviefile:
        entries = moviefile.readlines()
    for entry in entries[:n]:
        print(entry)
        
print_file_entry("movie_data/movie_lines.txt")

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"


In [4]:
print_file_entry("movie_data/movie_conversations.txt")

b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']\n"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L198', 'L199']\n"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L200', 'L201', 'L202', 'L203']\n"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L204', 'L205', 'L206']\n"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L207', 'L208']\n"


In [5]:
# split lines to organize lines with their corresponding fields
def load_lines(file, fields):
    lines = {}
    with open(file, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines

# split conversations to organize conversations into their corresponding fields
def load_conversations(file, lines, fields):
    conversations = []
    with open(file, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")

            conversObj = {}
            for i, field in enumerate(fields):
                conversObj[field] = values[i]

            line_id_pattern = re.compile('L[0-9]+')
            line_ids = line_id_pattern.findall(conversObj['lineIDs'])

            conversObj['lines'] = []
            for lineId in line_ids:
                conversObj['lines'].append(lines[lineId])
            conversations.append(conversObj)
    return conversations

# get the sentence pairs to use as inputs and targets
def extract_sentence_pairs(conversations):
    qr_pairs = []
    for conversation in conversations:
        for i in range(len(conversation['lines'])-1):
            input_line = conversation['lines'][i]["text"].strip()
            target_line = conversation['lines'][i+1]['text'].strip()
            if input_line and target_line:
                qr_pairs.append([input_line, target_line])
    return qr_pairs

In [6]:
lines = {}
conversations = []

# fields corresponding to the order of the splits for lines and conversations
lines_fields = ['lineID', 'characterID', 'movieID', 'character', 'text'] 
conversation_fields = ['character1ID', 'character2ID', 'movieID', 'lineIDs']

In [7]:
lines = load_lines("movie_data/movie_lines.txt", lines_fields)

conversations = load_conversations("movie_data/movie_conversations.txt", lines, conversation_fields)

In [8]:
lines['L1045']

{'lineID': 'L1045',
 'characterID': 'u0',
 'movieID': 'm0',
 'character': 'BIANCA',
 'text': 'They do not!\n'}

In [9]:
conversations[0]

{'character1ID': 'u0',
 'character2ID': 'u2',
 'movieID': 'm0',
 'lineIDs': "['L194', 'L195', 'L196', 'L197']\n",
 'lines': [{'lineID': 'L194',
   'characterID': 'u0',
   'movieID': 'm0',
   'character': 'BIANCA',
   'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'},
  {'lineID': 'L195',
   'characterID': 'u2',
   'movieID': 'm0',
   'character': 'CAMERON',
   'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"},
  {'lineID': 'L196',
   'characterID': 'u0',
   'movieID': 'm0',
   'character': 'BIANCA',
   'text': 'Not the hacking and gagging and spitting part.  Please.\n'},
  {'lineID': 'L197',
   'characterID': 'u2',
   'movieID': 'm0',
   'character': 'CAMERON',
   'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]}

In [10]:
qr_pairs = extract_sentence_pairs(conversations)

In [11]:
qr_pairs[0]

['Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.',
 "Well, I thought we'd start with pronunciation, if that's okay with you."]

In [12]:
preset_tokens = ['<PAD>', '<EOS>', '<OUT>', '<SOS>']

# vocabulary class
class Vocab:
    def __init__(self, preset_tokens):
        self.lines = lines
        self.conversations = conversations
        self.trimmed = False
        self.preset_tokens = preset_tokens
        self.index_to_word = {i: word for i, word in enumerate(preset_tokens)} # initializing the preset tokens
        self.word_to_index = {}
        self.word_to_count = {}
        self.num_words = len(preset_tokens)
        
    # iterate through words in a sentence to add them to the vocab
    def add_sentence(self, sentence):
        for word in sentence.split(' '):
            self.add_word(word)
    
    # handle a new word wrt the vocab
    def add_word(self, word):
        if word not in self.word_to_index:
            self.index_to_word[self.num_words] = word
            self.word_to_index[word] = self.num_words
            self.word_to_count[word] = 1
            self.num_words += 1
        else:
            self.word_to_count[word] += 1
    
    # get vocab for words above a certain minimum count threshold
    def trim(self, min_count, preset_tokens):
        if self.trimmed:
            return
        self.trimmed = True
        
        self.index_to_word = {i: word for i, word in enumerate(preset_tokens)} # initializing the preset tokens
        self.word_to_index = {'<OUT>': preset_tokens.index('<OUT>')}
        previous_word_to_count = self.word_to_count # storing this before reseting
        self.word_to_count = {}
        self.num_words = len(preset_tokens)
        
        for word, count in previous_word_to_count.items():
            if count >= min_count:
                self.add_word(word)
            

In [24]:
# convert unicode string to ascii string
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s)
                   if unicodedata.category(c) != 'Mn')

# creating function to tidy the text a bit
def fix_contractions(text):
    text = re.sub(r"i'm", "i am", text)
    text = re.sub(r"he's", "he is", text)
    text = re.sub(r"she's", "she is", text)
    text = re.sub(r"that's", "that is", text)
    text = re.sub(r"what's", "what is", text)
    text = re.sub(r"where's", "where is", text)
    text = re.sub(r"how's", "how is", text)
    text = re.sub(r"\'ll", " will", text)
    text = re.sub(r"\'ve'", " have", text)
    text = re.sub(r"\'re'", " are", text)
    text = re.sub(r"\'d", " would", text)
    text = re.sub(r"won't", "will not", text)
    text = re.sub(r"can't", "cannot", text)
    text = re.sub(r"it's", "it is", text)
    text = re.sub(r"don't", "do not", text)
    return text

# cleaning the strings
def normalize_string(s):
    s = unicode_to_ascii(s.lower().strip())
    s = re.sub(r"[^A-Za-z.!?,']", r" ", s)
    s = re.sub(r"[.]", r" . ", s)
    s = re.sub(r"[!]", r" ! ", s)
    s = re.sub(r"[?]", r" ? ", s)
    s = re.sub(r"[,]", r" , ", s)
    s = re.sub(r"\s+", r" ", s)
    s = fix_contractions(s)
    s = s.rstrip()
    return s

# construct covab and normalize strings
def read_vocs(qr_pairs):
    qr_pairs_normed = [[normalize_string(s) for s in pair] for pair in qr_pairs]
    vocab = Vocab(preset_tokens)
    return vocab, qr_pairs_normed
    
max_line_length = 20

# check if the question and response are both below the mi
def filter_pair(pair):
    return len(pair[0].split(' ')) < max_line_length and len(pair[1].split(' ')) < max_line_length

# filtering all the question and response
def filter_qr_pairs(qr_pairs):
    return [pair for pair in qr_pairs if filter_pair(pair)]

# high-level utilize other functions to prepare the data
def load_and_prepare_data(qr_pairs):
    vocab, pairs = read_vocs(qr_pairs)
    
    pairs = filter_qr_pairs(pairs)
    
    for pair in pairs:
        vocab.add_sentence(pair[0])
        vocab.add_sentence(pair[1])
    
    return vocab, pairs

In [25]:
vocab, pairs = load_and_prepare_data(qr_pairs)
for pair in pairs[:10]:
    print(pair)

['well , i thought we would start with pronunciation , if that is okay with you .', 'not the hacking and gagging and spitting part . please .']
['not the hacking and gagging and spitting part . please .', "okay . . . then how 'bout we try out some french cuisine . saturday ? night ?"]
["you're asking me out . that is so cute . what is your name again ?", 'forget it .']
["no , no , it is my fault we didn't have a proper introduction", 'cameron .']
['gosh , if only we could find kat a boyfriend . . .', 'let me see what i can do .']
["c'esc ma tete . this is my head", "right . see ? you're ready for the quiz ."]
['that is because it is such a nice one .', 'forget french .']
['how is our little find the wench a date plan progressing ?', "well , there's someone i think might be"]
['there .', 'where ?']
['you have my word . as a gentleman', "you're sweet ."]


In [26]:
# trim the words
def trim_rare_words(vocab, pairs, min_word_occurence=6):
    vocab.trim(min_word_occurence, preset_tokens)
    
    for pair in pairs:
        for i, word in enumerate(pair[0].split(' ')):
            line = pair[0].split(' ')
            if word not in vocab.word_to_index:
                line[i] = '<OUT>'
            pair[0] = ' '.join(line)    
        for i, word in enumerate(pair[1].split(' ')):
            line = pair[1].split(' ')
            if word not in vocab.word_to_index:
                line[i] = '<OUT>'
            pair[1] = ' '.join(line)  
    
    return pairs
            

In [27]:
pairs = trim_rare_words(vocab, pairs)

In [28]:
pairs[1][1]

"okay . . . then how 'bout we try out some french <OUT> . saturday ? night ?"

In [37]:
# return indices for words in the sentence in addition to EOS
def indices_from_sentence(vocab, sentence):
    return [vocab.word_to_index[word] for word in sentence.split(' ')] + [preset_tokens.index('<EOS>')]

# adds zero padding to batches
def zero_padding(indices_batch):
    return list(itertools.zip_longest(*indices_batch, fillvalue = preset_tokens.index('<PAD>')))

# 
def binary_matrix(batch):
    mat = []
    for i, seq in enumerate(batch):
        mat.append([])
        for token in seq:
            if token == preset_tokens.index('<PAD>'):
                mat[i].append(0)
            else:
                mat[i].append(1)
    return mat

# preparing input batch for input into the model
def input_var(input_batch, vocab):
    indices_batch = [indices_from_sentence(vocab, sentence) for sentence in input_batch]
    lengths = torch.tensor([len(indices) for indices in indices_batch])
    padded_batch = zero_padding(indices_batch)
    padded_var = torch.LongTensor(padded_batch)
    return padded_var, lengths

# preparing output
def output_var(output_batch, vocab):
    indices_batch = [indices_from_sentence(vocab, sentence) for sentence in output_batch]
    max_target_length = max([len(indices) for indices in indices_batch])
    padded_list = zero_padding(indices_batch)
    mask = binary_matrix(padded_list)
    mask = torch.BoolTensor(mask)
    padded_var = torch.LongTensor(padded_list)
    return padded_var, mask, max_target_length

# prepare the pairs data into batches to use as training data
def batch_to_train_data(vocab, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(' ')), reverse=True) # sorting the pairs by the length of the question
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = input_var(input_batch, vocab)
    output, mask, max_target_length = output_var(output_batch, vocab)
    return inp, lengths, output, mask, max_target_length
    
test_batch_size = 5
batches = batch_to_train_data(vocab, [random.choice(pairs) for _ in range(test_batch_size)])
input_variable, lengths, target_variable, mask, max_target_length = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_length:", max_target_length)

input_variable: tensor([[6453,  669,   46,   40,   13],
        [2010,    5, 1129,  156,   14],
        [ 106,   95,   17,   19, 4548],
        [  62,   19,   26, 3930,   17],
        [  16, 4954,    2,   33,    1],
        [  21,   83,   17,    1,    0],
        [ 153,   84,    1,    0,    0],
        [  40,    1,    0,    0,    0],
        [ 245,    0,    0,    0,    0],
        [ 203,    0,    0,    0,    0],
        [ 106,    0,    0,    0,    0],
        [ 136,    0,    0,    0,    0],
        [  17,    0,    0,    0,    0],
        [   1,    0,    0,    0,    0]])
lengths: tensor([14,  8,  7,  6,  5])
target_variable: tensor([[   6,    2,    6,   47,  265],
        [  64,  152, 1717,  771,   17],
        [  18,    1,   12,    5,   17],
        [7054,    0,    6,   47,   17],
        [ 532,    0,   63, 3930,   33],
        [   4,    0,  158,   17,    1],
        [ 348,    0,   51,    1,    0],
        [  51,    0, 1878,    0,    0],
        [   2,    0,   11,    0,    0],
        

In [38]:
target_word = [vocab.index_to_word[target.item()] for target in target_variable[:, 1]]
target_word

['<OUT>',
 '!',
 '<EOS>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>']

# Seq2Seq Model

Here we will use a sequence to sequence model with an encoder and decoder RNN utilizing bi-directional Gated Recurrent Units (GRUs) and an attention mechanism

In [39]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers==1 else dropout),
                          bidirectional=True)
        
    # forward pass through the model
    def forward(self, input_seq, input_lengths, hidden=None):
        embedded = self.embedding(input_seq) # convert word indices to embeddings
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) # pack padded batch sequences for RNN module
        outputs, hidden = self.gru(packed, hidden) # forward pass through GRU
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs) # unpack padding
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] # summing both of the bidirection GRU outputs
        return outputs, hidden

In [40]:
class Attention(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attention, self).__init__()
        self.method = method
        
        # account for invalid attention score function
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not a valid attention score function")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))
            
    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)
    
    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)
    
    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)
    
    def forward(self, hidden, encoder_outputs):
        # calculate the attention weight with the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        if self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        if self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)
            
        return F.softmax(attn_energies.t(), dim=1).unsqueeze(1)

In [41]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(AttnDecoderRNN, self).__init__()
        
        self.attn_model = attn_model
        self.embedding = embedding
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout(0 if n_layers==1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
        self.attn = Attention(attn_model, hidden_size)
    
    def forward(self, input_step, last_hidden, encoder_outputs):
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        
        rnn_output, hidden = self.gru(embedded, last_hidden) # forward through unidirectional gru
        
        attn_weights = self.attn(rnn_output, encoder_outputs) # calculate the attention weights from last gru ouput
        
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # get weighted sum as context layer
        
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        
        return output, hidden

# Training the Model

In [42]:
# caculating the negative log loss of padded sequences
def mask_nll_loss(inp, target, mask):
    n_total = mask.sum()
    cross_entropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = cross_entropy.masked_select(mask).mean() # applying the mask to account for padding
    loss = loss.to(device)
    return loss, n_total.item()

In [49]:
# train the model
def train(input_variable, lengths, target_variable, mask, max_target_length, encoder,
          decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=max_line_length):
    
    # clear the gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    # set variables to cuda
    input_variable = input_variable.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)
    
    lengths = lengths.to('cpu')
    
    loss = 0
    losses = []
    n_totals = 0
    
    # forward pass through the encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
    
    # create initial decoder input
    decoder_input = torch.LongTensor([[preset_tokens.index('<SOS>') for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)
    
    # set initial decoder hidden state to encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]
    
    # decide on using teacher forcing in this iteration
    use_teacher_forcing = True if random.random() < tracher_forcing_ratio else False
    
    # forward batch of sequences through the decoder one step at a time
    if use_teacher_forcing:
        for target in range(max_target_lengths):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_input = target_variable[target].view(1, -1) # use teacher forcing
            mask_loss, n_total = mask_nll_loss(decoder_output, target_variable[target], mask[target])
            loss += mask_loss
            losses.append(mask_loss.item() * n_total)
            n_totals += n_total
    else:
        for target in range(max_target_lengths):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            mask_loss, n_total = mask_nll_loss(decoder_output, target_variable[target], mask[target])
            loss += mask_loss
            losses.append(mask_loss.item() * n_total)
            n_totals += n_total
    
    loss.backward() # backward propogation
    
    # gradient clipping
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)
    
    # adjust the model weights
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return sum(losses) / n_totals

In [48]:
def run_train_iters(model_name, vocab, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
                    embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, 
                    print_every, save_every, clip, load_filename):
    
    # load batches for each iteration
    training_batches = [batch_to_train_data(vocab, [random.choice(pairs) for _ in range(batch_size)])
                        for _ in range(n_iteration)]
    
    start_iteration = 1
    losses = 0
    if load_filename:
        start_iteration = checkpoint['iteration'] + 1
        
    # training loop
    print('Training...')
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_length = training_batch
        
        loss = train(input_variable, lengths, target_variable, mask, max_target_length, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        losses += loss
        
        # print training progress
        if iteration % print_every == 0:
            print_loss_avg = losses / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration/n_iteration*100, print_loss_avg))
            losses = 0
            
        # save checkpoint
        if iteration % save_every == 0:
            save_directory = os.path.join(save_dir, '', '{}-{}-{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.sate_dict(),
                'loss': loss,
                'vocab_dict': vocab.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_checkpoint.tar'.format(iteration)))

In [52]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, input_seq, input_length, max_length):
        # forward input through encoder
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * preset_tokens.index('<SOS>')
        
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_score = token.zeros([0], device)
        
        for _ in range(max_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden,
                                                          encoder_outputs)
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            
            decoder_input = torch.unsqueeze(decoder_input, 0)
            
        return all_tokens, all_scores