In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [32]:
!pip install -q datasets

In [None]:
import os
import random
import string
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from tqdm import tqdm
from datasets import load_dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import uuid

# -----------------------
# Reproducibility
# -----------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# -----------------------
# Command-Line Arguments (with Jupyter compatibility)
# -----------------------
def is_jupyter():
    try:
        get_ipython()
        return True
    except NameError:
        return False

if is_jupyter():
    args = {
        'model_dir': '/content/drive/MyDrive/models',
        'checkpoint_dir': '/content/drive/MyDrive/checkpoints',
        'log_dir': 'logs',
        'epochs': 10,
        'batch_size': 64
    }
    args = argparse.Namespace(**args)
else:
    parser = argparse.ArgumentParser(description="Seq2Seq Chatbot Training")
    parser.add_argument('--model-dir', default='models', help='Directory to save models')
    parser.add_argument('--checkpoint-dir', default='checkpoints', help='Directory to save checkpoints')
    parser.add_argument('--log-dir', default='logs', help='Directory for TensorBoard logs')
    parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
    parser.add_argument('--batch-size', type=int, default=64, help='Batch size')
    args = parser.parse_args()

# -----------------------
# Constants & Hyperparameters
# -----------------------
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"
VOCAB = None
VOCAB_SIZE = None
char2idx = None
idx2char = None

EMBEDDING_DIM = 256
HIDDEN_SIZE = 768
NUM_LAYERS = 3
DROPOUT = 0.3
BATCH_SIZE = args.batch_size
NUM_EPOCHS = args.epochs
MAX_SEQ_LEN = 100
INITIAL_TEACHER_FORCING = 0.7
MIN_TEACHER_FORCING = 0.3
LEARNING_RATE = 5e-4
BEAM_WIDTH = 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter(log_dir=args.log_dir)

# -----------------------
# Teacher Forcing Scheduler
# -----------------------
def current_tf(epoch):
    return max(MIN_TEACHER_FORCING,
               INITIAL_TEACHER_FORCING - ((epoch - 1) * (INITIAL_TEACHER_FORCING - MIN_TEACHER_FORCING) / NUM_EPOCHS))

# -----------------------
# Dataset: DailyDialog pairs
# -----------------------
class ChatbotDataset(Dataset):
    def __init__(self,split='train', max_seq_len=MAX_SEQ_LEN, vocab=None):
        self.max_seq_len = max_seq_len
        raw = load_dataset('daily_dialog', split=split)
        pairs = []
        chars = set()

        for dialog in raw['dialog']:
            for i in range(len(dialog) - 1):
                src, tgt = dialog[i].lower(), dialog[i+1].lower()
                if len(src) <= max_seq_len and len(tgt) <= max_seq_len:
                    pairs.append((src, tgt))
                    chars.update(src + tgt)

        global VOCAB, VOCAB_SIZE, char2idx, idx2char
        if vocab is None:
            chars = sorted(list(chars))[:1000]
            VOCAB = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + chars
            VOCAB_SIZE = len(VOCAB)
            char2idx = {c: i for i, c in enumerate(VOCAB)}
            idx2char = {i: c for i, c in enumerate(VOCAB)}
        else:
            VOCAB = vocab
            VOCAB_SIZE = len(VOCAB)
            char2idx = {c: i for i, c in enumerate(VOCAB)}
            idx2char = {i: c for i, c in enumerate(VOCAB)}

        augmented_pairs = []
        chars_list = list(chars)
        for src, tgt in pairs:
            augmented_pairs.append((src, tgt))
            if random.random() < 0.1:
                src_aug = ''.join(c if random.random() > 0.05 else random.choice(chars_list) for c in src)
                tgt_aug = ''.join(c if random.random() > 0.05 else random.choice(chars_list) for c in tgt)
                augmented_pairs.append((src_aug, tgt_aug))

        random.shuffle(augmented_pairs)
        self.pairs = augmented_pairs
        print(f"Loaded {len(self.pairs)} pairs from DailyDialog {split}, VOCAB_SIZE={VOCAB_SIZE}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        src_ids = [char2idx[SOS_TOKEN]] + [char2idx.get(ch, char2idx[UNK_TOKEN]) for ch in src]
        tgt_ids = [char2idx.get(ch, char2idx[UNK_TOKEN]) for ch in tgt] + [char2idx[EOS_TOKEN]]
        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(tgt_ids, dtype=torch.long)

def collate_fn(batch):
    srcs, tgts = zip(*batch)
    lens_src = [len(s) for s in srcs]
    lens_tgt = [len(t) for t in tgts]
    src_pad = nn.utils.rnn.pad_sequence(srcs, batch_first=True, padding_value=char2idx[PAD_TOKEN])
    tgt_pad = nn.utils.rnn.pad_sequence(tgts, batch_first=True, padding_value=char2idx[PAD_TOKEN])
    return src_pad, torch.tensor(lens_src), tgt_pad, torch.tensor(lens_tgt)

train_ds = ChatbotDataset(split='train')
vocab = VOCAB
val_ds = ChatbotDataset(split='validation', vocab=vocab)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# -----------------------
# Model Definitions
# -----------------------
class Encoder(nn.Module):
    def __init__(self, vs, ed, hs, nl, do):
        super().__init__()
        self.embedding = nn.Embedding(vs, ed)
        lstm_dropout = do if nl > 1 else 0
        self.lstm = nn.LSTM(ed, hs, nl, batch_first=True, dropout=lstm_dropout, bidirectional=True)
        self.norm = nn.LayerNorm(hs * 2)

    def forward(self, x, lengths):
        emb = self.embedding(x)
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        out_packed, (h, c) = self.lstm(packed)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(out_packed, batch_first=True)
        h = h.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size).sum(dim=1)
        c = c.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size).sum(dim=1)
        outputs = self.norm(outputs)
        return outputs, h, c

class Decoder(nn.Module):
    def __init__(self, vs, ed, hs, nl, do):
        super().__init__()
        self.embedding = nn.Embedding(vs, ed)
        lstm_dropout = do if nl > 1 else 0
        self.lstm = nn.LSTM(ed, hs, nl, batch_first=True, dropout=lstm_dropout)
        self.enc_proj = nn.Linear(hs * 2, hs)
        self.attention = nn.MultiheadAttention(hs, num_heads=8, dropout=do)
        self.norm1 = nn.LayerNorm(hs)
        self.norm2 = nn.LayerNorm(hs)
        self.fc = nn.Linear(hs, vs)
        self.dropout = nn.Dropout(do)

    def forward(self, token, h, c, enc_out):
        emb = self.embedding(token.unsqueeze(1))  # (batch_size, 1, EMBEDDING_DIM)
        lstm_out, (h, c) = self.lstm(emb, (h, c))  # (batch_size, 1, HIDDEN_SIZE)
        enc_out_proj = self.enc_proj(enc_out)  # (batch_size, seq_len, HIDDEN_SIZE)
        attn_out, _ = self.attention(
            lstm_out.transpose(0, 1),  # (1, batch_size, HIDDEN_SIZE)
            enc_out_proj.transpose(0, 1),  # (seq_len, batch_size, HIDDEN_SIZE)
            enc_out_proj.transpose(0, 1)
        )  # (1, batch_size, HIDDEN_SIZE)
        attn_out = attn_out.transpose(0, 1)  # (batch_size, 1, HIDDEN_SIZE)
        out = self.norm1(lstm_out + attn_out)  # (batch_size, 1, HIDDEN_SIZE)
        out = self.norm2(self.dropout(torch.relu(out)))  # (batch_size, 1, HIDDEN_SIZE)
        out = out.view(-1, out.size(-1))  # (batch_size, HIDDEN_SIZE)
        return self.fc(out), h, c

class Seq2Seq(nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.encoder = enc
        self.decoder = dec

    def forward(self, src, lengths, tgt, tf_ratio):
        batch_size, tgt_len = tgt.size()
        outputs = torch.zeros(batch_size, tgt_len, VOCAB_SIZE, device=src.device)
        enc_out, h, c = self.encoder(src, lengths)
        inp = torch.full((batch_size,), char2idx[SOS_TOKEN], dtype=torch.long, device=src.device)
        for t in range(tgt_len):
            out_step, h, c = self.decoder(inp, h, c, enc_out)
            outputs[:, t, :] = out_step
            teacher_force = random.random() < tf_ratio
            top1 = out_step.argmax(1)
            inp = tgt[:, t] if teacher_force else top1
        return outputs

# -----------------------
# Initialize Everything
# -----------------------
encoder = Encoder(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_SIZE, NUM_LAYERS, DROPOUT).to(device)
decoder = Decoder(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_SIZE, NUM_LAYERS, DROPOUT).to(device)
model = Seq2Seq(encoder, decoder).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=char2idx[PAD_TOKEN])

# -----------------------
# Batch Beam Search Decode
# -----------------------
def batch_beam_search_decode(model, enc_out, h, c, beam_width=BEAM_WIDTH, max_len=100):
    batch_size = enc_out.size(0)
    beams = [[([char2idx[SOS_TOKEN]], 0.0, h[:, i:i+1], c[:, i:i+1])] for i in range(batch_size)]
    completed = [[] for _ in range(batch_size)]

    for _ in range(max_len):
        new_beams = [[] for _ in range(batch_size)]
        for b in range(batch_size):
            if len(completed[b]) >= beam_width:
                continue
            for seq, score, h1, c1 in beams[b]:
                inp = torch.tensor([seq[-1]], device=h1.device, dtype=torch.long)
                out, h2, c2 = model.decoder(inp, h1, c1, enc_out[b:b+1])
                logp = torch.log_softmax(out, dim=1)  # Shape: (1, VOCAB_SIZE)
                topk = torch.topk(logp, beam_width, dim=1)
                for idx, lp in zip(topk.indices[0], topk.values[0]):
                    new_seq = seq + [idx.item()]
                    lp_factor = ((5 + len(new_seq)) / 6) ** 0.7
                    new_score = (score + lp.item()) / lp_factor
                    if idx.item() == char2idx[EOS_TOKEN]:
                        completed[b].append((new_seq, new_score))
                    else:
                        new_beams[b].append((new_seq, new_score, h2, c2))
            beams[b] = sorted(new_beams[b], key=lambda x: x[1], reverse=True)[:beam_width]
        if all(len(c) >= beam_width for c in completed):
            break

    outputs = []
    for b in range(batch_size):
        best = max(completed[b] + beams[b], key=lambda x: x[1]) if completed[b] or beams[b] else ([], 0.0)
        seq = ''.join(idx2char[i] for i in best[0] if i not in (char2idx[SOS_TOKEN], char2idx[EOS_TOKEN]))
        outputs.append(seq)
    return outputs

# -----------------------
# Training & Validation
# -----------------------
def train_epoch(epoch):
    model.train()
    total_loss = 0
    tf = current_tf(epoch)
    for src, slen, tgt, tlen in tqdm(train_loader, desc=f"Train {epoch}"):
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        out = model(src, slen, tgt, tf)
        loss = criterion(out.view(-1, VOCAB_SIZE), tgt.view(-1))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    writer.add_scalar('Loss/Train', avg_loss, epoch)
    writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], epoch)
    return avg_loss

def validate():
    model.eval()
    total_loss = 0
    references = []
    hypotheses = []
    with torch.no_grad():
        for src, slen, tgt, tlen in tqdm(val_loader, desc="Val"):
            src, tgt = src.to(device), tgt.to(device)
            out = model(src, slen, tgt, tf_ratio=0)
            loss = criterion(out.view(-1, VOCAB_SIZE), tgt.view(-1))
            total_loss += loss.item()

            enc_out, h, c = model.encoder(src, slen)
            preds = batch_beam_search_decode(model, enc_out, h, c)
            for pred, ref in zip(preds, tgt):
                ref_chars = [idx2char[idx.item()] for idx in ref if idx not in (char2idx[PAD_TOKEN], char2idx[EOS_TOKEN])]
                references.append([ref_chars])
                hypotheses.append(list(pred))

    avg_loss = total_loss / len(val_loader)
    bleu = corpus_bleu(references, hypotheses, smoothing_function=SmoothingFunction().method1)
    perplexity = torch.exp(torch.tensor(avg_loss))
    writer.add_scalar('Loss/Val', avg_loss, epoch)
    writer.add_scalar('BLEU/Val', bleu, epoch)
    writer.add_scalar('Perplexity/Val', perplexity, epoch)
    return avg_loss, bleu, perplexity

def evaluate_sentence(sent):
    model.eval()
    with torch.no_grad():
        seq = [char2idx[SOS_TOKEN]] + [char2idx.get(ch, char2idx[UNK_TOKEN]) for ch in sent.lower()]
        tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)
        length = torch.tensor([len(seq)])
        enc_out, h, c = model.encoder(tensor, length)
        return batch_beam_search_decode(model, enc_out, h, c)[0]

# -----------------------
# Main Loop
# -----------------------
if __name__ == "__main__":
    os.makedirs(args.model_dir, exist_ok=True)
    os.makedirs(args.checkpoint_dir, exist_ok=True)

    CHECKPOINT_PATH = os.path.join(args.checkpoint_dir, "seq2seq_epoch_{epoch}.pth")
    MODEL_PATH = os.path.join(args.model_dir, "seq2seq_daily_dialog.pth")

    RESUME_FROM = 0

    def save_checkpoint(epoch, model, optimizer, scheduler):
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }, CHECKPOINT_PATH.format(epoch=epoch))

    def load_checkpoint(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return checkpoint['epoch']

    start_epoch = 5
    if RESUME_FROM:
        print(f"Resuming training from epoch {RESUME_FROM}...")
        resume_path = CHECKPOINT_PATH.format(epoch=RESUME_FROM)
        start_epoch = load_checkpoint(resume_path) + 1

    for epoch in range(start_epoch, NUM_EPOCHS + 1):
        tr_loss = train_epoch(epoch)
        val_loss, val_bleu, val_perplexity = validate()
        scheduler.step(val_loss)
        print(f"Epoch {epoch} | Train Loss: {tr_loss:.4f} | Val Loss: {val_loss:.4f} | BLEU: {val_bleu:.4f} | Perplexity: {val_perplexity:.2f}")

        if epoch % 2 == 0:
            save_checkpoint(epoch, model, optimizer, scheduler)

        inp = random.choice(train_ds.pairs)[0]
        out = evaluate_sentence(inp)
        print(f"Sample In: {inp}\nOut: {out}")

    torch.save(model.state_dict(), MODEL_PATH)
    print(f"Final model saved to {MODEL_PATH}")
    writer.close()

In [35]:
!mkdir model

In [36]:
!mkdir checkpoint