## 1. First, let's start with the imports and configuration:


In [1]:
# Import required libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
from collections import Counter
import json
import re
import nltk
from nltk.tokenize import word_tokenize
import unicodedata
import os
import logging
import math
from typing import List, Dict, Tuple
import warnings

warnings.filterwarnings('ignore')
nltk.download('punkt')
nltk.download('punkt_tab')


# Setup logging and device
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)

# Configuration
class Config:
    MAX_LENGTH = 10
    MIN_WORD_FREQ = 3
    VOCAB_SIZE = 5000
    BATCH_SIZE = 64
    EMBEDDING_DIM = 256
    NUM_HEADS = 8
    NUM_LAYERS = 4
    DROPOUT = 0.1
    LEARNING_RATE = 0.0003
    EPOCHS = 6 #20

    # Paths
    CORPUS_CONV = 'data/cornell movie-dialogs corpus/movie_conversations.txt'
    CORPUS_LINES = 'data/cornell movie-dialogs corpus/movie_lines.txt'
    DELIMITER = ' +++$+++ '

[nltk_data] Downloading package punkt to /home/livyatan13/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/livyatan13/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


## 1. Text preprocessing and data handling:

In [2]:
class TextProcessor:
    def __init__(self):
        self.special_tokens = ['<pad>', '<unk>', '<start>', '<end>']

    def preprocess_text(self, text: str) -> str:
        """Enhanced text preprocessing with better linguistic features"""
        text = text.lower()
        text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode()

        # Preserve important punctuation and add spaces
        text = re.sub(r'[^a-zA-Z0-9\',.!?\s]', ' ', text)
        text = re.sub(r'([.,!?])', r' \1 ', text)

        # Handle contractions carefully
        contractions = {
            "\'s": " 's", "\'re": " 're", "\'t": " 't",
            "\'ll": " 'll", "\'ve": " 've", "\'m": " 'm",
            "won't": "will not", "can't": "cannot"
        }
        for contraction, expansion in contractions.items():
            text = text.replace(contraction, expansion)

        return ' '.join(text.split())

    def filter_conversation(self, question: str, answer: str) -> bool:
        """Filter out low-quality conversation pairs"""
        min_words, max_words = 2, Config.MAX_LENGTH

        q_words = len(question.split())
        a_words = len(answer.split())

        if not (min_words <= q_words <= max_words and
                min_words <= a_words <= max_words):
            return False

        generic_responses = {
            'i don t know', 'i m not sure', 'what', 'ok', 'yes', 'no',
            'i do not know', 'i dont know', 'idk', 'i have no idea'
        }
        if answer.strip().lower() in generic_responses:
            return False

        return True

## Data loading and processing:

In [3]:
class DataHandler:
    def __init__(self, config: Config):
        self.config = config
        self.text_processor = TextProcessor()
        self.word_map = {}
        
    def encode_sequence(self, tokens: List[str], max_len: int) -> List[int]:
        """Encode token sequence to indices"""
        # Convert tokens to indices
        encoded = [self.word_map.get(token, self.word_map['<unk>']) 
                  for token in tokens]
        
        # Truncate or pad as needed
        if len(encoded) > max_len:
            encoded = encoded[:max_len]
        elif len(encoded) < max_len:
            encoded.extend([self.word_map['<pad>']] * (max_len - len(encoded)))
            
        return encoded
    
    def prepare_data(self) -> Tuple[List[List[List[int]]], Dict[str, int]]:
        """Prepare and encode all data"""
        # Load raw data
        conversations, lines_dict = self.load_conversations()
        pairs = self.extract_pairs(conversations, lines_dict)
        
        # Build vocabulary
        word_freq = Counter()
        for pair in pairs:
            word_freq.update(pair[0])
            word_freq.update(pair[1])
        
        # Filter words and create word map
        words = [w for w, freq in word_freq.items() 
                if freq >= self.config.MIN_WORD_FREQ]
        words = words[:self.config.VOCAB_SIZE - 4]  # Reserve space for special tokens
        
        # Create word map with special tokens
        special_tokens = ['<pad>', '<unk>', '<start>', '<end>']
        self.word_map = {token: idx for idx, token in enumerate(special_tokens)}
        for word in words:
            if word not in self.word_map:
                self.word_map[word] = len(self.word_map)
        
        # Encode all pairs
        encoded_pairs = []
        for pair in pairs:
            encoded_q = self.encode_sequence(pair[0], self.config.MAX_LENGTH)
            encoded_a = self.encode_sequence(pair[1], self.config.MAX_LENGTH)
            encoded_pairs.append([encoded_q, encoded_a])
        
        return encoded_pairs, self.word_map

    def load_conversations(self) -> Tuple[List[str], Dict[str, str]]:
        """Load conversation files with error handling"""
        try:
            with open(self.config.CORPUS_CONV, 'r', encoding='iso-8859-1') as f:
                conversations = f.readlines()
            with open(self.config.CORPUS_LINES, 'r', encoding='iso-8859-1') as f:
                lines = f.readlines()

            lines_dict = {}
            for line in lines:
                parts = line.split(self.config.DELIMITER)
                if len(parts) >= 2:
                    lines_dict[parts[0]] = parts[-1].strip()

            return conversations, lines_dict

        except Exception as e:
            logger.error(f"Error loading files: {e}")
            return [], {}

    def extract_pairs(self, conversations: List[str],
                     lines_dict: Dict[str, str]) -> List[List[str]]:
        """Extract and filter conversation pairs"""
        pairs = []
        for conv in conversations:
            try:
                ids = eval(conv.split(self.config.DELIMITER)[-1])
                for i in range(len(ids) - 1):
                    q = lines_dict.get(ids[i], '')
                    a = lines_dict.get(ids[i + 1], '')

                    if q and a:
                        q = self.text_processor.preprocess_text(q)
                        a = self.text_processor.preprocess_text(a)

                        if self.text_processor.filter_conversation(q, a):
                            q_tokens = word_tokenize(q)[:self.config.MAX_LENGTH]
                            a_tokens = word_tokenize(a)[:self.config.MAX_LENGTH]
                            pairs.append([q_tokens, a_tokens])

            except Exception as e:
                logger.error(f"Error processing conversation: {e}")
                continue

        return pairs

## Model architecture:


In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_length: int = 100):
        super().__init__()
        position = torch.arange(max_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                           (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_length, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1)]

class TransformerChatbot(nn.Module):
    def __init__(self, config: Config, vocab_size: int):
        super().__init__()
        self.config = config

        # Embeddings and positional encoding
        self.embedding = nn.Embedding(vocab_size, config.EMBEDDING_DIM)
        self.pos_encoder = PositionalEncoding(config.EMBEDDING_DIM)
        self.dropout = nn.Dropout(config.DROPOUT)

        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.EMBEDDING_DIM,
            nhead=config.NUM_HEADS,
            dim_feedforward=config.EMBEDDING_DIM * 4,
            dropout=config.DROPOUT,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=config.NUM_LAYERS
        )

        # Output layer
        self.output_layer = nn.Linear(config.EMBEDDING_DIM, vocab_size)

    def create_mask(self, src: torch.Tensor) -> torch.Tensor:
        src_mask = src == 0  # PAD token
        return src_mask

    def forward(self, src: torch.Tensor) -> torch.Tensor:
        # Create mask for padding tokens
        src_mask = self.create_mask(src)

        # Embedding and position encoding
        src = self.embedding(src) * math.sqrt(self.config.EMBEDDING_DIM)
        src = self.pos_encoder(src)
        src = self.dropout(src)

        # Apply transformer with masking
        output = self.transformer(src, src_key_padding_mask=src_mask)

        # Generate output distribution
        output = self.output_layer(output)

        return output

## Training components:


In [5]:
class ChatDataset(Dataset):
    def __init__(self, pairs_encoded: List[List[List[int]]]):
        self.pairs = pairs_encoded

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return (torch.LongTensor(self.pairs[idx][0]),
                torch.LongTensor(self.pairs[idx][1]))

class Trainer:
    def __init__(self, model: nn.Module, config: Config):
        self.model = model
        self.config = config
        self.optimizer = torch.optim.Adam(
            model.parameters(),
            lr=config.LEARNING_RATE
        )
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)

    def train_epoch(self, train_loader: DataLoader) -> float:
        self.model.train()
        total_loss = 0

        for batch_idx, (src, tgt) in enumerate(train_loader):
            src, tgt = src.to(device), tgt.to(device)

            self.optimizer.zero_grad()
            output = self.model(src)

            loss = self.criterion(
                output.view(-1, output.size(-1)),
                tgt.view(-1)
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()

            total_loss += loss.item()

            if batch_idx % 100 == 0:
                logger.info(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

        return total_loss / len(train_loader)

## Chat interface and inference:


In [6]:
class ChatInterface:
    def __init__(self, model: TransformerChatbot,
                 word_map: Dict[str, int],
                 config: Config):
        self.model = model
        self.word_map = word_map
        self.config = config
        self.text_processor = TextProcessor()
        self.rev_word_map = {idx: word for word, idx in word_map.items()}

    def generate_response(self, question: str) -> str:
        self.model.eval()

        # Preprocess question
        question = self.text_processor.preprocess_text(question)
        tokens = word_tokenize(question)[:self.config.MAX_LENGTH]

        # Convert to indices
        indices = [self.word_map.get(token, self.word_map['<unk>'])
                  for token in tokens]
        # Pad if necessary
        if len(indices) < self.config.MAX_LENGTH:
            indices.extend([self.word_map['<pad>']] *
                         (self.config.MAX_LENGTH - len(indices)))

        question_tensor = torch.LongTensor(indices).unsqueeze(0).to(device)

        # Generate response
        with torch.no_grad():
            output = self.model(question_tensor)
            _, predicted = torch.max(output, dim=-1)

        # Convert to words and clean up response
        response_words = []
        for idx in predicted[0]:
            word = self.rev_word_map[idx.item()]
            if word in ['<pad>', '<start>', '<end>', '<unk>']:
                continue
            response_words.append(word)

        return ' '.join(response_words)

## Finally, here's how to put it all together:

In [7]:
# Initialize configuration
config = Config()

# Initialize data handler and prepare data
data_handler = DataHandler(config)
encoded_pairs, word_map = data_handler.prepare_data()

# Create dataset and dataloader
dataset = ChatDataset(encoded_pairs)
train_loader = DataLoader(
    dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    pin_memory=True
)

# Create and train model
model = TransformerChatbot(config, len(word_map)).to(device)
trainer = Trainer(model, config)

# Train the model
logger.info("Starting training...")
for epoch in range(config.EPOCHS):
    avg_loss = trainer.train_epoch(train_loader)
    logger.info(f"Epoch {epoch} completed, Average Loss: {avg_loss:.4f}")

# Create chat interface
chat_interface = ChatInterface(model, word_map, config)

# Test the chatbot
test_questions = [
    "Hello, how are you?",
    "What's your favorite movie?",
    "Do you like pizza?"
]

print("\nTesting the chatbot:")
for question in test_questions:
    response = chat_interface.generate_response(question)
    print(f"Q: {question}")
    print(f"A: {response}\n")


INFO:__main__:Starting training...
INFO:__main__:Batch 0, Loss: 8.5772
INFO:__main__:Batch 100, Loss: 5.1684
INFO:__main__:Batch 200, Loss: 5.2472
INFO:__main__:Batch 300, Loss: 5.3433
INFO:__main__:Batch 400, Loss: 5.0134
INFO:__main__:Batch 500, Loss: 5.2570
INFO:__main__:Batch 600, Loss: 5.0929
INFO:__main__:Batch 700, Loss: 5.1147
INFO:__main__:Batch 800, Loss: 4.8993
INFO:__main__:Batch 900, Loss: 4.6280
INFO:__main__:Batch 1000, Loss: 4.9032
INFO:__main__:Epoch 0 completed, Average Loss: 5.1118
INFO:__main__:Batch 0, Loss: 4.8707
INFO:__main__:Batch 100, Loss: 5.0789
INFO:__main__:Batch 200, Loss: 5.0667
INFO:__main__:Batch 300, Loss: 5.1085
INFO:__main__:Batch 400, Loss: 4.8963
INFO:__main__:Batch 500, Loss: 5.0920
INFO:__main__:Batch 600, Loss: 4.7637
INFO:__main__:Batch 700, Loss: 5.0541
INFO:__main__:Batch 800, Loss: 4.8991
INFO:__main__:Batch 900, Loss: 5.1809
INFO:__main__:Batch 1000, Loss: 4.9116
INFO:__main__:Epoch 1 completed, Average Loss: 4.9755
INFO:__main__:Batch 0, 


Testing the chatbot:
Q: Hello, how are you?
A: i . . . . . to to to to

Q: What's your favorite movie?
A: . . . . . to to to to

Q: Do you like pizza?
A: i . . . . to to to to to

