In [16]:
import re
import unicodedata
import itertools
import torch
import numpy as np
import torch.nn as nn

# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count default tokens

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = [k for k, v in self.word2count.items() if v >= min_count]

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3

        for word in keep_words:
            self.addWord(word)

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for seq in l:
        m.append([0 if token == PAD_token else 1 for token in seq])
    return m

def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = zip(*pair_batch)
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

# GloVe integration
def load_glove_embeddings(voc, glove_path, embedding_dim=100, freeze=False):
    print("Loading GloVe embeddings...")
    glove = {}
    with open(glove_path, encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            word = parts[0]
            vector = np.array(parts[1:], dtype=np.float32)
            glove[word] = vector

    embedding_matrix = np.random.normal(0, 1, (voc.num_words, embedding_dim)).astype(np.float32)
    found = 0
    for word, idx in voc.word2index.items():
        if word in glove:
            embedding_matrix[idx] = glove[word]
            found += 1

    print(f"Found {found}/{voc.num_words} words in GloVe.")
    tensor = torch.tensor(embedding_matrix)
    return nn.Embedding.from_pretrained(tensor, freeze=freeze)


In [18]:
from sklearn.model_selection import train_test_split
import os
import random
import torch
from torch.utils.data import Dataset, DataLoader
#from utils import Voc, normalizeString, batch2TrainData, load_glove_embeddings

MAX_LENGTH = 15

class ChatDataset(Dataset):
    def __init__(self, pairs, voc):
        self.pairs = pairs
        self.voc = voc

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx]

def readVocs(datafile, corpus_name):
    print("Reading lines...")
    lines = open(datafile, encoding='utf-8').read().strip().split('\n')
    pairs = []
    for l in lines:
        parts = l.split('\t')
        if len(parts) == 2:
            pairs.append([normalizeString(parts[0]), normalizeString(parts[1])])

    voc = Voc(corpus_name)
    return voc, pairs

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

def trimRareWords(voc, pairs, MIN_COUNT):
    voc.trim(MIN_COUNT)
    keep_pairs = []
    for pair in pairs:
        input_sentence, output_sentence = pair
        keep_input = all(word in voc.word2index for word in input_sentence.split(' '))
        keep_output = all(word in voc.word2index for word in output_sentence.split(' '))
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs

def collate_fn(batch, voc):
    return batch2TrainData(voc, batch)

def split_dataset(pairs, test_size=0.1, random_state=42):
    train_pairs, val_pairs = train_test_split(pairs, test_size=test_size, random_state=random_state)
    print(f"Split {len(pairs)} pairs into {len(train_pairs)} train and {len(val_pairs)} validation")
    return train_pairs, val_pairs

def get_dataloader(pairs, voc, batch_size=64, shuffle=True):
    dataset = ChatDataset(pairs, voc)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=lambda x: collate_fn(x, voc))

In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Encoder with LSTM
class LSTMEncoder(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(LSTMEncoder, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding
        self.lstm = nn.LSTM(embedding.embedding_dim, hidden_size, n_layers,
                            dropout=(0 if n_layers == 1 else dropout),
                            bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        embedded = self.embedding(input_seq)
        self.lstm.flatten_parameters()  # Fix for RNN warning
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, enforce_sorted=False)
        outputs, hidden = self.lstm(packed, hidden)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
        return outputs, hidden

# Attention mechanism
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        self.hidden_size = hidden_size

        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(f"{self.method} is not a valid attention method.")

        if self.method == 'general':
            self.attn = nn.Linear(hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1),
                                      encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        else:  # dot
            attn_energies = self.dot_score(hidden, encoder_outputs)

        attn_energies = attn_energies.t()
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

# Decoder with Attention and LSTM
class LSTMAttnDecoder(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LSTMAttnDecoder, self).__init__()
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.embedding = embedding

        self.embedding_dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(embedding.embedding_dim, hidden_size, n_layers,
                            dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        self.lstm.flatten_parameters()  # Fix for RNN warning
        rnn_output, hidden = self.lstm(embedded, last_hidden)
        attn_weights = self.attn(rnn_output, encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        return output, hidden


In [46]:
## Initializes wandb api key = "delete"

def init_wandb(project_name="seq2seqLSTM", config=None):
    if config is None:
        config = {
            "model_name": "deep_lstm_seq2seq",
            "attn_model": "dot",
            "embedding": "glove.6B.300d",
            "embedding_dim": 300,  # match paper
            "freeze_embeddings": False,
            "hidden_size": 1000,  # deep hidden state
            "encoder_n_layers": 4,
            "decoder_n_layers": 4,
            "dropout": 0.1,
            "batch_size": 128,  # match paper
            "learning_rate": 0.0001,  # fixed LR
            "decoder_learning_ratio": 1.0,  # same as encoder
            "teacher_forcing_ratio": 1.0,
            "clip": 5.0,  # gradient norm clipping
            "n_iteration": 10000,  # longer run
            "print_every": 20,
            "save_every": 1000
        }

    wandb.init(project=project_name, config=config)
    return wandb.config




In [30]:
!wget http://nlp.stanford.edu/data/glove.6B.zip
!unzip glove.6B.zip


--2025-04-18 10:32:38--  http://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.6B.zip [following]
--2025-04-18 10:32:38--  https://nlp.stanford.edu/data/glove.6B.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2025-04-18 10:32:39--  https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zip]
Saving to: ‘glove.6B.zip’


202

In [47]:
import os
import math
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
#from utils import Voc, load_glove_embeddings
#from dataset import loadPrepareData, trimRareWords, split_dataset, get_dataloader
#from models_seq2seq import LSTMEncoder, LSTMAttnDecoder
#from wandb_config import init_wandb
from datetime import datetime

dev_config = init_wandb()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

datafile = "/content/formatted_pairs.txt"
save_dir = os.path.join("data", "save")
corpus_name = "custom"
voc, pairs = loadPrepareData("data", corpus_name, datafile, save_dir)
pairs = trimRareWords(voc, pairs, MIN_COUNT=dev_config.MIN_COUNT if "MIN_COUNT" in dev_config else 3)

train_pairs, val_pairs = split_dataset(pairs, test_size=0.1)

train_loader = get_dataloader(train_pairs, voc, batch_size=dev_config.batch_size)
val_loader = get_dataloader(val_pairs, voc, batch_size=dev_config.batch_size)

embedding = load_glove_embeddings(
    voc,
    glove_path="/content/glove.6B.300d.txt",
    embedding_dim=dev_config.embedding_dim,
    freeze=dev_config.freeze_embeddings
)

encoder = LSTMEncoder(dev_config.hidden_size, embedding, dev_config.encoder_n_layers, dev_config.dropout).to(device)
decoder = LSTMAttnDecoder(dev_config.attn_model, embedding, dev_config.hidden_size, voc.num_words,
                          dev_config.decoder_n_layers, dev_config.dropout).to(device)

encoder_optimizer = optim.Adam(encoder.parameters(), lr=dev_config.learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=dev_config.learning_rate * dev_config.decoder_learning_ratio)

def save_checkpoint_tar(voc, encoder, decoder, embedding, encoder_optimizer, decoder_optimizer, iteration, loss, save_path="checkpoint.tar"):
    checkpoint = {
        'iteration': iteration,
        'encoder_state': encoder.state_dict(),
        'decoder_state': decoder.state_dict(),
        'embedding_state': embedding.state_dict(),
        'encoder_optimizer_state': encoder_optimizer.state_dict(),
        'decoder_optimizer_state': decoder_optimizer.state_dict(),
        'voc_dict': voc.__dict__,
        'loss': loss
    }
    torch.save(checkpoint, save_path)
    with open("voc.pkl", "wb") as f:
        pickle.dump(voc, f)

def log_artifacts_to_wandb(tar_path="checkpoint.tar", voc_path="voc.pkl", artifact_name="chatbot_model"):
    artifact = wandb.Artifact(artifact_name, type="model")
    artifact.add_file(tar_path)
    artifact.add_file(voc_path)
    wandb.log_artifact(artifact)

def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    return loss, nTotal.item()

def train(input_variable, lengths, target_variable, mask, max_target_len,
          encoder, decoder, embedding, encoder_optimizer, decoder_optimizer, clip):

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_variable = input_variable.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)
    lengths = lengths.to("cpu")

    current_batch_size = input_variable.size(1)

    loss = 0
    print_losses = []
    n_totals = 0

    encoder_outputs, (encoder_hidden, encoder_cell) = encoder(input_variable, lengths)
    decoder_input = torch.LongTensor([[1 for _ in range(current_batch_size)]]).to(device)
    decoder_hidden = (encoder_hidden[:decoder.n_layers], encoder_cell[:decoder.n_layers])

    use_teacher_forcing = True if torch.rand(1).item() < dev_config.teacher_forcing_ratio else False

    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_input = target_variable[t].view(1, -1)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(current_batch_size)]]).to(device)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    loss.backward()
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), dev_config.clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), dev_config.clip)
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

def evaluate_loss(val_loader, encoder, decoder, embedding):
    encoder.eval()
    decoder.eval()
    total_loss = 0
    total_count = 0
    with torch.no_grad():
        for input_variable, lengths, target_variable, mask, max_target_len in val_loader:
            input_variable = input_variable.to(device)
            target_variable = target_variable.to(device)
            mask = mask.to(device)
            lengths = lengths.to("cpu")

            current_batch_size = input_variable.size(1)

            encoder_outputs, (encoder_hidden, encoder_cell) = encoder(input_variable, lengths)
            decoder_input = torch.LongTensor([[1 for _ in range(current_batch_size)]]).to(device)
            decoder_hidden = (encoder_hidden[:decoder.n_layers], encoder_cell[:decoder.n_layers])

            for t in range(max_target_len):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
                decoder_input = target_variable[t].view(1, -1)
                mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
                total_loss += mask_loss.item() * nTotal
                total_count += nTotal
    encoder.train()
    decoder.train()
    return total_loss / total_count

print("\nStarting training...")
train_iter = iter(train_loader)
for iteration in range(1, dev_config.n_iteration + 1):
    try:
        batch = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        batch = next(train_iter)

    input_variable, lengths, target_variable, mask, max_target_len = batch
    train_loss = train(input_variable, lengths, target_variable, mask, max_target_len,
                       encoder, decoder, embedding, encoder_optimizer, decoder_optimizer,
                       dev_config.clip)

    perplexity = math.exp(train_loss)
    wandb.log({
        "train_loss": train_loss,
        "train_perplexity": perplexity,
        "iteration": iteration
    })

    if iteration % dev_config.print_every == 0:
        print("Iteration: {}; Train Loss: {:.4f} | Perplexity: {:.4f}".format(iteration, train_loss, perplexity))

    if iteration % dev_config.save_every == 0:
        val_loss = evaluate_loss(val_loader, encoder, decoder, embedding)
        val_perplexity = math.exp(val_loss)
        wandb.log({
            "val_loss": val_loss,
            "val_perplexity": val_perplexity
        })
        print("Validation Loss: {:.4f} | Perplexity: {:.4f}".format(val_loss, val_perplexity))

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        save_checkpoint_tar(voc, encoder, decoder, embedding, encoder_optimizer, decoder_optimizer, iteration, val_loss, f"checkpoint_{timestamp}.tar")
        log_artifacts_to_wandb(f"checkpoint_{timestamp}.tar", "voc.pkl", f"chatbot_checkpoint_{iteration}")

save_checkpoint_tar(voc, encoder, decoder, embedding, encoder_optimizer, decoder_optimizer, iteration, train_loss, "final_checkpoint.tar")
log_artifacts_to_wandb("final_checkpoint.tar", "voc.pkl", "final_chatbot_checkpoint")
print("Final checkpoint saved to W&B.")


0,1
iteration,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇██████
train_loss,▁

0,1
iteration,221.0
train_loss,
train_perplexity,


Start preparing training data ...
Reading lines...
Read 129656 sentence pairs
Trimmed to 43827 sentence pairs
Counting words...
Counted words: 16561
keep_words 9888 / 16558 = 0.5972
Trimmed from 43827 pairs to 34583, 0.7891 of total
Split 34583 pairs into 31124 train and 3459 validation
Loading GloVe embeddings...
Found 9680/9891 words in GloVe.

Starting training...
Iteration: 20; Train Loss: 7.4572 | Perplexity: 1732.2342
Iteration: 40; Train Loss: 6.2775 | Perplexity: 532.4702
Iteration: 60; Train Loss: 6.1348 | Perplexity: 461.6549
Iteration: 80; Train Loss: 6.0247 | Perplexity: 413.5223
Iteration: 100; Train Loss: 6.0302 | Perplexity: 415.7794
Iteration: 120; Train Loss: 6.1636 | Perplexity: 475.1329
Iteration: 140; Train Loss: 6.1976 | Perplexity: 491.5755
Iteration: 160; Train Loss: 6.0156 | Perplexity: 409.7908
Iteration: 180; Train Loss: 5.8294 | Perplexity: 340.1674
Iteration: 200; Train Loss: 6.0477 | Perplexity: 423.1266
Iteration: 220; Train Loss: 6.0626 | Perplexity: 429.

In [58]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
#from utils import indexesFromSentence, normalizeString
#from config import MAX_LENGTH

# Helper to move tensor to correct device
def to_device(tensor):
    return tensor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# Safe version that logs unknown words
def safe_indexesFromSentence(voc, sentence):
    missing = []
    indexes = []
    for word in sentence.split(" "):
        if word in voc.word2index:
            indexes.append(voc.word2index[word])
        else:
            missing.append(word)
    if missing:
        print(f"Missing words in vocab: {missing}")
    return indexes + [2]  # EOS_token

# Greedy decoder
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)

        # For LSTM: encoder_hidden is a tuple (h, c)
        decoder_hidden = (encoder_hidden[0][:self.decoder.n_layers], encoder_hidden[1][:self.decoder.n_layers])
        decoder_input = torch.ones(1, 1, device=input_seq.device, dtype=torch.long) * 1  # SOS_token

        all_tokens = torch.zeros([0], device=input_seq.device, dtype=torch.long)
        all_scores = torch.zeros([0], device=input_seq.device)

        for _ in range(max_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            decoder_input = torch.unsqueeze(decoder_input, 0)

        return all_tokens, all_scores

# Final evaluation wrapper
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    sentence = normalizeString(sentence)
    indexes_batch = [safe_indexesFromSentence(voc, sentence)]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    input_batch = to_device(input_batch)
    lengths = lengths.to("cpu")

    with torch.no_grad():
        tokens, scores = searcher(input_batch, lengths, max_length)
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words

In [50]:
import wandb

artifact = wandb.use_artifact("abhi1199-city-university-of-london/seq2seqLSTM/final_chatbot_checkpoint:v0", type="model")
artifact_dir = artifact.download()

checkpoint_path = os.path.join(artifact_dir, "final_checkpoint.tar")
voc_path = os.path.join(artifact_dir, "voc.pkl")


checkpoint = torch.load(checkpoint_path, map_location=device)

with open(voc_path, "rb") as f:
    voc = pickle.load(f)


[34m[1mwandb[0m: Downloading large artifact final_chatbot_checkpoint:v0, 1471.71MB. 2 files... 
[34m[1mwandb[0m:   2 of 2 files downloaded.  
Done. 0:0:10.4


In [59]:
embedding = load_glove_embeddings(
    voc,
    glove_path="/content/glove.6B.300d.txt",
    embedding_dim=300,
    freeze=False
)



encoder = LSTMEncoder(
    hidden_size=1000,
    embedding=embedding,
    n_layers=4,
    dropout=0.1
).to(device)

decoder = LSTMAttnDecoder(
    attn_model="dot",
    embedding=embedding,
    hidden_size=1000,
    output_size=voc.num_words,
    n_layers=4,
    dropout=0.1
).to(device)


encoder.load_state_dict(checkpoint["encoder_state"])
decoder.load_state_dict(checkpoint["decoder_state"])
embedding.load_state_dict(checkpoint["embedding_state"])


Loading GloVe embeddings...
Found 9680/9891 words in GloVe.


<All keys matched successfully>

In [60]:
#from evaluate import GreedySearchDecoder, evaluate

encoder.eval()
decoder.eval()
searcher = GreedySearchDecoder(encoder, decoder)


In [61]:
def chat():
    print("FinanceBot is ready! Type 'quit' to exit.")
    while True:
        input_sentence = input("> ")
        if input_sentence.lower() in ["quit", "q"]:
            break
        try:
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            output_words = [w for w in output_words if w not in ["EOS", "PAD"]]
            print("Bot:", " ".join(output_words))
        except KeyError:
            print("Oops! Encountered unknown word.")

chat()


🤖 FinanceBot is ready! Type 'quit' to exit.
> car loan
Bot: . price . . .
> housing loan enquiry
❌ Missing words in vocab: ['enquiry']
Bot: the total salary is . . . . .
> loan enquiry
❌ Missing words in vocab: ['enquiry']
Bot:  . . . .
> what is inflation
Bot: this is change is . . . . .
> Can a company block a specific person from buying its stock?
Bot: dear spam plan . . .
> Do I have to pay a capital gains tax if I rebuy the same stock within 30 days?
❌ Missing words in vocab: ['gains', 'rebuy']
Bot: yes what did find make the market of start for week ?
> Can a credit card company raise my rates for making a large payment?
Bot: dear customer . . .
> How to motivate young people to save money
Bot: money plan to help daily members to save their poor results .
> How much should a new graduate with new job put towards a car?
Bot: we s learning feedback for break a task at this product .
> What are my investment options in real estate?
Bot: here are reasons that have have equity that of

In [None]:
 # i want to buy a car. should i buy or not

 # What percent of my salary should I save?

 # Is it wise to switch investment strategy frequently?

 # The best investment at this stage is a good, easy to understand but thorough book on finance

 # How to motivate young people to save money

 # How much should a new graduate with new job put towards a car?

 # What are my investment options in real estate?

 # Is it ever a good idea to close credit cards?


 # Would I need to keep track of 1099s?

 # Will I be paid dividends if I own shares?
