# Sequence to Sequence Chatbot

In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Special Tokens

In [2]:
SOS_token = 0
EOS_token = 1
UNK_TOKEN = 2

## Tokenizer

In [3]:
class QuestionAnswer:
    def __init__(self, name):
        self.name = name
        self.word2index = {"<SOS>": 0, "<EOS>": 1, "<UNK>": 2}
        self.word2count = {}
        self.index2word = {0: "<SOS>", 1: "<EOS>", 2: "<UNK>"}
        self.n_words = 3
        
    def addSentence(self, sentence):
        for word in sentence.split():
            self.addWord(word)
            
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words # for new word assign index starting from 2
            self.word2count[word] = 1 # count the word
            self.index2word[self.n_words] = word # assign word to index for mapping
            self.n_words += 1 # increase the count of words
        else:
            self.word2count[word] += 1 # if word already exists, increase the count

## PreProcessing

In [4]:
def unicodeToAscii(s):
    """Converts the unicode string to ascii"""
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    """lowercase, trim and remote non letter characters"""
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z!?]+", r" ", s)
    return s.strip()

In [5]:
normalizeString("aa123!s'ss?")

'aa !s ss ?'

## Dataset

In [6]:
import pandas as pd

data = pd.read_csv("./greetings.csv")

data.sample(10)

Unnamed: 0,id,question,answer,intent
480,481,Hi,Hi! Ke chahincha? Ma madat garna tayar chhu.,greetings
209,210,"Hello, tapai kasto chha?","Hello! Ma thikai chhu, tapai lai kasari madat ...",greetings
55,56,"Hello, tapai kasto chha?","Hello! Ma thikai chhu, tapai lai kasari madat ...",greetings
484,485,Namaste,Namaste! Ma tapai ko madat kasari garna sakchu?,greetings
94,95,Tapai ko din kasto chha?,"Din ramro chha, tapai lai kasari madat garna s...",greetings
33,34,"Good afternoon, kasari chha?","Good afternoon! Ma sanchai chhu, tapai lai ke ...",greetings
390,391,"Namaste, sab thikai chha?","Namaste! Sab thikai chha, tapai lai kasari mad...",greetings
134,135,K cha tapai?,"Sab thikai chha, tapai lai kasari madat garna ...",greetings
234,235,Tapai ko din kasto chha?,"Din ramro chha, tapai lai kasari madat garna s...",greetings
131,132,"Namaste, sab thikai chha?","Namaste! Sab thikai chha, tapai lai kasari mad...",greetings


In [7]:
questions = data["question"].values
answers = data["answer"].values

In [8]:
questions[:10], answers[:10]

(array(['Namaste', 'Hello', 'K cha?', 'Kasto chha?', 'Hi', 'Good morning',
        'Good afternoon', 'Good evening', 'Sanchai?', 'K chha?'],
       dtype=object),
 array(['Namaste! Ma tapai lai kasari madat garna sakchu?',
        'Hello! Ma tapai ko lagi ke garna sakchu?',
        'Thikai chha, tapai lai kasari madat garna sakchu?',
        'Ma thikai chhu, tapai lai ke chahincha?',
        'Hi! Tapai lai ma ke madat garna sakchu?',
        'Good morning! Tapai ko lagi ma ke garna sakchu?',
        'Good afternoon! Ma tapai lai kasari madat garna sakchu?',
        'Good evening! Tapai ko lagi ma ke garna sakchu?',
        'Sanchai chhu, tapai lai kasari madat garna sakchu?',
        'Sab thikai chha, tapai lai kasari madat garna sakchu?'],
       dtype=object))

In [9]:
question_data_class = QuestionAnswer("question")
answer_data_class = QuestionAnswer("answer")

In [10]:
question_data_class.n_words

3

In [11]:
answer_data_class.n_words

3

In [12]:
def prepareData(question: list[str], answer: list[str]):
    pairs = []
    for q, a in zip(question, answer):
        # normalize first
        ques = normalizeString(q)
        ans = normalizeString(a)
        question_data_class.addSentence(ques)
        answer_data_class.addSentence(ans)
        pairs.append([q, a])
    print("Question and Answer", question_data_class.n_words, answer_data_class.n_words)
    return question_data_class, answer_data_class, pairs

In [13]:
question_data_class, answer_data_class, pairs = prepareData(questions, answers)

Question and Answer 39 39


In [14]:
pairs

[['Namaste', 'Namaste! Ma tapai lai kasari madat garna sakchu?'],
 ['Hello', 'Hello! Ma tapai ko lagi ke garna sakchu?'],
 ['K cha?', 'Thikai chha, tapai lai kasari madat garna sakchu?'],
 ['Kasto chha?', 'Ma thikai chhu, tapai lai ke chahincha?'],
 ['Hi', 'Hi! Tapai lai ma ke madat garna sakchu?'],
 ['Good morning', 'Good morning! Tapai ko lagi ma ke garna sakchu?'],
 ['Good afternoon', 'Good afternoon! Ma tapai lai kasari madat garna sakchu?'],
 ['Good evening', 'Good evening! Tapai ko lagi ma ke garna sakchu?'],
 ['Sanchai?', 'Sanchai chhu, tapai lai kasari madat garna sakchu?'],
 ['K chha?', 'Sab thikai chha, tapai lai kasari madat garna sakchu?'],
 ['Ma tapai lai kasari madat garna sakchu?',
  'Namaste! Tapai ko lagi ma ke garna sakchu?'],
 ['Ke chha k k?', 'Sab thikai chha, tapai lai kasari madat garna sakchu?'],
 ['Tapai lai ke cha?', 'Sab thikai chha, tapai ko lagi ma ke garna sakchu?'],
 ['Tapai lai kasari madat garna sakchu?', 'Tapai lai k chahincha?'],
 ['Sanchai hunuhunchha

In [15]:
question_data_class.word2index

{'<SOS>': 0,
 '<EOS>': 1,
 '<UNK>': 2,
 'namaste': 3,
 'hello': 4,
 'k': 5,
 'cha': 6,
 '?': 7,
 'kasto': 8,
 'chha': 9,
 'hi': 10,
 'good': 11,
 'morning': 12,
 'afternoon': 13,
 'evening': 14,
 'sanchai': 15,
 'ma': 16,
 'tapai': 17,
 'lai': 18,
 'kasari': 19,
 'madat': 20,
 'garna': 21,
 'sakchu': 22,
 'ke': 23,
 'hunuhunchha': 24,
 'chhu': 25,
 'sab': 26,
 'thikai': 27,
 'chahincha': 28,
 'namaskar': 29,
 'din': 30,
 'bhayo': 31,
 'ko': 32,
 'day': 33,
 'hunu': 34,
 'hunchha': 35,
 'sabai': 36,
 '!': 37,
 'aaj': 38}

In [16]:
answer_data_class.word2index

{'<SOS>': 0,
 '<EOS>': 1,
 '<UNK>': 2,
 'namaste': 3,
 '!': 4,
 'ma': 5,
 'tapai': 6,
 'lai': 7,
 'kasari': 8,
 'madat': 9,
 'garna': 10,
 'sakchu': 11,
 '?': 12,
 'hello': 13,
 'ko': 14,
 'lagi': 15,
 'ke': 16,
 'thikai': 17,
 'chha': 18,
 'chhu': 19,
 'chahincha': 20,
 'hi': 21,
 'good': 22,
 'morning': 23,
 'afternoon': 24,
 'evening': 25,
 'sanchai': 26,
 'sab': 27,
 'k': 28,
 'pani': 29,
 'namaskar': 30,
 'din': 31,
 'ramro': 32,
 'bhayo': 33,
 'mero': 34,
 'day': 35,
 'sahayog': 36,
 'tayar': 37,
 'hola': 38}

In [17]:
question_data_class.index2word

{0: '<SOS>',
 1: '<EOS>',
 2: '<UNK>',
 3: 'namaste',
 4: 'hello',
 5: 'k',
 6: 'cha',
 7: '?',
 8: 'kasto',
 9: 'chha',
 10: 'hi',
 11: 'good',
 12: 'morning',
 13: 'afternoon',
 14: 'evening',
 15: 'sanchai',
 16: 'ma',
 17: 'tapai',
 18: 'lai',
 19: 'kasari',
 20: 'madat',
 21: 'garna',
 22: 'sakchu',
 23: 'ke',
 24: 'hunuhunchha',
 25: 'chhu',
 26: 'sab',
 27: 'thikai',
 28: 'chahincha',
 29: 'namaskar',
 30: 'din',
 31: 'bhayo',
 32: 'ko',
 33: 'day',
 34: 'hunu',
 35: 'hunchha',
 36: 'sabai',
 37: '!',
 38: 'aaj'}

In [18]:
answer_data_class.index2word

{0: '<SOS>',
 1: '<EOS>',
 2: '<UNK>',
 3: 'namaste',
 4: '!',
 5: 'ma',
 6: 'tapai',
 7: 'lai',
 8: 'kasari',
 9: 'madat',
 10: 'garna',
 11: 'sakchu',
 12: '?',
 13: 'hello',
 14: 'ko',
 15: 'lagi',
 16: 'ke',
 17: 'thikai',
 18: 'chha',
 19: 'chhu',
 20: 'chahincha',
 21: 'hi',
 22: 'good',
 23: 'morning',
 24: 'afternoon',
 25: 'evening',
 26: 'sanchai',
 27: 'sab',
 28: 'k',
 29: 'pani',
 30: 'namaskar',
 31: 'din',
 32: 'ramro',
 33: 'bhayo',
 34: 'mero',
 35: 'day',
 36: 'sahayog',
 37: 'tayar',
 38: 'hola'}

## Sequence to Sequence Model Architecture

#### Encoder

In [19]:
class Encoder(torch.nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.embedding = torch.nn.Embedding(input_size, hidden_size) # vocab_size, embedding_dim
        self.gru = torch.nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = torch.nn.Dropout(dropout_p)
        
    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

In [20]:
encoder_ = Encoder(question_data_class.n_words, 256)
encoder_ = encoder_.to(device)

output, hidden = encoder_(torch.tensor([[1, 2, 3, 4, 5]]).to(device))

In [21]:
output.shape, hidden.shape

(torch.Size([1, 5, 256]), torch.Size([1, 1, 256]))

#### Decoder

In [22]:
torch.empty(4, 1).fill_(99)

tensor([[99.],
        [99.],
        [99.],
        [99.]])

In [23]:
MAX_SEQUENCE_LENGTH = 30

In [24]:
class Decoder(torch.nn.Module):
    """
    Decoder architecture
    
    Args:
    hidden_size: int, hidden size of the model
    output_size: int, output size of the model
    """
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.embedding = torch.nn.Embedding(output_size, hidden_size)
        self.gru = torch.nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = torch.nn.Linear(hidden_size, output_size)
        
    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        
        for i in range(MAX_SEQUENCE_LENGTH):
            decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)
            
            if target_tensor is not None:
                # teacher forcing
                decoder_input = target_tensor[:, i].unsqueeze(1)
            else:
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()
                
        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        
        return decoder_outputs, decoder_hidden, None # return non for consistency
    
    def forward_step(self, input, hidden):
        output = self.embedding(input)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.out(output)
        
        return output, hidden
        

## Attention

In [25]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_SEQUENCE_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions


    def forward_step(self, input, hidden, encoder_outputs):
        embedded =  self.dropout(self.embedding(input))

        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((embedded, context), dim=2)

        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)

        return output, hidden, attn_weights

## Training

In [26]:
question_data_class.word2index

{'<SOS>': 0,
 '<EOS>': 1,
 '<UNK>': 2,
 'namaste': 3,
 'hello': 4,
 'k': 5,
 'cha': 6,
 '?': 7,
 'kasto': 8,
 'chha': 9,
 'hi': 10,
 'good': 11,
 'morning': 12,
 'afternoon': 13,
 'evening': 14,
 'sanchai': 15,
 'ma': 16,
 'tapai': 17,
 'lai': 18,
 'kasari': 19,
 'madat': 20,
 'garna': 21,
 'sakchu': 22,
 'ke': 23,
 'hunuhunchha': 24,
 'chhu': 25,
 'sab': 26,
 'thikai': 27,
 'chahincha': 28,
 'namaskar': 29,
 'din': 30,
 'bhayo': 31,
 'ko': 32,
 'day': 33,
 'hunu': 34,
 'hunchha': 35,
 'sabai': 36,
 '!': 37,
 'aaj': 38}

In [27]:
def convert_to_ids(type: QuestionAnswer, sentence: str):
    ids = []
    for word in sentence.split():
        try:
            ids.append(type.word2index[word])
        except KeyError:
            ids.append(type.word2index["<UNK>"])
            
    return ids

In [28]:
convert_to_ids(question_data_class, "Hello how are you")

[2, 2, 2, 2]

In [29]:
def tensor_from_sentence(type: QuestionAnswer, sentence: str):
    ids = convert_to_ids(type, sentence)
    ids.append(EOS_token)
    return torch.tensor(ids, dtype=torch.long, device=device).view(1, -1)

In [30]:
tensor_from_sentence(question_data_class, "what you do?")

tensor([[2, 2, 2, 1]])

In [31]:
def tensor_from_pair(pair):
    """
    Takes question answer pair and return ids
    """
    input_tensor = tensor_from_sentence(question_data_class, pair[0])
    target_tensor = tensor_from_sentence(answer_data_class, pair[1])
    
    return (input_tensor, target_tensor)

In [32]:
tensor_from_pair(["Hello how are you?", "I am fine"])

(tensor([[2, 2, 2, 2, 1]]), tensor([[2, 2, 2, 1]]))

In [33]:
def get_dataloader(batch_size):
    question_data_class, output_data_class, pairs = prepareData(question=questions, answer=answers)
    
    n = len(pairs)
    input_ids = np.zeros((n, MAX_SEQUENCE_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_SEQUENCE_LENGTH), dtype=np.int32)
    
    for idx, (ipt, tgt) in enumerate(pairs):
        inp_ids = convert_to_ids(question_data_class, ipt)
        tar_ids = convert_to_ids(answer_data_class, tgt)
        inp_ids.append(EOS_token)
        tar_ids.append(EOS_token)
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tar_ids)] = tar_ids
        
    print(input_ids.shape)
    print(target_ids.shape)
        
    train_data = TensorDataset(
        torch.LongTensor(input_ids).to(device),
        torch.LongTensor(target_ids).to(device)
    )
    train_sampler = RandomSampler(train_data)
    
    train_data_loader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
    return question_data_class, answer_data_class, train_data_loader

In [34]:
q, a, t = get_dataloader(2)

Question and Answer 39 39
(500, 30)
(500, 30)


In [35]:
i_, t_ = next(iter(t))

In [36]:
i_

tensor([[ 2,  8,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2, 26, 27,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

In [37]:
t_

tensor([[ 2, 26,  2,  6,  7,  8,  9, 10,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  2, 17,  2,  6,  7,  8,  9, 10,  2,  1,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

In [38]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):
    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor = data
        
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        
        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)
        
        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [39]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

## training loop

In [40]:
def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
               print_every=100, plot_every=100):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

## Plot

In [41]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [42]:
hidden_size = 128
batch_size = 4

question_data_class, answer_data_class, train_data_loader = get_dataloader(batch_size)

encoder = Encoder(question_data_class.n_words, hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size, answer_data_class.n_words).to(device)

Question and Answer 39 39
(500, 30)
(500, 30)


In [43]:
train(train_dataloader=train_data_loader,n_epochs=10, encoder=encoder, decoder=decoder, print_every=1, plot_every=1)

0m 7s (- 1m 9s) (1 10%) 0.4749
0m 14s (- 0m 56s) (2 20%) 0.1435
0m 20s (- 0m 47s) (3 30%) 0.0952
0m 26s (- 0m 39s) (4 40%) 0.0557
0m 32s (- 0m 32s) (5 50%) 0.0409
0m 39s (- 0m 26s) (6 60%) 0.0324
0m 45s (- 0m 19s) (7 70%) 0.0298
0m 51s (- 0m 12s) (8 80%) 0.0267
0m 58s (- 0m 6s) (9 90%) 0.0248
1m 4s (- 0m 0s) (10 100%) 0.0225


## Evaluate

In [44]:
encoder.eval()
decoder.eval()

AttnDecoderRNN(
  (embedding): Embedding(39, 128)
  (attention): BahdanauAttention(
    (Wa): Linear(in_features=128, out_features=128, bias=True)
    (Ua): Linear(in_features=128, out_features=128, bias=True)
    (Va): Linear(in_features=128, out_features=1, bias=True)
  )
  (gru): GRU(256, 128, batch_first=True)
  (out): Linear(in_features=128, out_features=39, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [45]:
input_question = "hello"
normalized_input_question = normalizeString(input_question)
question_ids_tensor = tensor_from_sentence(question_data_class,normalized_input_question)

question_ids_tensor

tensor([[4, 1]])

In [46]:
encoder_outputs, encoder_hidden = encoder(question_ids_tensor)

In [47]:
encoder_outputs.shape, encoder_hidden.shape

(torch.Size([1, 2, 128]), torch.Size([1, 1, 128]))

In [48]:
decoder_outputs, decoder_hidden, decoder_attention = decoder(encoder_outputs, encoder_hidden)

In [49]:
decoder_outputs.shape, decoder_hidden.shape, decoder_attention.shape

(torch.Size([1, 30, 39]), torch.Size([1, 1, 128]), torch.Size([1, 30, 2]))

In [50]:
_, topi = decoder_outputs.topk(1)
output_ids = topi.squeeze()

output_ids

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0])

In [51]:
decoded_words = []
for idx in output_ids:
    if idx.item() == EOS_token:
        decoded_words.append("EOS")
        break
    decoded_words.append(answer_data_class.index2word[idx.item()])

" ".join(decoded_words)

'<SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS>'

In [52]:
def evaluate(encoder, decoder, sentence, question_class, answer_class):
    with torch.no_grad():
        input_tensor = tensor_from_sentence(question_class, sentence)

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            decoded_words.append(answer_class.index2word[idx.item()])
    return decoded_words, decoder_attn

In [53]:
def chat(sentence):
    answer, attention = evaluate(encoder, decoder, sentence, question_data_class, answer_data_class)
    print(" ".join(answer))

In [56]:
chat("hi")

<SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS> <SOS>
