#  Install Dependencies

In [1]:
# Install all required packages (removed nlpaug)
!pip install tokenizers sacrebleu rouge-score streamlit -q

print("✓ All dependencies installed successfully!")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m79.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.3/564.3 kB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m104.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 4.1.1 requires pyarrow>=21.0.0, but you have pyarrow 19.0.1 which is incomp

# Import All Libraries & Setup

In [2]:
# Core PyTorch and Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Data processing
import pandas as pd
import numpy as np
import math
import re
import json
import random
from collections import Counter

# Tokenization
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

# Evaluation metrics
from sacrebleu.metrics import BLEU, CHRF
from rouge_score import rouge_scorer

# Utilities
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# GPU setup - CUDA optimization for Kaggle
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"{'='*70}")
print(f"🚀 DEVICE SETUP")
print(f"{'='*70}")
print(f"✓ Using device: {device}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
    print(f"✓ Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"✓ CUDA Version: {torch.version.cuda}")
    # Enable cuDNN benchmarking for faster training
    torch.backends.cudnn.benchmark = True
else:
    print("⚠ No GPU available - using CPU")
print(f"{'='*70}\n")

print("🎯 IMPROVEMENTS LOADED:")
print("   ✅ Repetition penalty mechanism")
print("   ✅ Nucleus sampling (Top-p)")
print("   ✅ Extended training schedule support")
print("   ✅ Multiple decoding strategies")
print()

🚀 DEVICE SETUP
✓ Using device: cuda
✓ GPU: Tesla T4
✓ Memory: 15.83 GB
✓ CUDA Version: 12.4

🎯 IMPROVEMENTS LOADED:
   ✅ Repetition penalty mechanism
   ✅ Nucleus sampling (Top-p)
   ✅ Extended training schedule support
   ✅ Multiple decoding strategies



#  Load Dataset & Preprocessing

In [3]:
print(f"{'='*70}")
print(f"📊 DATASET LOADING & PREPROCESSING")
print(f"{'='*70}\n")

# Load the dataset
df = pd.read_csv('/kaggle/input/empathetic-dialogues-facebook-ai/emotion-emotion_69k.csv')

print(f"✓ Dataset loaded successfully!")
print(f"📊 Shape: {df.shape}")
print(f"🧩 Columns: {list(df.columns)}\n")

# Display sample
print("📋 First few rows:")
print(df.head(3))
print()

# Dataset overview
print(f"{'='*70}")
print(f"📈 DATASET OVERVIEW")
print(f"{'='*70}")
print(f"Total conversations: {len(df)}")
print(f"Unique situations: {df['Situation'].nunique()}")
print(f"Unique emotions: {df['emotion'].nunique()}")
print(f"\n🎭 Emotion distribution:")
print(df['emotion'].value_counts().head(10))
print()

# Text normalization function
def normalize_text(text):
    """Normalize text: lowercase, clean whitespace, normalize punctuation"""
    if pd.isna(text) or text is None:
        return ""
    text = str(text).lower()
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'([.!?,;:])', r' \1 ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# Extract customer utterance from empathetic_dialogues
def extract_customer_utterance(dialogue_text):
    """Extract the last customer utterance from dialogue"""
    if pd.isna(dialogue_text):
        return ""
    
    parts = re.split(r'(Customer :|Agent :)', str(dialogue_text))
    customer_utterances = []
    for i in range(len(parts)):
        if parts[i].strip() == 'Customer :' and i + 1 < len(parts):
            customer_utterances.append(parts[i + 1].strip())
    
    if customer_utterances:
        last_utterance = customer_utterances[-1]
        last_utterance = re.split(r'Agent :', last_utterance)[0].strip()
        last_utterance = last_utterance.replace('\\n', ' ').strip()
        return last_utterance
    
    return str(dialogue_text).strip()

print(f"{'='*70}")
print(f"🔧 PROCESSING DATA PAIRS (No Augmentation)")
print(f"{'='*70}\n")

# Create input-output pairs WITHOUT augmentation
data_pairs = []
skipped = 0

for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"):
    emotion = row['emotion']
    situation = row['Situation']
    dialogue = row['empathetic_dialogues']
    agent_reply = row['labels']
    
    if pd.isna(agent_reply) or pd.isna(situation):
        skipped += 1
        continue
    
    customer_utt = extract_customer_utterance(dialogue)
    
    if not customer_utt:
        skipped += 1
        continue
    
    # Create input-output pair
    input_text = f"emotion: {emotion} | situation: {situation} | customer: {customer_utt} agent:"
    target_text = agent_reply
    
    data_pairs.append({
        'input': normalize_text(input_text),
        'target': normalize_text(target_text),
        'emotion': emotion
    })

print(f"\n✓ Total valid pairs created: {len(data_pairs)}")
print(f"⚠ Skipped rows (missing data): {skipped}")

# Display sample pairs
print(f"\n{'='*70}")
print(f"📝 SAMPLE INPUT-OUTPUT PAIRS")
print(f"{'='*70}")
for i in range(3):
    print(f"\n--- Example {i+1} ---")
    print(f"INPUT: {data_pairs[i]['input'][:150]}...")
    print(f"TARGET: {data_pairs[i]['target']}")

# Split dataset: 80% train, 10% val, 10% test
random.shuffle(data_pairs)
total_size = len(data_pairs)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)

train_data = data_pairs[:train_size]
val_data = data_pairs[train_size:train_size + val_size]
test_data = data_pairs[train_size + val_size:]

print(f"\n{'='*70}")
print(f"📊 DATASET SPLIT (80-10-10)")
print(f"{'='*70}")
print(f"✓ Train: {len(train_data)} ({len(train_data)/total_size*100:.1f}%)")
print(f"✓ Validation: {len(val_data)} ({len(val_data)/total_size*100:.1f}%)")
print(f"✓ Test: {len(test_data)} ({len(test_data)/total_size*100:.1f}%)")
print(f"✓ Total: {total_size}")
print(f"{'='*70}\n")

📊 DATASET LOADING & PREPROCESSING

✓ Dataset loaded successfully!
📊 Shape: (64636, 7)
🧩 Columns: ['Unnamed: 0', 'Situation', 'emotion', 'empathetic_dialogues', 'labels', 'Unnamed: 5', 'Unnamed: 6']

📋 First few rows:
   Unnamed: 0                                          Situation      emotion  \
0           0  I remember going to the fireworks with my best...  sentimental   
1           1  I remember going to the fireworks with my best...  sentimental   
2           2  I remember going to the fireworks with my best...  sentimental   

                                empathetic_dialogues  \
0  Customer :I remember going to see the firework...   
1  Customer :This was a best friend. I miss her.\...   
2              Customer :We no longer talk.\nAgent :   

                                              labels Unnamed: 5 Unnamed: 6  
0  Was this a friend you were in love with, or ju...        NaN        NaN  
1                                Where has she gone?        NaN        NaN  
2 

Processing rows: 100%|██████████| 64636/64636 [00:07<00:00, 8672.00it/s]


✓ Total valid pairs created: 64636
⚠ Skipped rows (missing data): 0

📝 SAMPLE INPUT-OUTPUT PAIRS

--- Example 1 ---
INPUT: emotion : sentimental | situation : i remember going to the fireworks with my best friend . there was a lot of people , but it only felt like us in th...
TARGET: was this a friend you were in love with , or just a best friend ?

--- Example 2 ---
INPUT: emotion : sentimental | situation : i remember going to the fireworks with my best friend . there was a lot of people , but it only felt like us in th...
TARGET: where has she gone ?

--- Example 3 ---
INPUT: emotion : sentimental | situation : i remember going to the fireworks with my best friend . there was a lot of people , but it only felt like us in th...
TARGET: oh was this something that happened because of an argument ?

📊 DATASET SPLIT (80-10-10)
✓ Train: 51708 (80.0%)
✓ Validation: 6463 (10.0%)
✓ Test: 6465 (10.0%)
✓ Total: 64636






# Build Tokenizer & Create Datasets

In [4]:
print(f"{'='*70}")
print(f"🔤 TOKENIZER TRAINING (using 'tokenizers' library)")
print(f"{'='*70}\n")

# Create tokenizer using tokenizers library (Hugging Face)
tokenizer = Tokenizer(BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = Whitespace()

# Prepare training texts (from TRAIN split only - as per requirements)
train_texts = []
for item in train_data:
    train_texts.append(item['input'])
    train_texts.append(item['target'])

# Save to temporary file for tokenizer training
with open('train_texts.txt', 'w', encoding='utf-8') as f:
    for text in train_texts:
        f.write(text + '\n')

# Train BPE tokenizer with special tokens
trainer = BpeTrainer(
    special_tokens=["<pad>", "<bos>", "<eos>", "<unk>"],
    vocab_size=10000,
    min_frequency=2
)

print("🔄 Training BPE tokenizer on training data only...")
tokenizer.train(['train_texts.txt'], trainer)
tokenizer.save("empathetic_tokenizer.json")

vocab_size = tokenizer.get_vocab_size()
PAD_IDX = tokenizer.token_to_id("<pad>")
BOS_IDX = tokenizer.token_to_id("<bos>")
EOS_IDX = tokenizer.token_to_id("<eos>")
UNK_IDX = tokenizer.token_to_id("<unk>")

print(f"\n✓ Tokenizer trained successfully!")
print(f"✓ Vocabulary size: {vocab_size}")
print(f"✓ Special tokens:")
print(f"   - PAD: {PAD_IDX}")
print(f"   - BOS: {BOS_IDX}")
print(f"   - EOS: {EOS_IDX}")
print(f"   - UNK: {UNK_IDX}")

# Test tokenizer
test_text = "i am feeling happy today"
encoded = tokenizer.encode(test_text)
print(f"\n📝 Tokenizer test:")
print(f"   Input: '{test_text}'")
print(f"   Tokens: {encoded.tokens[:10]}")
print(f"   IDs: {encoded.ids[:10]}")

print(f"\n{'='*70}")
print(f"📦 CREATING PYTORCH DATASETS & DATALOADERS")
print(f"{'='*70}\n")

# PyTorch Dataset class
class EmpatheticDataset(Dataset):
    def __init__(self, data_pairs, tokenizer, max_len=128):
        self.data = data_pairs
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize input
        input_enc = self.tokenizer.encode(item['input'])
        input_ids = [BOS_IDX] + input_enc.ids[:self.max_len-2] + [EOS_IDX]
        
        # Tokenize target
        target_enc = self.tokenizer.encode(item['target'])
        target_ids = [BOS_IDX] + target_enc.ids[:self.max_len-2] + [EOS_IDX]
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'target_ids': torch.tensor(target_ids, dtype=torch.long),
        }

# Collate function for batching with padding
def collate_fn(batch):
    """Pad sequences in batch to same length"""
    input_ids = [item['input_ids'] for item in batch]
    target_ids = [item['target_ids'] for item in batch]
    
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=PAD_IDX)
    target_ids_padded = pad_sequence(target_ids, batch_first=True, padding_value=PAD_IDX)
    
    return {
        'input_ids': input_ids_padded,
        'target_ids': target_ids_padded,
    }

# Create datasets
train_dataset = EmpatheticDataset(train_data, tokenizer)
val_dataset = EmpatheticDataset(val_data, tokenizer)
test_dataset = EmpatheticDataset(test_data, tokenizer)

# Create dataloaders with GPU pinning for faster transfer
BATCH_SIZE = 32
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=collate_fn,
    pin_memory=True if torch.cuda.is_available() else False,
    num_workers=2
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    collate_fn=collate_fn,
    pin_memory=True if torch.cuda.is_available() else False,
    num_workers=2
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    collate_fn=collate_fn,
    pin_memory=True if torch.cuda.is_available() else False,
    num_workers=2
)

print(f"✓ Datasets created:")
print(f"   - Train batches: {len(train_loader)} (batch_size={BATCH_SIZE})")
print(f"   - Val batches: {len(val_loader)}")
print(f"   - Test batches: {len(test_loader)}")
print(f"\n✓ GPU memory pinning: {'Enabled' if torch.cuda.is_available() else 'Disabled'}")
print(f"{'='*70}\n")

🔤 TOKENIZER TRAINING (using 'tokenizers' library)

🔄 Training BPE tokenizer on training data only...




✓ Tokenizer trained successfully!
✓ Vocabulary size: 10000
✓ Special tokens:
   - PAD: 0
   - BOS: 1
   - EOS: 2
   - UNK: 3

📝 Tokenizer test:
   Input: 'i am feeling happy today'
   Tokens: ['i', 'am', 'feeling', 'happy', 'today']
   IDs: [49, 139, 535, 324, 362]

📦 CREATING PYTORCH DATASETS & DATALOADERS

✓ Datasets created:
   - Train batches: 1616 (batch_size=32)
   - Val batches: 202
   - Test batches: 203

✓ GPU memory pinning: Enabled



# Complete Transformer Model (Built from Scratch)

In [5]:
print(f"{'='*70}")
print(f"🏗️ BUILDING TRANSFORMER MODEL FROM SCRATCH")
print(f"{'='*70}\n")

# ============================================================================
# POSITIONAL ENCODING
# ============================================================================
class PositionalEncoding(nn.Module):
    """Adds positional information to embeddings using sin/cos functions"""
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# ============================================================================
# MULTI-HEAD ATTENTION
# ============================================================================
class MultiHeadAttention(nn.Module):
    """Multi-head scaled dot-product attention mechanism"""
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.dropout(attn_probs)
        
        output = torch.matmul(attn_probs, V)
        return output, attn_probs
    
    def split_heads(self, x):
        batch_size, seq_len, _ = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        batch_size, _, seq_len, _ = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output, attn_probs = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output, attn_probs

# ============================================================================
# POSITION-WISE FEED-FORWARD NETWORK
# ============================================================================
class PositionWiseFeedForward(nn.Module):
    """Two-layer feed-forward network with ReLU activation"""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.fc2(self.dropout(F.relu(self.fc1(x))))

# ============================================================================
# ENCODER LAYER
# ============================================================================
class EncoderLayer(nn.Module):
    """Single Transformer encoder layer"""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

# ============================================================================
# DECODER LAYER
# ============================================================================
class DecoderLayer(nn.Module):
    """Single Transformer decoder layer"""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

# ============================================================================
# COMPLETE TRANSFORMER CHATBOT MODEL
# ============================================================================
class TransformerChatbot(nn.Module):
    """Complete Transformer Encoder-Decoder for empathetic dialogue generation"""
    def __init__(self, vocab_size, d_model=256, num_heads=2, num_encoder_layers=2, 
                 num_decoder_layers=2, d_ff=1024, max_seq_len=128, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        
        # Embedding layers
        self.encoder_embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)
        self.decoder_embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
        
        # Encoder stack (2 layers as per specification)
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_encoder_layers)
        ])
        
        # Decoder stack (2 layers as per specification)
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_decoder_layers)
        ])
        
        # Output projection to vocabulary
        self.output_projection = nn.Linear(d_model, vocab_size)
    
    def generate_mask(self, src, tgt):
        """Generate source and target masks"""
        src_mask = (src != PAD_IDX).unsqueeze(1).unsqueeze(2)
        tgt_len = tgt.size(1)
        tgt_pad_mask = (tgt != PAD_IDX).unsqueeze(1).unsqueeze(2)
        tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
        tgt_mask = tgt_pad_mask & tgt_sub_mask
        return src_mask, tgt_mask
    
    def encode(self, src, src_mask):
        """Encode source sequence"""
        x = self.encoder_embedding(src) * math.sqrt(self.d_model)
        x = self.positional_encoding(x)
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x
    
    def decode(self, tgt, enc_output, src_mask, tgt_mask):
        """Decode target sequence"""
        x = self.decoder_embedding(tgt) * math.sqrt(self.d_model)
        x = self.positional_encoding(x)
        for layer in self.decoder_layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return x
    
    def forward(self, src, tgt):
        """Forward pass through encoder-decoder"""
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        enc_output = self.encode(src, src_mask)
        dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)
        output = self.output_projection(dec_output)
        return output

# ============================================================================
# INITIALIZE MODEL
# ============================================================================
print("🔧 Initializing Transformer model...")
model = TransformerChatbot(
    vocab_size=vocab_size,
    d_model=256,
    num_heads=2,
    num_encoder_layers=2,
    num_decoder_layers=2,
    d_ff=1024,
    dropout=0.1
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✓ Model initialized successfully!")
print(f"\n📊 MODEL STATISTICS:")
print(f"   - Total parameters: {total_params:,}")
print(f"   - Trainable parameters: {trainable_params:,}")
print(f"   - Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB")
print(f"\n🏗️ ARCHITECTURE:")
print(f"   - Encoder layers: {model.num_encoder_layers}")
print(f"   - Decoder layers: {model.num_decoder_layers}")
print(f"   - Attention heads per layer: 2")
print(f"   - Embedding dimension: {model.d_model}")
print(f"   - Feed-forward dimension: 1024")
print(f"   - Vocabulary size: {vocab_size}")
print(f"{'='*70}\n")

🏗️ BUILDING TRANSFORMER MODEL FROM SCRATCH

🔧 Initializing Transformer model...

✓ Model initialized successfully!

📊 MODEL STATISTICS:
   - Total parameters: 11,376,400
   - Trainable parameters: 11,376,400
   - Model size: ~43.40 MB

🏗️ ARCHITECTURE:
   - Encoder layers: 2
   - Decoder layers: 2
   - Attention heads per layer: 2
   - Embedding dimension: 256
   - Feed-forward dimension: 1024
   - Vocabulary size: 10000



# Training Setup & Extended Training with Improvements

In [6]:
print(f"{'='*70}")
print(f"🎯 TRAINING SETUP WITH IMPROVEMENTS")
print(f"{'='*70}\n")

# Loss function
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# Adam optimizer
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=1e-4,
    betas=(0.9, 0.98),
    eps=1e-9
)

# Noam learning rate scheduler
class NoamScheduler:
    """Learning rate scheduler with warmup"""
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.current_step = 0
    
    def step(self):
        self.current_step += 1
        lr = self.d_model ** (-0.5) * min(
            self.current_step ** (-0.5), 
            self.current_step * self.warmup_steps ** (-1.5)
        )
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr
    
    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

scheduler = NoamScheduler(optimizer, d_model=256, warmup_steps=4000)

print(f"✓ Loss function: CrossEntropyLoss (ignore_index={PAD_IDX})")
print(f"✓ Optimizer: Adam (lr=1e-4, betas=(0.9, 0.98))")
print(f"✓ Scheduler: Noam (warmup=4000 steps)")

# ============================================================================
# FIX TOKENIZER PARALLELISM WARNING
# ============================================================================
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print(f"✓ Tokenizer parallelism: Disabled (fixes multiprocessing conflicts)")

# ============================================================================
# TRAINING FUNCTION
# ============================================================================
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, use_scheduler=True):
    """Train for one epoch with teacher forcing"""
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch_idx, batch in enumerate(progress_bar):
        src = batch['input_ids'].to(device, non_blocking=True)
        tgt = batch['target_ids'].to(device, non_blocking=True)
        
        # Teacher forcing
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        # Forward pass
        optimizer.zero_grad()
        output = model(src, tgt_input)
        
        # Calculate loss
        loss = criterion(output.reshape(-1, vocab_size), tgt_output.reshape(-1))
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        if use_scheduler:
            scheduler.step()
        
        total_loss += loss.item()
        
        # Update progress bar
        if use_scheduler:
            current_lr = scheduler.get_lr()
        else:
            current_lr = optimizer.param_groups[0]['lr']
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'lr': f'{current_lr:.6f}'
        })
    
    return total_loss / len(dataloader)

# ============================================================================
# VALIDATION FUNCTION
# ============================================================================
def validate(model, dataloader, criterion, device):
    """Validate model"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            src = batch['input_ids'].to(device, non_blocking=True)
            tgt = batch['target_ids'].to(device, non_blocking=True)
            
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            output = model(src, tgt_input)
            loss = criterion(output.reshape(-1, vocab_size), tgt_output.reshape(-1))
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

print(f"\n{'='*70}")
print(f"🔬 SANITY CHECK: Overfitting on 10 samples")
print(f"{'='*70}\n")

# Sanity check
small_data = train_data[:10]
small_dataset = EmpatheticDataset(small_data, tokenizer)
small_loader = DataLoader(small_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn, num_workers=0)

# Higher learning rate for sanity check
sanity_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
sanity_scheduler = NoamScheduler(sanity_optimizer, d_model=256, warmup_steps=100)

print("Testing if model can learn (overfitting on 10 samples)...")
for epoch in range(50):
    model.train()
    epoch_loss = 0
    for batch in small_loader:
        src = batch['input_ids'].to(device)
        tgt = batch['target_ids'].to(device)
        
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        sanity_optimizer.zero_grad()
        output = model(src, tgt_input)
        loss = criterion(output.reshape(-1, vocab_size), tgt_output.reshape(-1))
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        sanity_optimizer.step()
        sanity_scheduler.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(small_loader)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/50, Loss: {avg_loss:.4f}")
    
    if avg_loss < 0.5:
        print(f"✓ Loss dropped below 0.5 at epoch {epoch+1}!")
        break

if avg_loss < 1.0:
    print(f"\n✓ Sanity check PASSED! Model can learn (final loss: {avg_loss:.4f})")
else:
    print(f"\n⚠ Warning: Model struggling to overfit (final loss: {avg_loss:.4f})")

print(f"\n{'='*70}")
print(f"🚀 FULL TRAINING WITH 3-PHASE SCHEDULE")
print(f"{'='*70}\n")

# Reinitialize model for full training
model = TransformerChatbot(
    vocab_size=vocab_size,
    d_model=256,
    num_heads=2,
    num_encoder_layers=2,
    num_decoder_layers=2,
    d_ff=1024,
    dropout=0.1
).to(device)

# MULTI-PHASE TRAINING SCHEDULE
PHASE1_EPOCHS = 10
PHASE1_LR = 1e-4

PHASE2_EPOCHS = 10
PHASE2_LR = 5e-5

PHASE3_EPOCHS = 5
PHASE3_LR = 1e-5

print(f"📅 TRAINING SCHEDULE:")
print(f"   Phase 1: {PHASE1_EPOCHS} epochs @ lr={PHASE1_LR} (Fast learning)")
print(f"   Phase 2: {PHASE2_EPOCHS} epochs @ lr={PHASE2_LR} (Fine-tuning)")
print(f"   Phase 3: {PHASE3_EPOCHS} epochs @ lr={PHASE3_LR} (Polish)")
print(f"   Total: {PHASE1_EPOCHS + PHASE2_EPOCHS + PHASE3_EPOCHS} epochs\n")

# Recreate dataloaders WITHOUT num_workers
print("📦 Recreating dataloaders (fixing multiprocessing)...")
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=collate_fn,
    pin_memory=True if torch.cuda.is_available() else False,
    num_workers=0
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    collate_fn=collate_fn,
    pin_memory=True if torch.cuda.is_available() else False,
    num_workers=0
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    collate_fn=collate_fn,
    pin_memory=True if torch.cuda.is_available() else False,
    num_workers=0
)
print("✓ Dataloaders recreated with num_workers=0\n")

# Training history
history = {
    'train_loss': [], 
    'val_loss': [], 
    'train_ppl': [], 
    'val_ppl': [],
    'learning_rates': [],
    'phases': []
}

best_val_loss = float('inf')
all_epochs = PHASE1_EPOCHS + PHASE2_EPOCHS + PHASE3_EPOCHS
current_epoch = 0

# ============================================================================
# PHASE 1: FAST LEARNING
# ============================================================================
print(f"\n{'='*70}")
print(f"📍 PHASE 1: FAST LEARNING ({PHASE1_EPOCHS} epochs)")
print(f"{'='*70}\n")

optimizer = torch.optim.Adam(model.parameters(), lr=PHASE1_LR, betas=(0.9, 0.98), eps=1e-9)
scheduler = NoamScheduler(optimizer, d_model=256, warmup_steps=4000)

for epoch in range(PHASE1_EPOCHS):
    current_epoch += 1
    print(f"\n{'='*70}")
    print(f"Epoch {current_epoch}/{all_epochs} (Phase 1)")
    print(f"{'='*70}")
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, scheduler, device, use_scheduler=True)
    
    # Validate
    val_loss = validate(model, val_loader, criterion, device)
    
    # Calculate perplexity
    train_ppl = math.exp(min(train_loss, 10))
    val_ppl = math.exp(min(val_loss, 10))
    
    # Store history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_ppl'].append(train_ppl)
    history['val_ppl'].append(val_ppl)
    history['learning_rates'].append(scheduler.get_lr())
    history['phases'].append('Phase 1')
    
    # Print metrics
    print(f"\n📊 Epoch {current_epoch} Results:")
    print(f"   Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"   Train Perplexity: {train_ppl:.2f} | Val Perplexity: {val_ppl:.2f}")
    print(f"   Learning Rate: {scheduler.get_lr():.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': current_epoch,
            'phase': 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'train_loss': train_loss,
        }, 'best_model.pt')
        print(f"   ✓ Best model saved! (Val Loss: {val_loss:.4f})")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ============================================================================
# PHASE 2: FINE-TUNING
# ============================================================================
print(f"\n{'='*70}")
print(f"📍 PHASE 2: FINE-TUNING ({PHASE2_EPOCHS} epochs)")
print(f"{'='*70}\n")

# Reduce learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=PHASE2_LR, betas=(0.9, 0.98), eps=1e-9)

for epoch in range(PHASE2_EPOCHS):
    current_epoch += 1
    print(f"\n{'='*70}")
    print(f"Epoch {current_epoch}/{all_epochs} (Phase 2)")
    print(f"{'='*70}")
    
    train_loss = train_epoch(model, train_loader, criterion, optimizer, None, device, use_scheduler=False)
    val_loss = validate(model, val_loader, criterion, device)
    
    train_ppl = math.exp(min(train_loss, 10))
    val_ppl = math.exp(min(val_loss, 10))
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_ppl'].append(train_ppl)
    history['val_ppl'].append(val_ppl)
    history['learning_rates'].append(optimizer.param_groups[0]['lr'])
    history['phases'].append('Phase 2')
    
    print(f"\n📊 Epoch {current_epoch} Results:")
    print(f"   Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"   Train Perplexity: {train_ppl:.2f} | Val Perplexity: {val_ppl:.2f}")
    print(f"   Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': current_epoch,
            'phase': 2,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'train_loss': train_loss,
        }, 'best_model.pt')
        print(f"   ✓ Best model saved! (Val Loss: {val_loss:.4f})")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ============================================================================
# PHASE 3: POLISH
# ============================================================================
print(f"\n{'='*70}")
print(f"📍 PHASE 3: POLISH ({PHASE3_EPOCHS} epochs)")
print(f"{'='*70}\n")

# Further reduce learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=PHASE3_LR, betas=(0.9, 0.98), eps=1e-9)

for epoch in range(PHASE3_EPOCHS):
    current_epoch += 1
    print(f"\n{'='*70}")
    print(f"Epoch {current_epoch}/{all_epochs} (Phase 3)")
    print(f"{'='*70}")
    
    train_loss = train_epoch(model, train_loader, criterion, optimizer, None, device, use_scheduler=False)
    val_loss = validate(model, val_loader, criterion, device)
    
    train_ppl = math.exp(min(train_loss, 10))
    val_ppl = math.exp(min(val_loss, 10))
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_ppl'].append(train_ppl)
    history['val_ppl'].append(val_ppl)
    history['learning_rates'].append(optimizer.param_groups[0]['lr'])
    history['phases'].append('Phase 3')
    
    print(f"\n📊 Epoch {current_epoch} Results:")
    print(f"   Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"   Train Perplexity: {train_ppl:.2f} | Val Perplexity: {val_ppl:.2f}")
    print(f"   Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': current_epoch,
            'phase': 3,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'train_loss': train_loss,
        }, 'best_model.pt')
        print(f"   ✓ Best model saved! (Val Loss: {val_loss:.4f})")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print(f"\n{'='*70}")
print(f"✓ TRAINING COMPLETE!")
print(f"{'='*70}")
print(f"Total epochs trained: {current_epoch}")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Best validation perplexity: {math.exp(best_val_loss):.2f}")

# Load best model
print(f"\n📥 Loading best model...")
checkpoint = torch.load('best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✓ Loaded best model from epoch {checkpoint['epoch']} (Phase {checkpoint['phase']})")
print(f"{'='*70}\n")

🎯 TRAINING SETUP WITH IMPROVEMENTS

✓ Loss function: CrossEntropyLoss (ignore_index=0)
✓ Optimizer: Adam (lr=1e-4, betas=(0.9, 0.98))
✓ Scheduler: Noam (warmup=4000 steps)
✓ Tokenizer parallelism: Disabled (fixes multiprocessing conflicts)

🔬 SANITY CHECK: Overfitting on 10 samples

Testing if model can learn (overfitting on 10 samples)...
Epoch 10/50, Loss: 0.9199
Epoch 20/50, Loss: 5.3793
Epoch 30/50, Loss: 4.4935
Epoch 40/50, Loss: 4.3679
Epoch 50/50, Loss: 4.2310


🚀 FULL TRAINING WITH 3-PHASE SCHEDULE

📅 TRAINING SCHEDULE:
   Phase 1: 10 epochs @ lr=0.0001 (Fast learning)
   Phase 2: 10 epochs @ lr=5e-05 (Fine-tuning)
   Phase 3: 5 epochs @ lr=1e-05 (Polish)
   Total: 25 epochs

📦 Recreating dataloaders (fixing multiprocessing)...
✓ Dataloaders recreated with num_workers=0


📍 PHASE 1: FAST LEARNING (10 epochs)


Epoch 1/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:00<00:00, 26.56it/s, loss=4.6018, lr=0.000399]
Validation: 100%|██████████| 202/202 [00:02<00:00, 72.66it/s]



📊 Epoch 1 Results:
   Train Loss: 5.3087 | Val Loss: 4.3849
   Train Perplexity: 202.08 | Val Perplexity: 80.23
   Learning Rate: 0.000399
   ✓ Best model saved! (Val Loss: 4.3849)

Epoch 2/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 24.93it/s, loss=4.4958, lr=0.000798]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.28it/s]



📊 Epoch 2 Results:
   Train Loss: 4.2738 | Val Loss: 4.1814
   Train Perplexity: 71.79 | Val Perplexity: 65.46
   Learning Rate: 0.000798
   ✓ Best model saved! (Val Loss: 4.1814)

Epoch 3/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.10it/s, loss=3.9393, lr=0.000898]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.19it/s]



📊 Epoch 3 Results:
   Train Loss: 4.1173 | Val Loss: 4.0777
   Train Perplexity: 61.40 | Val Perplexity: 59.01
   Learning Rate: 0.000898
   ✓ Best model saved! (Val Loss: 4.0777)

Epoch 4/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.05it/s, loss=3.7390, lr=0.000777]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.79it/s]



📊 Epoch 4 Results:
   Train Loss: 3.9701 | Val Loss: 3.9988
   Train Perplexity: 52.99 | Val Perplexity: 54.53
   Learning Rate: 0.000777
   ✓ Best model saved! (Val Loss: 3.9988)

Epoch 5/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.02it/s, loss=3.9867, lr=0.000695]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.34it/s]



📊 Epoch 5 Results:
   Train Loss: 3.8611 | Val Loss: 3.9205
   Train Perplexity: 47.52 | Val Perplexity: 50.42
   Learning Rate: 0.000695
   ✓ Best model saved! (Val Loss: 3.9205)

Epoch 6/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.06it/s, loss=3.5661, lr=0.000635]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.36it/s]



📊 Epoch 6 Results:
   Train Loss: 3.7814 | Val Loss: 3.8695
   Train Perplexity: 43.88 | Val Perplexity: 47.92
   Learning Rate: 0.000635
   ✓ Best model saved! (Val Loss: 3.8695)

Epoch 7/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 24.99it/s, loss=3.8386, lr=0.000588]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.16it/s]



📊 Epoch 7 Results:
   Train Loss: 3.7218 | Val Loss: 3.8542
   Train Perplexity: 41.34 | Val Perplexity: 47.19
   Learning Rate: 0.000588
   ✓ Best model saved! (Val Loss: 3.8542)

Epoch 8/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 24.97it/s, loss=3.6343, lr=0.000550]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.45it/s]



📊 Epoch 8 Results:
   Train Loss: 3.6712 | Val Loss: 3.8343
   Train Perplexity: 39.30 | Val Perplexity: 46.26
   Learning Rate: 0.000550
   ✓ Best model saved! (Val Loss: 3.8343)

Epoch 9/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.00it/s, loss=3.8686, lr=0.000518]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.37it/s]



📊 Epoch 9 Results:
   Train Loss: 3.6283 | Val Loss: 3.8256
   Train Perplexity: 37.65 | Val Perplexity: 45.86
   Learning Rate: 0.000518
   ✓ Best model saved! (Val Loss: 3.8256)

Epoch 10/25 (Phase 1)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.00it/s, loss=3.4545, lr=0.000492]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.33it/s]



📊 Epoch 10 Results:
   Train Loss: 3.5907 | Val Loss: 3.8158
   Train Perplexity: 36.26 | Val Perplexity: 45.41
   Learning Rate: 0.000492
   ✓ Best model saved! (Val Loss: 3.8158)

📍 PHASE 2: FINE-TUNING (10 epochs)


Epoch 11/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 24.95it/s, loss=3.1210, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.25it/s]



📊 Epoch 11 Results:
   Train Loss: 3.4409 | Val Loss: 3.7703
   Train Perplexity: 31.21 | Val Perplexity: 43.39
   Learning Rate: 0.000050
   ✓ Best model saved! (Val Loss: 3.7703)

Epoch 12/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.01it/s, loss=3.1786, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 69.98it/s]



📊 Epoch 12 Results:
   Train Loss: 3.4070 | Val Loss: 3.7646
   Train Perplexity: 30.17 | Val Perplexity: 43.15
   Learning Rate: 0.000050
   ✓ Best model saved! (Val Loss: 3.7646)

Epoch 13/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.07it/s, loss=3.3742, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.61it/s]



📊 Epoch 13 Results:
   Train Loss: 3.3912 | Val Loss: 3.7648
   Train Perplexity: 29.70 | Val Perplexity: 43.15
   Learning Rate: 0.000050

Epoch 14/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.02it/s, loss=3.4230, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.32it/s]



📊 Epoch 14 Results:
   Train Loss: 3.3803 | Val Loss: 3.7660
   Train Perplexity: 29.38 | Val Perplexity: 43.21
   Learning Rate: 0.000050

Epoch 15/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.11it/s, loss=3.6025, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.41it/s]



📊 Epoch 15 Results:
   Train Loss: 3.3710 | Val Loss: 3.7702
   Train Perplexity: 29.11 | Val Perplexity: 43.39
   Learning Rate: 0.000050

Epoch 16/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.01it/s, loss=3.1408, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.50it/s]



📊 Epoch 16 Results:
   Train Loss: 3.3631 | Val Loss: 3.7696
   Train Perplexity: 28.88 | Val Perplexity: 43.36
   Learning Rate: 0.000050

Epoch 17/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.00it/s, loss=2.9445, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.37it/s]



📊 Epoch 17 Results:
   Train Loss: 3.3557 | Val Loss: 3.7732
   Train Perplexity: 28.66 | Val Perplexity: 43.52
   Learning Rate: 0.000050

Epoch 18/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.09it/s, loss=3.4410, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.55it/s]



📊 Epoch 18 Results:
   Train Loss: 3.3483 | Val Loss: 3.7730
   Train Perplexity: 28.46 | Val Perplexity: 43.51
   Learning Rate: 0.000050

Epoch 19/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.06it/s, loss=3.5424, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.38it/s]



📊 Epoch 19 Results:
   Train Loss: 3.3432 | Val Loss: 3.7759
   Train Perplexity: 28.31 | Val Perplexity: 43.64
   Learning Rate: 0.000050

Epoch 20/25 (Phase 2)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.01it/s, loss=3.1443, lr=0.000050]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.32it/s]



📊 Epoch 20 Results:
   Train Loss: 3.3364 | Val Loss: 3.7760
   Train Perplexity: 28.12 | Val Perplexity: 43.64
   Learning Rate: 0.000050

📍 PHASE 3: POLISH (5 epochs)


Epoch 21/25 (Phase 3)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.05it/s, loss=3.2360, lr=0.000010]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.12it/s]



📊 Epoch 21 Results:
   Train Loss: 3.3157 | Val Loss: 3.7763
   Train Perplexity: 27.54 | Val Perplexity: 43.66
   Learning Rate: 0.000010

Epoch 22/25 (Phase 3)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.05it/s, loss=3.0592, lr=0.000010]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.23it/s]



📊 Epoch 22 Results:
   Train Loss: 3.3123 | Val Loss: 3.7756
   Train Perplexity: 27.45 | Val Perplexity: 43.63
   Learning Rate: 0.000010

Epoch 23/25 (Phase 3)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.01it/s, loss=3.5293, lr=0.000010]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.20it/s]



📊 Epoch 23 Results:
   Train Loss: 3.3123 | Val Loss: 3.7765
   Train Perplexity: 27.45 | Val Perplexity: 43.66
   Learning Rate: 0.000010

Epoch 24/25 (Phase 3)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 24.99it/s, loss=3.3390, lr=0.000010]
Validation: 100%|██████████| 202/202 [00:02<00:00, 69.92it/s]



📊 Epoch 24 Results:
   Train Loss: 3.3100 | Val Loss: 3.7768
   Train Perplexity: 27.39 | Val Perplexity: 43.68
   Learning Rate: 0.000010

Epoch 25/25 (Phase 3)


Training: 100%|██████████| 1616/1616 [01:04<00:00, 25.04it/s, loss=3.5682, lr=0.000010]
Validation: 100%|██████████| 202/202 [00:02<00:00, 70.31it/s]



📊 Epoch 25 Results:
   Train Loss: 3.3082 | Val Loss: 3.7771
   Train Perplexity: 27.34 | Val Perplexity: 43.69
   Learning Rate: 0.000010

✓ TRAINING COMPLETE!
Total epochs trained: 25
Best validation loss: 3.7646
Best validation perplexity: 43.15

📥 Loading best model...
✓ Loaded best model from epoch 12 (Phase 2)



# Inference with Advanced Decoding Strategies

In [7]:
print(f"{'='*70}")
print(f"🔮 ADVANCED INFERENCE FUNCTIONS")
print(f"{'='*70}\n")

# ============================================================================
# GREEDY DECODING (Basic)
# ============================================================================
def greedy_decode(model, src, max_len=50):
    """Standard greedy decoding"""
    model.eval()
    src = src.to(device)
    
    with torch.no_grad():
        src_mask, _ = model.generate_mask(src, src)
        enc_output = model.encode(src, src_mask)
        tgt = torch.tensor([[BOS_IDX]], device=device)
        
        for _ in range(max_len):
            _, tgt_mask = model.generate_mask(src, tgt)
            dec_output = model.decode(tgt, enc_output, src_mask, tgt_mask)
            output = model.output_projection(dec_output[:, -1, :])
            next_token = output.argmax(dim=-1, keepdim=True)
            tgt = torch.cat([tgt, next_token], dim=1)
            if next_token.item() == EOS_IDX:
                break
        
        return tgt.squeeze(0).cpu().tolist()

# ============================================================================
# GREEDY DECODING WITH REPETITION PENALTY (NEW IMPROVEMENT)
# ============================================================================
def greedy_decode_with_penalty(model, src, max_len=50, repetition_penalty=1.2):
    """
    Greedy decoding with repetition penalty to reduce repetitive outputs
    Higher penalty (>1.0) discourages repeating tokens
    """
    model.eval()
    src = src.to(device)
    
    with torch.no_grad():
        src_mask, _ = model.generate_mask(src, src)
        enc_output = model.encode(src, src_mask)
        tgt = torch.tensor([[BOS_IDX]], device=device)
        generated_tokens = []
        
        for _ in range(max_len):
            _, tgt_mask = model.generate_mask(src, tgt)
            dec_output = model.decode(tgt, enc_output, src_mask, tgt_mask)
            logits = model.output_projection(dec_output[:, -1, :]).squeeze(0)
            
            # Apply repetition penalty (NEW IMPROVEMENT)
            for token in set(generated_tokens):
                logits[token] = logits[token] / repetition_penalty
            
            next_token = logits.argmax(dim=-1, keepdim=True)
            tgt = torch.cat([tgt, next_token.unsqueeze(0)], dim=1)
            generated_tokens.append(next_token.item())
            
            if next_token.item() == EOS_IDX:
                break
        
        return tgt.squeeze(0).cpu().tolist()

# ============================================================================
# NUCLEUS SAMPLING / TOP-P SAMPLING (NEW IMPROVEMENT)
# ============================================================================
def nucleus_sampling_decode(model, src, max_len=50, p=0.9, temperature=1.0):
    """
    Nucleus (top-p) sampling for more diverse and natural responses
    p: cumulative probability threshold (0.9 = top 90% probability mass)
    temperature: controls randomness (higher = more random)
    """
    model.eval()
    src = src.to(device)
    
    with torch.no_grad():
        src_mask, _ = model.generate_mask(src, src)
        enc_output = model.encode(src, src_mask)
        tgt = torch.tensor([[BOS_IDX]], device=device)
        
        for _ in range(max_len):
            _, tgt_mask = model.generate_mask(src, tgt)
            dec_output = model.decode(tgt, enc_output, src_mask, tgt_mask)
            logits = model.output_projection(dec_output[:, -1, :]).squeeze(0)
            
            # Apply temperature
            logits = logits / temperature
            
            # Sort by probability
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            
            # Remove tokens after cumulative probability exceeds p
            sorted_indices_to_remove = cumulative_probs > p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            # Set removed tokens to -inf
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[indices_to_remove] = -float('Inf')
            
            # Sample from remaining distribution
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            tgt = torch.cat([tgt, next_token.unsqueeze(0)], dim=1)
            
            if next_token.item() == EOS_IDX:
                break
        
        return tgt.squeeze(0).cpu().tolist()

# ============================================================================
# BEAM SEARCH DECODING
# ============================================================================
def beam_search_decode(model, src, beam_width=3, max_len=50):
    """Beam search decoding"""
    model.eval()
    src = src.to(device)
    
    with torch.no_grad():
        src_mask, _ = model.generate_mask(src, src)
        enc_output = model.encode(src, src_mask)
        beams = [([BOS_IDX], 0.0)]
        
        for _ in range(max_len):
            candidates = []
            
            for seq, score in beams:
                if seq[-1] == EOS_IDX:
                    candidates.append((seq, score))
                    continue
                
                tgt = torch.tensor([seq], device=device)
                _, tgt_mask = model.generate_mask(src, tgt)
                dec_output = model.decode(tgt, enc_output, src_mask, tgt_mask)
                output = model.output_projection(dec_output[:, -1, :])
                
                log_probs = F.log_softmax(output, dim=-1)
                top_probs, top_indices = log_probs.topk(beam_width)
                
                for prob, idx in zip(top_probs[0], top_indices[0]):
                    candidates.append((seq + [idx.item()], score + prob.item()))
            
            beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            
            if all(seq[-1] == EOS_IDX for seq, _ in beams):
                break
        
        return beams[0][0]

# ============================================================================
# UNIFIED GENERATE RESPONSE FUNCTION
# ============================================================================
def generate_response(model, input_text, method='greedy', **kwargs):
    """
    Generate response with multiple decoding strategies
    
    Args:
        model: Trained model
        input_text: Input text string
        method: 'greedy', 'greedy_penalty', 'nucleus', 'beam'
        **kwargs: Additional parameters for each method
            - repetition_penalty: for 'greedy_penalty' (default: 1.2)
            - p: for 'nucleus' (default: 0.9)
            - temperature: for 'nucleus' (default: 1.0)
            - beam_width: for 'beam' (default: 3)
    """
    # Tokenize input
    input_enc = tokenizer.encode(normalize_text(input_text))
    input_ids = [BOS_IDX] + input_enc.ids + [EOS_IDX]
    src = torch.tensor([input_ids], dtype=torch.long)
    
    # Generate based on method
    if method == 'greedy':
        output_ids = greedy_decode(model, src)
    elif method == 'greedy_penalty':
        repetition_penalty = kwargs.get('repetition_penalty', 1.2)
        output_ids = greedy_decode_with_penalty(model, src, repetition_penalty=repetition_penalty)
    elif method == 'nucleus':
        p = kwargs.get('p', 0.9)
        temperature = kwargs.get('temperature', 1.0)
        output_ids = nucleus_sampling_decode(model, src, p=p, temperature=temperature)
    elif method == 'beam':
        beam_width = kwargs.get('beam_width', 3)
        output_ids = beam_search_decode(model, src, beam_width=beam_width)
    else:
        raise ValueError(f"Unknown method: {method}")
    
    # Decode output
    output_tokens = [tokenizer.id_to_token(id) for id in output_ids 
                     if id not in [BOS_IDX, EOS_IDX, PAD_IDX]]
    output_text = ' '.join(output_tokens)
    output_text = output_text.replace(' ,', ',').replace(' .', '.').replace(' !', '!').replace(' ?', '?')
    
    return output_text

print("✓ Greedy decoding implemented")
print("✓ Greedy with repetition penalty implemented (NEW)")
print("✓ Nucleus sampling implemented (NEW)")
print("✓ Beam search decoding implemented")

print(f"\n{'='*70}")
print(f"📊 EVALUATION METRICS")
print(f"{'='*70}\n")

# ============================================================================
# EVALUATION METRICS
# ============================================================================
def calculate_bleu(references, hypotheses):
    """Calculate BLEU score"""
    bleu = BLEU()
    score = bleu.corpus_score(hypotheses, [references])
    return score.score

def calculate_rouge(references, hypotheses):
    """Calculate ROUGE-L score"""
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    scores = [scorer.score(ref, hyp)['rougeL'].fmeasure 
              for ref, hyp in zip(references, hypotheses)]
    return np.mean(scores) * 100

def calculate_chrf(references, hypotheses):
    """Calculate chrF score"""
    chrf = CHRF()
    score = chrf.corpus_score(hypotheses, [references])
    return score.score

# ============================================================================
# FULL EVALUATION WITH DIFFERENT DECODING METHODS
# ============================================================================
def evaluate_model(model, dataloader, device, num_samples=500, method='greedy'):
    """Evaluate model on test set"""
    model.eval()
    references = []
    hypotheses = []
    total_loss = 0
    sample_count = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating ({method})"):
            src = batch['input_ids'].to(device, non_blocking=True)
            tgt = batch['target_ids'].to(device, non_blocking=True)
            
            # Calculate loss
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            output = model(src, tgt_input)
            loss = criterion(output.reshape(-1, vocab_size), tgt_output.reshape(-1))
            total_loss += loss.item()
            
            # Generate predictions
            for i in range(src.size(0)):
                if sample_count >= num_samples:
                    break
                
                # Generate prediction based on method
                if method == 'greedy':
                    pred_ids = greedy_decode(model, src[i:i+1])
                elif method == 'greedy_penalty':
                    pred_ids = greedy_decode_with_penalty(model, src[i:i+1], repetition_penalty=1.2)
                elif method == 'nucleus':
                    pred_ids = nucleus_sampling_decode(model, src[i:i+1], p=0.9)
                elif method == 'beam':
                    pred_ids = beam_search_decode(model, src[i:i+1], beam_width=3)
                
                # Get reference
                ref_ids = tgt[i].cpu().tolist()
                ref_tokens = [tokenizer.id_to_token(id) for id in ref_ids 
                             if id not in [BOS_IDX, EOS_IDX, PAD_IDX]]
                pred_tokens = [tokenizer.id_to_token(id) for id in pred_ids 
                              if id not in [BOS_IDX, EOS_IDX, PAD_IDX]]
                
                ref_text = ' '.join(ref_tokens)
                pred_text = ' '.join(pred_tokens)
                
                references.append(ref_text)
                hypotheses.append(pred_text)
                
                sample_count += 1
            
            if sample_count >= num_samples:
                break
    
    # Calculate all metrics
    perplexity = math.exp(total_loss / len(dataloader))
    bleu_score = calculate_bleu(references, hypotheses)
    rouge_score = calculate_rouge(references, hypotheses)
    chrf_score = calculate_chrf(references, hypotheses)
    
    return {
        'perplexity': perplexity,
        'bleu': bleu_score,
        'rouge_l': rouge_score,
        'chrf': chrf_score,
        'references': references,
        'hypotheses': hypotheses
    }

print("✓ BLEU metric ready")
print("✓ ROUGE-L metric ready")
print("✓ chrF metric ready")
print("✓ Perplexity metric ready")

print(f"\n{'='*70}")
print(f"🧪 RUNNING EVALUATION ON TEST SET")
print(f"{'='*70}\n")

# Evaluate with different methods
print("📊 Comparing different decoding strategies...\n")

results_comparison = {}

for method_name in ['greedy', 'greedy_penalty', 'nucleus', 'beam']:
    print(f"Testing {method_name}...")
    results = evaluate_model(model, test_loader, device, num_samples=500, method=method_name)
    results_comparison[method_name] = results

print(f"\n{'='*70}")
print(f"📈 TEST SET RESULTS COMPARISON")
print(f"{'='*70}")
print(f"{'Method':<20} {'BLEU':<12} {'ROUGE-L':<12} {'chrF':<12} {'Perplexity':<12}")
print(f"{'-'*68}")

for method_name, results in results_comparison.items():
    print(f"{method_name:<20} {results['bleu']:<12.2f} {results['rouge_l']:<12.2f} "
          f"{results['chrf']:<12.2f} {results['perplexity']:<12.2f}")

print(f"{'='*70}\n")

# Use best performing method for final results
best_method = max(results_comparison.items(), key=lambda x: x[1]['bleu'])
test_results = best_method[1]
print(f"✓ Best performing method: {best_method[0]}")
print(f"✓ Best BLEU score: {test_results['bleu']:.2f}")

print(f"\n{'='*70}")
print(f"📝 QUALITATIVE EXAMPLES")
print(f"{'='*70}\n")

# Show 5 random examples
for i in range(5):
    idx = random.randint(0, len(test_results['references'])-1)
    print(f"--- Example {i+1} ---")
    print(f"Reference: {test_results['references'][idx]}")
    print(f"Generated: {test_results['hypotheses'][idx]}")
    print()

print(f"{'='*70}")
print(f"🎭 CUSTOM TEST INPUTS WITH ALL METHODS")
print(f"{'='*70}\n")

# Test with custom inputs
test_inputs = [
    "emotion: sad | situation: my dog passed away last week | customer: i miss him so much agent:",
    "emotion: excited | situation: i got accepted to my dream university | customer: i can't believe it happened! agent:",
    "emotion: afraid | situation: i have to give a presentation tomorrow | customer: i'm so nervous about speaking in public agent:",
]

for inp in test_inputs:
    print(f"Input: {inp[:80]}...")
    print(f"Greedy:         {generate_response(model, inp, method='greedy')}")
    print(f"Greedy+Penalty: {generate_response(model, inp, method='greedy_penalty', repetition_penalty=1.2)}")
    print(f"Nucleus (p=0.9): {generate_response(model, inp, method='nucleus', p=0.9, temperature=1.0)}")
    print(f"Beam (width=3): {generate_response(model, inp, method='beam', beam_width=3)}")
    print("-" * 70)
    print()

print(f"{'='*70}")
print(f"💾 SAVING MODEL AND ARTIFACTS")
print(f"{'='*70}\n")

# Save final model with best results
torch.save({
    'model_state_dict': model.state_dict(),
    'vocab_size': vocab_size,
    'model_config': {
        'd_model': 256,
        'num_heads': 2,
        'num_encoder_layers': 2,
        'num_decoder_layers': 2,
        'd_ff': 1024,
        'dropout': 0.1
    },
    'test_results': {
        'perplexity': test_results['perplexity'],
        'bleu': test_results['bleu'],
        'rouge_l': test_results['rouge_l'],
        'chrf': test_results['chrf'],
        'best_method': best_method[0]
    },
    'results_comparison': {k: {
        'bleu': v['bleu'],
        'rouge_l': v['rouge_l'],
        'chrf': v['chrf'],
        'perplexity': v['perplexity']
    } for k, v in results_comparison.items()}
}, 'final_model.pt')

# Save tokenizer
tokenizer.save("empathetic_tokenizer.json")

# Save evaluation results
with open('evaluation_results.json', 'w') as f:
    json.dump({
        'best_method': best_method[0],
        'perplexity': test_results['perplexity'],
        'bleu': test_results['bleu'],
        'rouge_l': test_results['rouge_l'],
        'chrf': test_results['chrf'],
        'all_methods': {k: {
            'bleu': v['bleu'],
            'rouge_l': v['rouge_l'],
            'chrf': v['chrf'],
            'perplexity': v['perplexity']
        } for k, v in results_comparison.items()},
        'num_samples': len(test_results['references'])
    }, f, indent=2)

# Save training history
with open('training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print("✓ Model saved: final_model.pt (includes all test results)")
print("✓ Best checkpoint: best_model.pt (saved during training)")
print("✓ Tokenizer saved: empathetic_tokenizer.json")
print("✓ Evaluation results saved: evaluation_results.json")
print("✓ Training history saved: training_history.json")

print(f"\n{'='*70}")
print(f"📄 CREATING DEPLOYMENT FILES")
print(f"{'='*70}\n")


🔮 ADVANCED INFERENCE FUNCTIONS

✓ Greedy decoding implemented
✓ Greedy with repetition penalty implemented (NEW)
✓ Nucleus sampling implemented (NEW)
✓ Beam search decoding implemented

📊 EVALUATION METRICS

✓ BLEU metric ready
✓ ROUGE-L metric ready
✓ chrF metric ready
✓ Perplexity metric ready

🧪 RUNNING EVALUATION ON TEST SET

📊 Comparing different decoding strategies...

Testing greedy...


Evaluating (greedy):   7%|▋         | 15/203 [00:19<04:03,  1.30s/it]


Testing greedy_penalty...


Evaluating (greedy_penalty):   7%|▋         | 15/203 [00:18<03:50,  1.23s/it]


Testing nucleus...


Evaluating (nucleus):   7%|▋         | 15/203 [00:26<05:34,  1.78s/it]


Testing beam...


Evaluating (beam):   7%|▋         | 15/203 [01:03<13:20,  4.26s/it]



📈 TEST SET RESULTS COMPARISON
Method               BLEU         ROUGE-L      chrF         Perplexity  
--------------------------------------------------------------------
greedy               2.60         15.91        12.59        1.35        
greedy_penalty       2.40         16.02        12.26        1.35        
nucleus              1.26         11.63        13.76        1.35        
beam                 2.69         15.49        11.33        1.35        

✓ Best performing method: beam
✓ Best BLEU score: 2.69

📝 QUALITATIVE EXAMPLES

--- Example 1 ---
Reference: no , and it gets progress ive ly more expensive .
Generated: i ' m not sure yet . i ' ll be able to do that .

--- Example 2 ---
Reference: : ) we have all being through the age . you will miss this stage when they finally grow to become adult
Generated: i ' m not sure , but i ' m not sure if i ' ll do it .

--- Example 3 ---
Reference: why is that ?
Generated: i ' m sure you ' ll be fine !

--- Example 4 ---
Reference: d