<a href="https://colab.research.google.com/github/bimarshak7/chat-bots/blob/main/General_Convo_chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#snippet to mount drive and copy kaggle API keys
# palge you kaggle keys at MyDrive/Colab Notebooks/ as kaggle.json
from google.colab import drive
drive.mount('/content/drive')
! mkdir ~/.kaggle
!cp "/content/drive/MyDrive/Colab Notebooks/kaggle.json" ~/.kaggle/kaggle.json
! chmod 600 ~/.kaggle/kaggle.json

Mounted at /content/drive


In [2]:
!kaggle datasets download -d kreeshrajani/3k-conversations-dataset-for-chatbot

Downloading 3k-conversations-dataset-for-chatbot.zip to /content
  0% 0.00/67.1k [00:00<?, ?B/s]
100% 67.1k/67.1k [00:00<00:00, 68.8MB/s]


In [3]:
!unzip -q *.zip -d data/

In [4]:
import json
import pandas as pd
from collections import Counter
import torch
import random
from torch import nn
import torch.nn.functional as F
from torchtext.vocab import vocab
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import math
import string

In [5]:
DATA_PATH = "/content/data/Conversation.csv"

In [6]:
df = pd.read_csv(DATA_PATH)
df.sample(5)

Unnamed: 0.1,Unnamed: 0,question,answer
3314,3314,i don't even know why we need to fix it.,in case we have visitors.
2708,2708,did you go swimming?,i went to the beach every day.
2155,2155,"if more people donate money, pbs could offer n...",who wants to donate? public tv should be free.
3467,3467,no. they said there are some things you can't ...,so are they going to hold another election?
3330,3330,it sure is.,"in fact, it's chilly in the apartment, too."


In [7]:
df["question"] = df["question"].apply(lambda x: x.lower())
df["answer"] = df["answer"].apply(lambda x: x.lower())

In [8]:
df["question"] = df["question"].apply(lambda x: x.translate(str.maketrans('', '', string.punctuation)))
df["answer"] = df["answer"].apply(lambda x: x.translate(str.maketrans('', '', string.punctuation)))

In [9]:
counter = Counter()
for i in range(len(df)):
    row = df.iloc[i]
    counter.update(row["question"].split())
    counter.update(row["answer"].split())

In [10]:
vocab_en = vocab(counter, min_freq=5, specials=('<UNK>', '<BOS>', '<EOS>', '<PAD>'))

In [11]:
vocab_en.set_default_index(vocab_en['<UNK>'])

In [12]:
df["question"] = df["question"].apply(lambda x: "<BOS> "+x+" <EOS>")
df["answer"] = df["answer"].apply(lambda x: "<BOS> "+x+" <EOS>")

In [13]:
MAX_LEN_QN = df["question"].apply(lambda x: len(x.split())).max()
MAX_LEN_ANS = df["answer"].apply(lambda x: len(x.split())).max()
MAX_LEN_QN,MAX_LEN_ANS

(21, 21)

In [14]:
class MyDataset(Dataset):
    def __init__(self, df):
        self.X = df["question"]
        self.y = df["answer"]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        qn = self.X.iloc[idx]
        ans = self.y.iloc[idx]

        qn_indices = [vocab_en[word] for word in qn.split()]

        ans_indices = [vocab_en[word] for word in ans.split()]

        return torch.tensor(qn_indices), len(qn_indices), torch.tensor(ans_indices)


In [15]:
ds = MyDataset(df)

In [16]:
train_ds,test_ds = torch.utils.data.random_split(ds, [0.8, 0.2])

In [17]:
BATCH_SIZE = 32

In [18]:
def my_collate(batch):
    # Extract sequences and targets
    qns = [item[0] for item in batch]
    src_lens = [item[1] for item in batch]
    ans = [item[2] for item in batch]
    # Pad sequences
    padded_qn = pad_sequence(qns, batch_first=True, padding_value=vocab_en["<PAD>"])
    padded_ans = pad_sequence(ans, batch_first=True, padding_value=vocab_en["<PAD>"])

    # Return padded sequences and targets
    return padded_qn, torch.tensor(src_lens), padded_ans

In [19]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,collate_fn=my_collate)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE,collate_fn=my_collate)

In [74]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim)

        self.rnn = nn.GRU(emb_dim, enc_hid_dim,2, bidirectional = True)

        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_len):

        #src = [src len, batch size]
        #src_len = [batch size]

        embedded = self.dropout(self.embedding(src))
        #embedded = [src len, batch size, emb dim]

        #need to explicitly put lengths on cpu!
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len.to('cpu'),enforce_sorted=False)

        packed_outputs, hidden = self.rnn(packed_embedded)

        #packed_outputs is a packed sequence containing all hidden states
        #hidden is now from the final non-padded element in the batch

        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs)

        #outputs is now a non-packed sequence, all hidden states obtained
        #  when the input is a pad token are all zeros

        #outputs = [src len, batch size, hid dim * num directions]
        #hidden = [n layers * num directions, batch size, hid dim]

        #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        #outputs are always from the last layer

        #hidden [-2, :, : ] is the last of the forwards RNN
        #hidden [-1, :, : ] is the last of the backwards RNN

        #initial decoder hidden is final hidden state of the forwards and backwards
        #  encoder RNNs fed through a linear layer
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))

        #outputs = [src len, batch size, enc hid dim * 2]
        #hidden = [batch size, dec hid dim]

        return outputs, hidden

In [21]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()

        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, hidden, encoder_outputs, mask):

        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]

        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]

        #repeat decoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        encoder_outputs = encoder_outputs.permute(1, 0, 2)

        #hidden = [batch size, src len, dec hid dim]
        #encoder_outputs = [batch size, src len, enc hid dim * 2]

        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))

        #energy = [batch size, src len, dec hid dim]

        attention = self.v(energy).squeeze(2)

        #attention = [batch size, src len]

        attention = attention.masked_fill(mask == 0, -1e10)

        return F.softmax(attention, dim = 1)

In [75]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention

        self.embedding = nn.Embedding(output_dim, emb_dim)

        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)

        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs, mask):

        #input = [batch size]
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        #mask = [batch size, src len]

        input = input.unsqueeze(0)

        #input = [1, batch size]

        embedded = self.dropout(self.embedding(input))

        #embedded = [1, batch size, emb dim]

        a = self.attention(hidden, encoder_outputs, mask)

        #a = [batch size, src len]

        a = a.unsqueeze(1)

        #a = [batch size, 1, src len]

        encoder_outputs = encoder_outputs.permute(1, 0, 2)

        #encoder_outputs = [batch size, src len, enc hid dim * 2]

        weighted = torch.bmm(a, encoder_outputs)

        #weighted = [batch size, 1, enc hid dim * 2]

        weighted = weighted.permute(1, 0, 2)

        #weighted = [1, batch size, enc hid dim * 2]

        rnn_input = torch.cat((embedded, weighted), dim = 2)

        #rnn_input = [1, batch size, (enc hid dim * 2) + emb dim]

        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))

        #output = [seq len, batch size, dec hid dim * n directions]
        #hidden = [n layers * n directions, batch size, dec hid dim]

        #seq len, n layers and n directions will always be 1 in this decoder, therefore:
        #output = [1, batch size, dec hid dim]
        #hidden = [1, batch size, dec hid dim]
        #this also means that output == hidden
        assert (output == hidden).all()

        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)

        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))

        #prediction = [batch size, output dim]

        return prediction, hidden.squeeze(0), a.squeeze(1)

In [60]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.device = device

    def create_mask(self, src):
        mask = (src != self.src_pad_idx).permute(1, 0)
        return mask

    def forward(self, src, src_len, trg, teacher_forcing_ratio = 0.5):

        #src = [src len, batch size]
        #src_len = [batch size]
        #trg = [trg len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
        src = src.permute(1,0)
        trg = trg.permute(1,0)

        # print("SRC SHAPE",src.shape)
        # print("TRG SHAPE",trg.shape)

        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim

        #tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)

        #encoder_outputs is all hidden states of the input sequence, back and forwards
        #hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src, src_len)
        #first input to the decoder is the <sos> tokens
        input = trg[0,:]

        mask = self.create_mask(src)

        #mask = [batch size, src len]

        for t in range(1, trg_len):

            #insert input token embedding, previous hidden state, all encoder hidden states
            #  and mask
            #receive output tensor (predictions) and new hidden state
            output, hidden, _ = self.decoder(input, hidden, encoder_outputs, mask)

            #place predictions in a tensor holding predictions for each token
            outputs[t] = output

            #decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio

            #get the highest predicted token from our predictions
            top1 = output.argmax(1)

            #if teacher forcing, use actual next token as next input
            #if not, use predicted token
            input = trg[t] if teacher_force else top1

        return outputs

In [61]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [94]:
INPUT_DIM = len(vocab_en)
OUTPUT_DIM = len(vocab_en)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.4
DEC_DROPOUT = 0.3
SRC_PAD_IDX = vocab_en["<PAD>"]
ENCODER_N_LAYERS = 2

In [95]:
attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)
model = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)

In [96]:
criterion = nn.CrossEntropyLoss(ignore_index = vocab_en["<PAD>"])

In [97]:
# def init_weights(m):
#     for name, param in m.named_parameters():
#         if 'weight' in name:
#             nn.init.normal_(param.data, mean=0, std=0.01)
#         else:
#             nn.init.constant_(param.data, 0)

# model.apply(init_weights)

In [98]:
# model.load_state_dict(torch.load("/kaggle/input/seq2seqmental/seq2seq(mental).pt"))

In [99]:
optimizer = torch.optim.Adam(model.parameters(),lr=2e-4)

In [100]:
criterion = nn.CrossEntropyLoss(ignore_index = vocab_en["<PAD>"])

In [101]:
def train(model, iterator, optimizer, criterion, clip):

    model.train()

    epoch_loss = 0
    pbar = tqdm(iterator)
    for i, batch in enumerate(pbar):

        src = batch[0].to(device)
        src_lens = batch[1]
        trg = batch[2].to(device)
        optimizer.zero_grad()

        output = model(src, src_lens, trg)

        #trg = [trg len, batch size]
        #output = [trg len, batch size, output dim]

        output_dim = output.shape[-1]

        output = output[1:].view(-1, output_dim)
        trg = trg[:,1:].reshape(-1)

        #trg = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]
#         print("OP shape ",output.shape)
#         print("TRG Shape ",trg.shape)
        loss = criterion(output, trg)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()
#         if i%500==0:print("Batch:",epoch_loss)

    return epoch_loss / len(iterator)

In [102]:
def evaluate(model, iterator, criterion):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():
        pbar = tqdm(iterator)
        for i, batch in enumerate(pbar):

            src = batch[0].to(device)
            src_lens = batch[1]
            trg = batch[2].to(device)
            output = model(src, src_lens, trg, 0) #turn off teacher forcing

            #trg = [trg len, batch size]
            #output = [trg len, batch size, output dim]

            output_dim = output.shape[-1]

            output = output[1:].view(-1, output_dim)
            trg = trg[:,1:].reshape(-1)

            #trg = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size, output dim]

            loss = criterion(output, trg)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [103]:
model = model.to(device)

In [104]:
N_EPOCHS = 100
CLIP = 20.0

In [105]:
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    train_loss = train(model, train_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, test_loader, criterion)

    torch.save(model.state_dict(), "seq2seq.pt")

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
    print(f'\tEPOCH:{epoch+1}/{N_EPOCHS}\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f} \t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

100%|██████████| 94/94 [00:05<00:00, 18.51it/s]
100%|██████████| 24/24 [00:00<00:00, 42.72it/s]


	EPOCH:1/100	Train Loss: 5.464 | Train PPL: 236.026 	 Val. Loss: 5.330 |  Val. PPL: 206.509


100%|██████████| 94/94 [00:05<00:00, 18.52it/s]
100%|██████████| 24/24 [00:00<00:00, 57.55it/s]


	EPOCH:2/100	Train Loss: 5.188 | Train PPL: 179.125 	 Val. Loss: 5.334 |  Val. PPL: 207.174


100%|██████████| 94/94 [00:04<00:00, 19.81it/s]
100%|██████████| 24/24 [00:00<00:00, 56.01it/s]


	EPOCH:3/100	Train Loss: 5.138 | Train PPL: 170.385 	 Val. Loss: 5.339 |  Val. PPL: 208.237


100%|██████████| 94/94 [00:05<00:00, 17.59it/s]
100%|██████████| 24/24 [00:00<00:00, 57.29it/s]


	EPOCH:4/100	Train Loss: 5.059 | Train PPL: 157.449 	 Val. Loss: 5.373 |  Val. PPL: 215.492


100%|██████████| 94/94 [00:04<00:00, 20.62it/s]
100%|██████████| 24/24 [00:00<00:00, 54.73it/s]


	EPOCH:5/100	Train Loss: 4.963 | Train PPL: 143.064 	 Val. Loss: 5.424 |  Val. PPL: 226.842


100%|██████████| 94/94 [00:04<00:00, 19.38it/s]
100%|██████████| 24/24 [00:00<00:00, 44.69it/s]


	EPOCH:6/100	Train Loss: 4.831 | Train PPL: 125.350 	 Val. Loss: 5.490 |  Val. PPL: 242.330


100%|██████████| 94/94 [00:05<00:00, 18.62it/s]
100%|██████████| 24/24 [00:00<00:00, 55.28it/s]


	EPOCH:7/100	Train Loss: 4.674 | Train PPL: 107.143 	 Val. Loss: 5.567 |  Val. PPL: 261.569


100%|██████████| 94/94 [00:04<00:00, 20.63it/s]
100%|██████████| 24/24 [00:00<00:00, 55.83it/s]


	EPOCH:8/100	Train Loss: 4.502 | Train PPL:  90.203 	 Val. Loss: 5.635 |  Val. PPL: 280.128


100%|██████████| 94/94 [00:05<00:00, 17.98it/s]
100%|██████████| 24/24 [00:00<00:00, 39.99it/s]


	EPOCH:9/100	Train Loss: 4.314 | Train PPL:  74.740 	 Val. Loss: 5.704 |  Val. PPL: 299.965


100%|██████████| 94/94 [00:04<00:00, 20.64it/s]
100%|██████████| 24/24 [00:00<00:00, 57.10it/s]


	EPOCH:10/100	Train Loss: 4.162 | Train PPL:  64.172 	 Val. Loss: 5.779 |  Val. PPL: 323.516


100%|██████████| 94/94 [00:04<00:00, 20.73it/s]
100%|██████████| 24/24 [00:00<00:00, 47.75it/s]


	EPOCH:11/100	Train Loss: 3.999 | Train PPL:  54.532 	 Val. Loss: 5.837 |  Val. PPL: 342.632


100%|██████████| 94/94 [00:05<00:00, 17.83it/s]
100%|██████████| 24/24 [00:00<00:00, 57.31it/s]


	EPOCH:12/100	Train Loss: 3.813 | Train PPL:  45.300 	 Val. Loss: 5.907 |  Val. PPL: 367.550


100%|██████████| 94/94 [00:04<00:00, 20.65it/s]
100%|██████████| 24/24 [00:00<00:00, 57.43it/s]


	EPOCH:13/100	Train Loss: 3.673 | Train PPL:  39.353 	 Val. Loss: 5.978 |  Val. PPL: 394.776


100%|██████████| 94/94 [00:04<00:00, 19.17it/s]
100%|██████████| 24/24 [00:00<00:00, 44.84it/s]


	EPOCH:14/100	Train Loss: 3.523 | Train PPL:  33.890 	 Val. Loss: 6.036 |  Val. PPL: 418.262


100%|██████████| 94/94 [00:04<00:00, 19.36it/s]
100%|██████████| 24/24 [00:00<00:00, 57.44it/s]


	EPOCH:15/100	Train Loss: 3.371 | Train PPL:  29.111 	 Val. Loss: 6.101 |  Val. PPL: 446.383


100%|██████████| 94/94 [00:04<00:00, 20.54it/s]
100%|██████████| 24/24 [00:00<00:00, 58.15it/s]


	EPOCH:16/100	Train Loss: 3.230 | Train PPL:  25.270 	 Val. Loss: 6.172 |  Val. PPL: 479.166


100%|██████████| 94/94 [00:05<00:00, 17.07it/s]
100%|██████████| 24/24 [00:00<00:00, 53.72it/s]


	EPOCH:17/100	Train Loss: 3.093 | Train PPL:  22.051 	 Val. Loss: 6.229 |  Val. PPL: 507.141


100%|██████████| 94/94 [00:04<00:00, 20.37it/s]
100%|██████████| 24/24 [00:00<00:00, 55.22it/s]


	EPOCH:18/100	Train Loss: 2.944 | Train PPL:  18.989 	 Val. Loss: 6.289 |  Val. PPL: 538.586


100%|██████████| 94/94 [00:04<00:00, 19.86it/s]
100%|██████████| 24/24 [00:00<00:00, 44.74it/s]


	EPOCH:19/100	Train Loss: 2.813 | Train PPL:  16.657 	 Val. Loss: 6.366 |  Val. PPL: 581.741


100%|██████████| 94/94 [00:05<00:00, 17.44it/s]
100%|██████████| 24/24 [00:00<00:00, 55.28it/s]


	EPOCH:20/100	Train Loss: 2.697 | Train PPL:  14.840 	 Val. Loss: 6.438 |  Val. PPL: 625.075


100%|██████████| 94/94 [00:04<00:00, 20.41it/s]
100%|██████████| 24/24 [00:00<00:00, 57.07it/s]


	EPOCH:21/100	Train Loss: 2.569 | Train PPL:  13.056 	 Val. Loss: 6.511 |  Val. PPL: 672.696


100%|██████████| 94/94 [00:05<00:00, 18.11it/s]
100%|██████████| 24/24 [00:00<00:00, 43.79it/s]


	EPOCH:22/100	Train Loss: 2.454 | Train PPL:  11.636 	 Val. Loss: 6.577 |  Val. PPL: 718.526


100%|██████████| 94/94 [00:04<00:00, 19.99it/s]
100%|██████████| 24/24 [00:00<00:00, 55.67it/s]


	EPOCH:23/100	Train Loss: 2.349 | Train PPL:  10.473 	 Val. Loss: 6.634 |  Val. PPL: 760.786


100%|██████████| 94/94 [00:04<00:00, 20.29it/s]
100%|██████████| 24/24 [00:00<00:00, 53.54it/s]


	EPOCH:24/100	Train Loss: 2.252 | Train PPL:   9.505 	 Val. Loss: 6.723 |  Val. PPL: 831.320


100%|██████████| 94/94 [00:05<00:00, 17.16it/s]
100%|██████████| 24/24 [00:00<00:00, 54.17it/s]


	EPOCH:25/100	Train Loss: 2.162 | Train PPL:   8.689 	 Val. Loss: 6.761 |  Val. PPL: 863.382


100%|██████████| 94/94 [00:04<00:00, 20.34it/s]
100%|██████████| 24/24 [00:00<00:00, 55.09it/s]


	EPOCH:26/100	Train Loss: 2.084 | Train PPL:   8.033 	 Val. Loss: 6.848 |  Val. PPL: 942.293


100%|██████████| 94/94 [00:04<00:00, 18.97it/s]
100%|██████████| 24/24 [00:00<00:00, 44.94it/s]


	EPOCH:27/100	Train Loss: 1.996 | Train PPL:   7.362 	 Val. Loss: 6.906 |  Val. PPL: 998.608


100%|██████████| 94/94 [00:04<00:00, 19.16it/s]
100%|██████████| 24/24 [00:00<00:00, 55.40it/s]


	EPOCH:28/100	Train Loss: 1.916 | Train PPL:   6.791 	 Val. Loss: 6.934 |  Val. PPL: 1026.315


100%|██████████| 94/94 [00:04<00:00, 20.52it/s]
100%|██████████| 24/24 [00:00<00:00, 56.44it/s]


	EPOCH:29/100	Train Loss: 1.838 | Train PPL:   6.283 	 Val. Loss: 7.000 |  Val. PPL: 1096.656


100%|██████████| 94/94 [00:05<00:00, 17.62it/s]
100%|██████████| 24/24 [00:00<00:00, 57.60it/s]


	EPOCH:30/100	Train Loss: 1.773 | Train PPL:   5.889 	 Val. Loss: 7.040 |  Val. PPL: 1141.607


100%|██████████| 94/94 [00:04<00:00, 20.55it/s]
100%|██████████| 24/24 [00:00<00:00, 58.20it/s]


	EPOCH:31/100	Train Loss: 1.716 | Train PPL:   5.562 	 Val. Loss: 7.087 |  Val. PPL: 1196.215


100%|██████████| 94/94 [00:04<00:00, 20.12it/s]
100%|██████████| 24/24 [00:00<00:00, 44.26it/s]


	EPOCH:32/100	Train Loss: 1.656 | Train PPL:   5.238 	 Val. Loss: 7.152 |  Val. PPL: 1276.579


100%|██████████| 94/94 [00:05<00:00, 16.79it/s]
100%|██████████| 24/24 [00:00<00:00, 42.19it/s]


	EPOCH:33/100	Train Loss: 1.596 | Train PPL:   4.934 	 Val. Loss: 7.197 |  Val. PPL: 1335.716


100%|██████████| 94/94 [00:04<00:00, 19.49it/s]
100%|██████████| 24/24 [00:00<00:00, 57.54it/s]


	EPOCH:34/100	Train Loss: 1.566 | Train PPL:   4.786 	 Val. Loss: 7.244 |  Val. PPL: 1400.028


100%|██████████| 94/94 [00:05<00:00, 17.83it/s]
100%|██████████| 24/24 [00:00<00:00, 42.26it/s]


	EPOCH:35/100	Train Loss: 1.519 | Train PPL:   4.566 	 Val. Loss: 7.248 |  Val. PPL: 1405.034


100%|██████████| 94/94 [00:04<00:00, 20.57it/s]
100%|██████████| 24/24 [00:00<00:00, 56.79it/s]


	EPOCH:36/100	Train Loss: 1.467 | Train PPL:   4.335 	 Val. Loss: 7.279 |  Val. PPL: 1449.020


100%|██████████| 94/94 [00:04<00:00, 20.45it/s]
100%|██████████| 24/24 [00:00<00:00, 42.76it/s]


	EPOCH:37/100	Train Loss: 1.410 | Train PPL:   4.098 	 Val. Loss: 7.345 |  Val. PPL: 1548.730


100%|██████████| 94/94 [00:05<00:00, 17.97it/s]
100%|██████████| 24/24 [00:00<00:00, 53.36it/s]


	EPOCH:38/100	Train Loss: 1.382 | Train PPL:   3.984 	 Val. Loss: 7.391 |  Val. PPL: 1620.719


100%|██████████| 94/94 [00:04<00:00, 20.60it/s]
100%|██████████| 24/24 [00:00<00:00, 57.16it/s]


	EPOCH:39/100	Train Loss: 1.358 | Train PPL:   3.890 	 Val. Loss: 7.412 |  Val. PPL: 1656.209


100%|██████████| 94/94 [00:05<00:00, 18.59it/s]
100%|██████████| 24/24 [00:00<00:00, 43.85it/s]


	EPOCH:40/100	Train Loss: 1.305 | Train PPL:   3.689 	 Val. Loss: 7.429 |  Val. PPL: 1684.252


100%|██████████| 94/94 [00:04<00:00, 19.60it/s]
100%|██████████| 24/24 [00:00<00:00, 56.94it/s]


	EPOCH:41/100	Train Loss: 1.276 | Train PPL:   3.582 	 Val. Loss: 7.443 |  Val. PPL: 1708.032


100%|██████████| 94/94 [00:04<00:00, 20.61it/s]
100%|██████████| 24/24 [00:00<00:00, 56.08it/s]


	EPOCH:42/100	Train Loss: 1.239 | Train PPL:   3.454 	 Val. Loss: 7.493 |  Val. PPL: 1796.221


100%|██████████| 94/94 [00:05<00:00, 17.46it/s]
100%|██████████| 24/24 [00:00<00:00, 55.85it/s]


	EPOCH:43/100	Train Loss: 1.209 | Train PPL:   3.352 	 Val. Loss: 7.569 |  Val. PPL: 1937.807


100%|██████████| 94/94 [00:04<00:00, 20.44it/s]
100%|██████████| 24/24 [00:00<00:00, 56.21it/s]


	EPOCH:44/100	Train Loss: 1.186 | Train PPL:   3.275 	 Val. Loss: 7.709 |  Val. PPL: 2227.258


100%|██████████| 94/94 [00:04<00:00, 19.79it/s]
100%|██████████| 24/24 [00:00<00:00, 45.39it/s]


	EPOCH:45/100	Train Loss: 1.150 | Train PPL:   3.159 	 Val. Loss: 7.659 |  Val. PPL: 2119.676


100%|██████████| 94/94 [00:05<00:00, 18.33it/s]
100%|██████████| 24/24 [00:00<00:00, 55.75it/s]


	EPOCH:46/100	Train Loss: 1.147 | Train PPL:   3.147 	 Val. Loss: 7.604 |  Val. PPL: 2007.208


100%|██████████| 94/94 [00:04<00:00, 20.63it/s]
100%|██████████| 24/24 [00:00<00:00, 56.10it/s]


	EPOCH:47/100	Train Loss: 1.108 | Train PPL:   3.030 	 Val. Loss: 7.617 |  Val. PPL: 2032.627


100%|██████████| 94/94 [00:05<00:00, 18.23it/s]
100%|██████████| 24/24 [00:00<00:00, 44.29it/s]


	EPOCH:48/100	Train Loss: 1.073 | Train PPL:   2.926 	 Val. Loss: 7.697 |  Val. PPL: 2202.496


100%|██████████| 94/94 [00:04<00:00, 20.11it/s]
100%|██████████| 24/24 [00:00<00:00, 52.76it/s]


	EPOCH:49/100	Train Loss: 1.057 | Train PPL:   2.877 	 Val. Loss: 7.746 |  Val. PPL: 2313.405


100%|██████████| 94/94 [00:04<00:00, 20.61it/s]
100%|██████████| 24/24 [00:00<00:00, 55.26it/s]


	EPOCH:50/100	Train Loss: 1.038 | Train PPL:   2.823 	 Val. Loss: 7.766 |  Val. PPL: 2359.936


100%|██████████| 94/94 [00:05<00:00, 17.46it/s]
100%|██████████| 24/24 [00:00<00:00, 56.60it/s]


	EPOCH:51/100	Train Loss: 1.015 | Train PPL:   2.760 	 Val. Loss: 7.791 |  Val. PPL: 2418.780


100%|██████████| 94/94 [00:04<00:00, 20.61it/s]
100%|██████████| 24/24 [00:00<00:00, 56.72it/s]


	EPOCH:52/100	Train Loss: 1.012 | Train PPL:   2.750 	 Val. Loss: 7.797 |  Val. PPL: 2433.861


100%|██████████| 94/94 [00:04<00:00, 19.06it/s]
100%|██████████| 24/24 [00:00<00:00, 43.55it/s]


	EPOCH:53/100	Train Loss: 1.012 | Train PPL:   2.752 	 Val. Loss: 7.856 |  Val. PPL: 2582.190


100%|██████████| 94/94 [00:05<00:00, 18.74it/s]
100%|██████████| 24/24 [00:00<00:00, 56.12it/s]


	EPOCH:54/100	Train Loss: 1.016 | Train PPL:   2.762 	 Val. Loss: 7.850 |  Val. PPL: 2564.682


100%|██████████| 94/94 [00:04<00:00, 20.15it/s]
100%|██████████| 24/24 [00:00<00:00, 54.83it/s]


	EPOCH:55/100	Train Loss: 0.965 | Train PPL:   2.625 	 Val. Loss: 7.862 |  Val. PPL: 2596.250


100%|██████████| 94/94 [00:05<00:00, 17.05it/s]
100%|██████████| 24/24 [00:00<00:00, 56.79it/s]


	EPOCH:56/100	Train Loss: 0.933 | Train PPL:   2.541 	 Val. Loss: 7.866 |  Val. PPL: 2606.534


100%|██████████| 94/94 [00:04<00:00, 20.13it/s]
100%|██████████| 24/24 [00:00<00:00, 55.94it/s]


	EPOCH:57/100	Train Loss: 0.897 | Train PPL:   2.451 	 Val. Loss: 7.912 |  Val. PPL: 2729.629


100%|██████████| 94/94 [00:04<00:00, 19.61it/s]
100%|██████████| 24/24 [00:00<00:00, 41.86it/s]


	EPOCH:58/100	Train Loss: 0.881 | Train PPL:   2.412 	 Val. Loss: 7.955 |  Val. PPL: 2851.212


100%|██████████| 94/94 [00:05<00:00, 18.15it/s]
100%|██████████| 24/24 [00:00<00:00, 56.64it/s]


	EPOCH:59/100	Train Loss: 0.862 | Train PPL:   2.367 	 Val. Loss: 7.954 |  Val. PPL: 2846.464


100%|██████████| 94/94 [00:04<00:00, 20.30it/s]
100%|██████████| 24/24 [00:00<00:00, 54.46it/s]


	EPOCH:60/100	Train Loss: 0.836 | Train PPL:   2.306 	 Val. Loss: 7.991 |  Val. PPL: 2953.949


100%|██████████| 94/94 [00:05<00:00, 18.26it/s]
100%|██████████| 24/24 [00:00<00:00, 42.41it/s]


	EPOCH:61/100	Train Loss: 0.827 | Train PPL:   2.287 	 Val. Loss: 8.020 |  Val. PPL: 3042.121


100%|██████████| 94/94 [00:04<00:00, 19.68it/s]
100%|██████████| 24/24 [00:00<00:00, 56.51it/s]


	EPOCH:62/100	Train Loss: 0.813 | Train PPL:   2.255 	 Val. Loss: 8.013 |  Val. PPL: 3019.919


100%|██████████| 94/94 [00:04<00:00, 20.17it/s]
100%|██████████| 24/24 [00:00<00:00, 55.98it/s]


	EPOCH:63/100	Train Loss: 0.820 | Train PPL:   2.271 	 Val. Loss: 8.043 |  Val. PPL: 3112.279


100%|██████████| 94/94 [00:05<00:00, 17.43it/s]
100%|██████████| 24/24 [00:00<00:00, 56.96it/s]


	EPOCH:64/100	Train Loss: 0.815 | Train PPL:   2.259 	 Val. Loss: 8.070 |  Val. PPL: 3196.401


100%|██████████| 94/94 [00:04<00:00, 20.51it/s]
100%|██████████| 24/24 [00:00<00:00, 56.80it/s]


	EPOCH:65/100	Train Loss: 0.804 | Train PPL:   2.234 	 Val. Loss: 8.087 |  Val. PPL: 3251.994


100%|██████████| 94/94 [00:04<00:00, 19.01it/s]
100%|██████████| 24/24 [00:00<00:00, 42.19it/s]


	EPOCH:66/100	Train Loss: 0.790 | Train PPL:   2.204 	 Val. Loss: 8.133 |  Val. PPL: 3404.372


100%|██████████| 94/94 [00:04<00:00, 19.03it/s]
100%|██████████| 24/24 [00:00<00:00, 55.63it/s]


	EPOCH:67/100	Train Loss: 0.786 | Train PPL:   2.196 	 Val. Loss: 8.108 |  Val. PPL: 3319.783


100%|██████████| 94/94 [00:04<00:00, 20.39it/s]
100%|██████████| 24/24 [00:00<00:00, 57.05it/s]


	EPOCH:68/100	Train Loss: 0.762 | Train PPL:   2.142 	 Val. Loss: 8.122 |  Val. PPL: 3368.564


100%|██████████| 94/94 [00:05<00:00, 17.47it/s]
100%|██████████| 24/24 [00:00<00:00, 57.10it/s]


	EPOCH:69/100	Train Loss: 0.747 | Train PPL:   2.112 	 Val. Loss: 8.154 |  Val. PPL: 3478.690


100%|██████████| 94/94 [00:04<00:00, 20.21it/s]
100%|██████████| 24/24 [00:00<00:00, 56.68it/s]


	EPOCH:70/100	Train Loss: 0.716 | Train PPL:   2.046 	 Val. Loss: 8.199 |  Val. PPL: 3638.118


100%|██████████| 94/94 [00:04<00:00, 19.58it/s]
100%|██████████| 24/24 [00:00<00:00, 41.90it/s]


	EPOCH:71/100	Train Loss: 0.733 | Train PPL:   2.081 	 Val. Loss: 8.255 |  Val. PPL: 3846.855


100%|██████████| 94/94 [00:05<00:00, 17.86it/s]
100%|██████████| 24/24 [00:00<00:00, 54.51it/s]


	EPOCH:72/100	Train Loss: 0.719 | Train PPL:   2.053 	 Val. Loss: 8.247 |  Val. PPL: 3817.441


100%|██████████| 94/94 [00:04<00:00, 20.09it/s]
100%|██████████| 24/24 [00:00<00:00, 54.96it/s]


	EPOCH:73/100	Train Loss: 0.705 | Train PPL:   2.024 	 Val. Loss: 8.235 |  Val. PPL: 3769.449


100%|██████████| 94/94 [00:05<00:00, 17.64it/s]
100%|██████████| 24/24 [00:00<00:00, 40.60it/s]


	EPOCH:74/100	Train Loss: 0.678 | Train PPL:   1.970 	 Val. Loss: 8.281 |  Val. PPL: 3947.750


100%|██████████| 94/94 [00:04<00:00, 20.39it/s]
100%|██████████| 24/24 [00:00<00:00, 56.80it/s]


	EPOCH:75/100	Train Loss: 0.683 | Train PPL:   1.980 	 Val. Loss: 8.300 |  Val. PPL: 4023.876


100%|██████████| 94/94 [00:04<00:00, 20.51it/s]
100%|██████████| 24/24 [00:00<00:00, 48.98it/s]


	EPOCH:76/100	Train Loss: 0.658 | Train PPL:   1.930 	 Val. Loss: 8.381 |  Val. PPL: 4365.382


100%|██████████| 94/94 [00:06<00:00, 14.85it/s]
100%|██████████| 24/24 [00:00<00:00, 43.23it/s]


	EPOCH:77/100	Train Loss: 0.656 | Train PPL:   1.927 	 Val. Loss: 8.396 |  Val. PPL: 4428.124


100%|██████████| 94/94 [00:04<00:00, 19.68it/s]
100%|██████████| 24/24 [00:00<00:00, 52.59it/s]


	EPOCH:78/100	Train Loss: 0.648 | Train PPL:   1.912 	 Val. Loss: 8.392 |  Val. PPL: 4411.105


100%|██████████| 94/94 [00:04<00:00, 18.82it/s]
100%|██████████| 24/24 [00:00<00:00, 42.67it/s]


	EPOCH:79/100	Train Loss: 0.651 | Train PPL:   1.918 	 Val. Loss: 8.408 |  Val. PPL: 4484.914


100%|██████████| 94/94 [00:04<00:00, 19.26it/s]
100%|██████████| 24/24 [00:00<00:00, 56.85it/s]


	EPOCH:80/100	Train Loss: 0.652 | Train PPL:   1.919 	 Val. Loss: 8.491 |  Val. PPL: 4870.522


100%|██████████| 94/94 [00:04<00:00, 20.43it/s]
100%|██████████| 24/24 [00:00<00:00, 56.62it/s]


	EPOCH:81/100	Train Loss: 0.636 | Train PPL:   1.889 	 Val. Loss: 8.466 |  Val. PPL: 4750.907


100%|██████████| 94/94 [00:05<00:00, 17.22it/s]
100%|██████████| 24/24 [00:00<00:00, 53.15it/s]


	EPOCH:82/100	Train Loss: 0.617 | Train PPL:   1.854 	 Val. Loss: 8.447 |  Val. PPL: 4662.227


100%|██████████| 94/94 [00:04<00:00, 20.54it/s]
100%|██████████| 24/24 [00:00<00:00, 56.01it/s]


	EPOCH:83/100	Train Loss: 0.620 | Train PPL:   1.858 	 Val. Loss: 8.473 |  Val. PPL: 4785.382


100%|██████████| 94/94 [00:04<00:00, 19.58it/s]
100%|██████████| 24/24 [00:00<00:00, 43.14it/s]


	EPOCH:84/100	Train Loss: 0.616 | Train PPL:   1.852 	 Val. Loss: 8.479 |  Val. PPL: 4811.132


100%|██████████| 94/94 [00:05<00:00, 18.45it/s]
100%|██████████| 24/24 [00:00<00:00, 55.59it/s]


	EPOCH:85/100	Train Loss: 0.616 | Train PPL:   1.851 	 Val. Loss: 8.484 |  Val. PPL: 4835.712


100%|██████████| 94/94 [00:04<00:00, 20.31it/s]
100%|██████████| 24/24 [00:00<00:00, 55.00it/s]


	EPOCH:86/100	Train Loss: 0.626 | Train PPL:   1.871 	 Val. Loss: 8.518 |  Val. PPL: 5004.293


100%|██████████| 94/94 [00:05<00:00, 17.94it/s]
100%|██████████| 24/24 [00:00<00:00, 43.44it/s]


	EPOCH:87/100	Train Loss: 0.614 | Train PPL:   1.847 	 Val. Loss: 8.512 |  Val. PPL: 4973.581


100%|██████████| 94/94 [00:04<00:00, 20.34it/s]
100%|██████████| 24/24 [00:00<00:00, 55.25it/s]


	EPOCH:88/100	Train Loss: 0.613 | Train PPL:   1.846 	 Val. Loss: 8.574 |  Val. PPL: 5294.735


100%|██████████| 94/94 [00:04<00:00, 20.37it/s]
100%|██████████| 24/24 [00:00<00:00, 50.96it/s]


	EPOCH:89/100	Train Loss: 0.597 | Train PPL:   1.816 	 Val. Loss: 8.609 |  Val. PPL: 5480.186


100%|██████████| 94/94 [00:05<00:00, 17.73it/s]
100%|██████████| 24/24 [00:00<00:00, 55.09it/s]


	EPOCH:90/100	Train Loss: 0.603 | Train PPL:   1.828 	 Val. Loss: 8.639 |  Val. PPL: 5648.134


100%|██████████| 94/94 [00:04<00:00, 20.28it/s]
100%|██████████| 24/24 [00:00<00:00, 56.51it/s]


	EPOCH:91/100	Train Loss: 0.602 | Train PPL:   1.827 	 Val. Loss: 8.641 |  Val. PPL: 5659.072


100%|██████████| 94/94 [00:05<00:00, 18.72it/s]
100%|██████████| 24/24 [00:00<00:00, 44.81it/s]


	EPOCH:92/100	Train Loss: 0.608 | Train PPL:   1.836 	 Val. Loss: 8.713 |  Val. PPL: 6079.508


100%|██████████| 94/94 [00:04<00:00, 18.84it/s]
100%|██████████| 24/24 [00:00<00:00, 55.06it/s]


	EPOCH:93/100	Train Loss: 0.585 | Train PPL:   1.796 	 Val. Loss: 8.627 |  Val. PPL: 5579.826


100%|██████████| 94/94 [00:04<00:00, 20.19it/s]
100%|██████████| 24/24 [00:00<00:00, 55.94it/s]


	EPOCH:94/100	Train Loss: 0.593 | Train PPL:   1.809 	 Val. Loss: 8.675 |  Val. PPL: 5852.163


100%|██████████| 94/94 [00:05<00:00, 16.82it/s]
100%|██████████| 24/24 [00:00<00:00, 54.80it/s]


	EPOCH:95/100	Train Loss: 0.596 | Train PPL:   1.815 	 Val. Loss: 8.636 |  Val. PPL: 5631.151


100%|██████████| 94/94 [00:04<00:00, 20.28it/s]
100%|██████████| 24/24 [00:00<00:00, 55.53it/s]


	EPOCH:96/100	Train Loss: 0.572 | Train PPL:   1.772 	 Val. Loss: 8.654 |  Val. PPL: 5731.100


100%|██████████| 94/94 [00:04<00:00, 19.13it/s]
100%|██████████| 24/24 [00:00<00:00, 42.32it/s]


	EPOCH:97/100	Train Loss: 0.571 | Train PPL:   1.769 	 Val. Loss: 8.666 |  Val. PPL: 5800.517


100%|██████████| 94/94 [00:05<00:00, 18.33it/s]
100%|██████████| 24/24 [00:00<00:00, 54.67it/s]


	EPOCH:98/100	Train Loss: 0.545 | Train PPL:   1.725 	 Val. Loss: 8.679 |  Val. PPL: 5879.798


100%|██████████| 94/94 [00:04<00:00, 20.44it/s]
100%|██████████| 24/24 [00:00<00:00, 52.44it/s]


	EPOCH:99/100	Train Loss: 0.545 | Train PPL:   1.725 	 Val. Loss: 8.716 |  Val. PPL: 6100.249


100%|██████████| 94/94 [00:05<00:00, 17.09it/s]
100%|██████████| 24/24 [00:00<00:00, 42.61it/s]


	EPOCH:100/100	Train Loss: 0.523 | Train PPL:   1.687 	 Val. Loss: 8.762 |  Val. PPL: 6383.985


In [106]:
# torch.save(model.state_dict(), "seq2seq.pt")

In [107]:
# model.load_state_dict(torch.load("/kaggle/input/mentalhealth-chatbot/seq2seq.pt",map_location=torch.device('cpu')))

In [108]:
model = model.eval()

In [109]:
def tokenize(qn):
    qn = qn.lower()
    qn = qn.translate(str.maketrans('', '', string.punctuation))
    indices = [vocab_en[word] for word in qn.split()]
    indices.insert(0,vocab_en["<BOS>"])
    indices.append(vocab_en["<EOS>"])
    ip_tensor = torch.tensor(indices).unsqueeze(0).permute(1,0)

    return ip_tensor

In [110]:
def make_pred(qn,trg_len=12):
    src = tokenize(qn).to(device)
    encoder_outputs, hidden = model.encoder(src, torch.tensor([len(src)]))
    trg = torch.tensor([[vocab_en["<BOS>"]]])

    with torch.no_grad():
        trg_ip = trg[0,:].to(device)
        trg_vocab_size = model.decoder.output_dim
        outputs = torch.zeros(trg_len, 1, trg_vocab_size).to(device)
        mask = model.create_mask(src).to(device)

        for t in range(1, trg_len):
            #insert input token embedding, previous hidden state and all encoder hidden states
            #receive output tensor (predictions) and new hidden state
            output, hidden, _ = model.decoder(trg_ip, hidden, encoder_outputs, mask)

            #place predictions in a tensor holding predictions for each token
            outputs[t] = output

            #get the highest predicted token from our predictions
            trg_ip = output.argmax(1)

        outputs = outputs.squeeze(1)
        return " ".join([vocab_en.get_itos()[i] for i in outputs.argmax(1) if i not in [0,1,2]])

In [113]:
while True:
    you = input("YOU > ")
    if you=="q" or you=="quit":
        break

    bot = make_pred(you)
    print(f"BOT > {bot}")
print(f"BOT > Bye! Have a good day dear.")

YOU > hello dear
BOT > the got got every did did lot lot lot
YOU > lot lot what
BOT > i course course course
YOU > no course
BOT > i
YOU > i 
BOT > but do the the
YOU > do what
BOT > dont i her by
YOU > who is her
BOT > a a did did a a with so
YOU > did what
BOT > then then when when rude it
YOU > she rude?
BOT > you i i you
YOU > me rude?
BOT > you i i i i go
YOU > no i am not
BOT > are to to your my you but you you
YOU > yes yes yes
BOT > last date i the dirty shop me me
YOU > when did it happen
BOT > whats and
YOU > and what
BOT > dont do do do plus want want
YOU > bye
BOT > good the move it job before before
YOU > q
BOT > Bye! Have a good day dear.
