In [14]:
import torch
import torch.nn as nn

In [6]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    print(inp.size())
    print(target.size())
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [7]:
MAX_LENGTH=10
def train(input_variable,lengths,target_variable,mask,max_target_len,encoder,decoder,embedding,
          encoder_optimizer,decoder_optimizer,batch_size,clip,max_length=MAX_LENGTH):
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_variable=input_variable.to(device)
    lengths=lengths.to(device)
    target_variable=target_variable.to(device)
    mask=mask.to(device)
    
    loss=0
    print_losses=[]
    n_totals=0
    
    encoder_outputs, encoder_hidden=encoder(input_variable,lengths)
    
    decoder_input=torch.LongTensor([[START_Token for _ in range(batch_size)]])
    decoder_input=decoder_input.to(device)
    use_teacher_forcing=True if random.random<teacher_forcing_ration else False
    
    decoder_hidden=encoder_hidden[:decoder.n_layers]
    
    if use_teacher_forcing:
        
        for t in range(max_target_len):
            decoder_output,decoder_hidden=decoder(decoder_input,decoder_hidden,encoder_outputs)
            
            decoder_input=target_variable[t].view(1,-1)
            
            mask_loss,nTotal=maskNLLLoss(decoder_output,target_variable[t],mask[t])
            loss+=mask_loss
            print_losses.append(mask_loss.item()*nTotal)
            n_totals+=nTotal
            
    else:
        
        for t in range(max_target_len):
            decoder_output,decoder_hidden=decoder(decoder_input,decoder_hidden,encoder_outputs)
            
            _,topi=decoder_output.topk(1)
            decoder_input=torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input=decoder_input.to(device)
            
            mask_loss,nTotal=maskNLLLoss(decoder_output,target_variable[t],mask[t])
            loss+=mask_loss
            print_losses.append(mask_loss.item()*nTotal)
            n_totals+=nTotal
            
    
    loss.backward()
    
    _=nn.utils.clip_grad_norm(encoder.parameters(),clip)
    _=nn.utils.clip_grad_norm(decoder.parameters(),clip)
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return sum(print_losses)/n_totals
    


In [8]:
def trainiters(model_name,voc,pairs,encoder,decoder,encoder_optimizer,decoder_optimizer,
               embedding,encoder_n_layers,decoder_n_layers,save_dir,n_iteration,batch_size,print_every,
               save_every,corpus_name,loadFileName):
    
    training_batches=[batch2TrainData(voc,[random.choice(pairs) for _ in range(batch_size)]) for _ in range(n_iterations)]
    
    start_iteration=1
    print_loss=0
    
    if loadFileName:
        start_iteration=checkpoint['iteration']+1
        
    for iteration in range(start_iteration,n_iteration):
        training_batch=training_batches[iteration-1]
        
        input_variable,lengths,target_variable,mask,max_target_len=training_batch
        
        loss=train(input_variable,lengths,target_variable,mask,max_target_len,encoder,decoder,embedding,
                   encoder_optimizer,decoder_optimizer,batch_size,clip)
        
        print_loss+=loss
        
        if iteration%print_every==0:
            print_loss_avg=print_loss/print_every
            print("Iteration: "+str(iteration)+"Loss: "+str(print_loss_avg))
            print_loss=0
        

In [12]:
x=torch.zeros([0])
print(x)

tensor([])


In [16]:
class GreedySearchDecoder(nn.Module):
    
    def __init__(self,encoder,decoder):
        super().__init__()
        
        self.encoder=encoder
        self.decoder=decoder
        
    def forward(self,input_seq,input_length,max_length):
        
        encoder_outputs,encoder_hidden=self.encoder(input_seq,seq_length)
        
        decoder_hidden=encoder_hidden[:self.decoder.n_layers]
        decoder_input=torch.ones(1,1,device=device,dtype=torch.long)*Start_Token
        
        all_tokens=torch.zeros([0],device=device,dtype=torch.long)
        all_scores=torch.zeros([0],device=device)
        
        for _ in range(max_length):
            
            decoder_output,decoder_hidden=self.decoder(decoder_input,decoder_hidden,encoder_outputs)
            
            decoder_scores,decoder_input=torch.max(decoder_output,dim=1)
            all_scores=torch.cat((all_scores,decoder_scores),dim=0)
            all_tokens=torch.cat((all_tokrns,decoder_input),dim=0)
            
            torch.unsqueeze(decoder_input,0)
            
        return all_tokens, all_scores
        
        

In [18]:
def evaluate(encoder, decoder, searcher,voc,sentence,max_length=MAX_LENGTH):
    
    index_batch=[indexesFromSentence(voc,sentence)]
    lengths=torch.tensor([len(index) for index in index_batch])
    input_batch=torch.LongTensor(index_batch).transpose(0,1)
    
    input_batch=input_batch.to(device)
    lengths=lengths.to(device)
    
    tokens, scores=searcher(input_batch,lengths,max_length)
    decoded_words=[voc.index2word[token.item()] for token in tokens]
    return decoder_words

def evaluateInput(encoder,decoder,searcher,voc):
    input_sentence=''
    while True:
        try:
            input_sentence=input('Human> ')
            
            if input_sentence=='q' or input_sentence=='quit':
                break
            input_sentence=normalizeString(input_sentence)
            output_words=evaluate(encoder,decoder,searcher,voc,input_sentence)
            output_words[:]=[x for x in output_words if not(x=="PAD" or x=="EOS")]
            print("Bot:"," ".join(output_words))
            
        except KeyError:
            print("Unknown Word")
            
            
    