# Chatbot - Transformer

In [1]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.utils.data
import math
import torch.nn.functional as F
import re
import nltk
from nltk.tokenize import word_tokenize
import unicodedata
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
nltk.download('punkt')

nltk.download('punkt_tab')


[nltk_data] Downloading package punkt to /home/leviathan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/leviathan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [2]:
import os

if not os.path.exists('data/cornell movie-dialogs corpus'):
    !wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip -P data && cd data && unzip cornell_movie_dialogs_corpus.zip
else:
    print("'data/cornell movie-dialogs corpus' already exists. Skipping download and extraction.")

'data/cornell movie-dialogs corpus' already exists. Skipping download and extraction.


## Dataset Preprocessing Part 1

In [3]:
import json
from collections import Counter
import re
import nltk
from nltk.tokenize import word_tokenize
import unicodedata
nltk.download('punkt')

# File paths
CORPUS_CONV = 'data/cornell movie-dialogs corpus/movie_conversations.txt'
CORPUS_LINES = 'data/cornell movie-dialogs corpus/movie_lines.txt'
max_len = 13
DELIMITER = ' +++$+++ '

def load_conversations():
    """Load and parse conversation files with error handling"""
    try:
        with open(CORPUS_CONV, 'r', encoding='iso-8859-1') as f:
            conversations = f.readlines()
        with open(CORPUS_LINES, 'r', encoding='iso-8859-1') as f:
            lines = f.readlines()
        return conversations, lines
    except Exception as e:
        print(f"Error loading files: {e}")
        return None, None

def create_lines_dict(lines):
    """Create dictionary of line ID to text with better parsing"""
    lines_dict = {}
    for line in lines:
        parts = line.split(DELIMITER)
        if len(parts) >= 2:
            lines_dict[parts[0]] = parts[-1].strip()
    return lines_dict

def preprocess_text(text):
    """Improved text preprocessing"""
    # Convert to lowercase and normalize unicode
    text = text.lower()
    text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode()

    # Remove special characters but keep apostrophes for contractions
    text = re.sub(r'[^a-zA-Z0-9\'\s]', ' ', text)

    # Standardize contractions
    text = re.sub(r"\'s", " is", text)
    text = re.sub(r"\'re", " are", text)
    text = re.sub(r"\'t", " not", text)
    text = re.sub(r"\'ll", " will", text)
    text = re.sub(r"\'ve", " have", text)
    text = re.sub(r"\'m", " am", text)
    text = re.sub(r"won't", "will not", text)
    text = re.sub(r"can't", "cannot", text)

    # Remove extra whitespace
    text = ' '.join(text.split())
    return text

def extract_pairs(conversations, lines_dict):
    """Extract and preprocess conversation pairs"""
    pairs = []
    for conv in conversations:
        try:
            ids = eval(conv.split(DELIMITER)[-1])
            for i in range(len(ids) - 1):
                first = lines_dict.get(ids[i], '')
                second = lines_dict.get(ids[i + 1], '')

                if first and second:
                    first = preprocess_text(first)
                    second = preprocess_text(second)

                    # Add debug prints to see the text before tokenization
                    # print(f"Pre-tokenization first: '{first}'")
                    # print(f"Pre-tokenization second: '{second}'")

                    try:
                        first_tokens = word_tokenize(first)[:max_len]
                        second_tokens = word_tokenize(second)[:max_len]
                    except Exception as e:
                        print(f"Tokenization error: {e}")
                        continue
                    #
                    # print(f'Tokens first: {first_tokens}')
                    # print(f'Tokens second: {second_tokens}')

                    if first_tokens and second_tokens:  # Only add if both have content
                        pairs.append([first_tokens, second_tokens])
        except Exception as e:
            print(f"Error processing conversation: {e}")
            continue

    print(f"Total pairs extracted: {len(pairs)}")
    return pairs


def build_vocab(pairs, min_freq=1):
    """Build vocabulary with proper ID allocation"""
    word_freq = Counter()
    for pair in pairs:
        word_freq.update(pair[0])
        word_freq.update(pair[1])

    # Filter out rare words
    words = [w for w, freq in word_freq.items() if freq >= min_freq]

    # Start with special tokens to ensure they have predictable IDs
    special_tokens = ['<pad>', '<unk>', '<start>', '<end>']

    # Create word map with continuous IDs
    word_map = {}
    # First add special tokens
    for i, token in enumerate(special_tokens):
        word_map[token] = i

    # Then add all other words
    for i, word in enumerate(words):
        if word not in word_map:  # Skip if it's somehow a special token
            word_map[word] = i + len(special_tokens)

    print(f"Final vocabulary size: {len(word_map)}")
    return word_map

def encode_sequence(tokens, word_map, is_reply=False, max_seq_len=5):
    """Encode sequence with correct length handling"""
    if is_reply:
        # For replies, we need to account for <start> and <end> tokens
        # So actual token count will be limited to max_seq_len-2
        tokens = tokens[:max_seq_len-2]  # Limit tokens to leave room for special tokens

        # Add start token, tokens, and end token
        encoded = [word_map['<start>']]
        encoded.extend([word_map.get(token, word_map['<unk>']) for token in tokens])
        encoded.append(word_map['<end>'])
    else:
        # For questions, we can use all max_seq_len positions
        tokens = tokens[:max_seq_len]  # Limit tokens to max_seq_len
        encoded = [word_map.get(token, word_map['<unk>']) for token in tokens]

    # Calculate correct padding length
    padding_length = max_seq_len - len(encoded)

    # Add padding if needed
    if padding_length > 0:
        encoded.extend([word_map['<pad>']] * padding_length)

    return encoded

# Load and process data
conversations, lines = load_conversations()
lines_dict = create_lines_dict(lines)
pairs = extract_pairs(conversations, lines_dict)
word_map = build_vocab(pairs)

# Encode pairs with fixed function
pairs_encoded = []
for pair in pairs:
    question = encode_sequence(pair[0], word_map, is_reply=False, max_seq_len=max_len)
    answer = encode_sequence(pair[1], word_map, is_reply=True, max_seq_len=max_len)

    # Validate lengths (important debugging check)
    if len(question) != max_len or len(answer) != max_len:
        print(f"Warning: Inconsistent sequence length: Q:{len(question)}, A:{len(answer)}")
        continue

    pairs_encoded.append([question, answer])


# Add this after preprocessing a few examples
sample_pairs = pairs[:5]
print("\nSample processed pairs:")
for pair in sample_pairs:
    print(f"Q: {' '.join(pair[0])}")
    print(f"A: {' '.join(pair[1])}")
    print()


# Save processed data
with open('data/word_map_corpus.json', 'w') as f:
    json.dump(word_map, f)
with open('data/pairs_encoded.json', 'w') as f:
    json.dump(pairs_encoded, f)

print(f"Vocabulary size: {len(word_map)}")
print(f"Number of conversation pairs: {len(pairs_encoded)}")


[nltk_data] Downloading package punkt to /home/leviathan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Total pairs extracted: 221274
Final vocabulary size: 41506

Sample processed pairs:
Q: can we make this quick roxanne korrine and andrew barrett are having an
A: well i thought we 'd start with pronunciation if that is okay with

Q: well i thought we 'd start with pronunciation if that is okay with
A: not the hacking and gagging and spitting part please

Q: not the hacking and gagging and spitting part please
A: okay then how 'bout we try out some french cuisine saturday night

Q: you are asking me out that is so cute what is your name
A: forget it

Q: no no it is my fault we didn not have a proper introduction
A: cameron

Vocabulary size: 41506
Number of conversation pairs: 221274


In [4]:
lines

['L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n',
 'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n',
 'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n',
 'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n',
 "L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n",
 'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n',
 "L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n",
 'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n',
 'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n',
 'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n',
 'L868 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ The "real you".\n',
 'L867 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ What good stuff?\n',
 "L866 +++$+++ u2 +++$+++ m0 +++$+++ CAME

In [5]:
# Check maximum lengths in pairs_encoded
max_question_length = 0
max_reply_length = 0

for pair in pairs_encoded:
    question, reply = pair
    max_question_length = max(max_question_length, len(question))
    max_reply_length = max(max_reply_length, len(reply))

print(f"Maximum question length: {max_question_length}")
print(f"Maximum reply length: {max_reply_length}")

# Also check if any sequence exceeds the max_len
exceeds_limit = False
for pair in pairs_encoded:
    question, reply = pair
    if len(question) > max_len or len(reply) > max_len:
        exceeds_limit = True
        break

print(f"\nAre there any sequences exceeding max_len ({max_len})? {exceeds_limit}")

Maximum question length: 13
Maximum reply length: 13

Are there any sequences exceeding max_len (13)? False


## Data Loading and Masking

In [1]:
class ChatDataset(Dataset):
    """Dataset class for handling chat conversation pairs."""

    DEFAULT_DATA_PATH = 'data/pairs_encoded.json'

    def __init__(self, data_path=DEFAULT_DATA_PATH):
        """Initialize dataset with conversation pairs.

        Args:
            data_path: Path to the JSON file containing encoded pairs
        """
        try:
            with open(data_path, 'r') as file:
                self.pairs = json.load(file)
            if not self.pairs:
                raise ValueError("Empty dataset loaded")
        except FileNotFoundError:
            raise FileNotFoundError(f"Data file not found at: {data_path}")
        except json.JSONDecodeError:
            raise ValueError(f"Invalid JSON format in file: {data_path}")

        self.dataset_size = len(self.pairs)

    def __getitem__(self, index):
        """Get a conversation pair at the specified index."""
        pair = self.pairs[index]
        question, reply = pair[0], pair[1]
        return torch.LongTensor(question), torch.LongTensor(reply)

    def __len__(self):
        """Return the total number of conversation pairs."""
        return self.dataset_size

NameError: name 'Dataset' is not defined

In [7]:
train_loader = DataLoader(ChatDataset(), batch_size=100, shuffle=True, pin_memory=True)

In [8]:
question, reply = next(iter(train_loader))

In [9]:
question.shape

torch.Size([100, 13])

In [10]:
reply.shape

torch.Size([100, 13])

In [11]:
def create_masks(question, reply_input, reply_target):
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)

    question_mask = (question != 0).to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)

    reply_input_mask = reply_input != 0
    reply_input_mask = reply_input_mask.unsqueeze(1)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data)
    # (batch_size, max_words, max_words)
    reply_input_mask = reply_input_mask.unsqueeze(1)
    reply_target_mask = reply_target != 0

    return question_mask, reply_input_mask, reply_target_mask


In [12]:
size = 5
print(torch.triu(torch.ones(size, size)).transpose(0, 1))

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])


In [13]:
question[0] != 0

tensor([ True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False])

## Embeddings

In [14]:
class Embeddings(nn.Module):
    DEFAULT_MAX_LENGTH = 50
    DROPOUT_RATE = 0.1
    POSITION_ENCODING_BASE = 10000

    def __init__(self, vocab_size, embedding_dim, max_length=DEFAULT_MAX_LENGTH):
        super(Embeddings, self).__init__()

        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim

        self.dropout = nn.Dropout(self.DROPOUT_RATE)
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim)

        # Create position encoding on CPU and register as buffer
        # Using register_buffer is crucial as it handles device transfers properly
        position_encoding = self._create_position_encoding(max_length, embedding_dim)
        self.register_buffer('position_encoding', position_encoding)

    def forward(self, x):
        # Get sequence length from input
        seq_len = x.size(1)

        # Apply token embedding
        embeddings = self.token_embedding(x) * math.sqrt(self.embedding_dim)

        # Add positional encoding (slice to match sequence length)
        embeddings = embeddings + self.position_encoding[:seq_len, :]

        return self.dropout(embeddings)

    def _create_position_encoding(self, max_length, embedding_dim):
        """Create position encoding matrix with sinusoidal patterns"""
        position_encoding = torch.zeros(max_length, embedding_dim)

        for pos in range(max_length):
            for dim in range(0, embedding_dim, 2):
                div_term = math.exp(dim * -math.log(self.POSITION_ENCODING_BASE) / embedding_dim)
                position_encoding[pos, dim] = math.sin(pos * div_term)
                if dim + 1 < embedding_dim:
                    position_encoding[pos, dim + 1] = math.cos(pos * div_term)

        return position_encoding

## MultiHead Attention Implementation Part 1

In [15]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism implementation.

    Allows the model to jointly attend to information from different representation
    subspaces at different positions.
    """
    DROPOUT_RATE = 0.1
    ATTENTION_MASK_FILL_VALUE = -1e9

    def __init__(self, num_heads, d_model):
        """Initialize the multi-head attention layer.

        Args:
            num_heads: Number of attention heads
            d_model: Dimension of the model
        """
        super(MultiHeadAttention, self).__init__()
        self._validate_dimensions(d_model, num_heads)

        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        # Linear layers for transformations
        self.query_transform = nn.Linear(d_model, d_model)
        self.key_transform = nn.Linear(d_model, d_model)
        self.value_transform = nn.Linear(d_model, d_model)
        self.output_transform = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(self.DROPOUT_RATE)

    def _validate_dimensions(self, d_model, num_heads):
        """Validate that model dimensions are compatible with number of heads."""
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by the number of heads")

    def _split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, d_k)."""
        return x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

    def _calculate_attention_scores(self, query, key):
        """Calculate raw attention scores."""
        return torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)

    def forward(self, query, key, value, mask):
        """
        Compute multi-head attention.

        Args:
            query: Query tensor of shape (batch_size, seq_len, d_model)
            key: Key tensor of shape (batch_size, seq_len, d_model)
            value: Value tensor of shape (batch_size, seq_len, d_model)
            mask: Attention mask of shape (batch_size, seq_len, seq_len)

        Returns:
            Output tensor of shape (batch_size, seq_len, d_model)
        """
        batch_size = query.size(0)

        # Linear transformations
        query = self.query_transform(query)
        key = self.key_transform(key)
        value = self.value_transform(value)

        # Split heads
        query = self._split_heads(query, batch_size)
        key = self._split_heads(key, batch_size)
        value = self._split_heads(value, batch_size)

        # Calculate attention scores
        attention_scores = self._calculate_attention_scores(query, key)

        # Apply mask and softmax
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, self.ATTENTION_MASK_FILL_VALUE)
        attention_weights = self.dropout(F.softmax(attention_scores, dim=-1))

        # Apply attention to values
        output = torch.matmul(attention_weights, value)

        # Combine heads and apply final transformation
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.output_transform(output)

In [16]:
class FeedForward(nn.Module):
    """Feed-forward neural network module with two linear layers.

    Implements a feed-forward network that consists of two linear transformations
    with a ReLU activation and dropout in between.
    """

    DROPOUT_RATE = 0.1

    def __init__(self, input_dim, hidden_dim=2048) -> None:
        """Initialize the feed-forward network.

        Args:
            input_dim: Dimension of input features
            hidden_dim: Dimension of hidden layer (default: 2048)
        """
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, input_dim)
        self.dropout = nn.Dropout(self.DROPOUT_RATE)

    def forward(self, x):
        """Forward pass of the feed-forward network.

        Args:
            x: Input tensor

        Returns:
            Processed tensor after passing through the feed-forward layers
        """
        hidden = F.relu(self.fc1(x))
        output = self.fc2(self.dropout(hidden))
        return output

In [17]:
def visualize_attention_transpose():
    # Create sample data
    batch_size, seq_len, num_heads, head_dim = 1, 3, 2, 4
    key = torch.ones(batch_size, num_heads, seq_len, head_dim)

    print("Original key shape:", key.shape)
    print("Key before transpose:\n", key[0])  # Show first batch

    key_transposed = key.transpose(-2, -1)
    print("\nTransposed key shape:", key_transposed.shape)
    print("Key after transpose:\n", key_transposed[0])  # Show first batch

    return key, key_transposed

# Run visualization
key, key_t = visualize_attention_transpose()

Original key shape: torch.Size([1, 2, 3, 4])
Key before transpose:
 tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])

Transposed key shape: torch.Size([1, 2, 4, 3])
Key after transpose:
 tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])


## Encoder Layer

In [18]:
class EncoderLayer(nn.Module):
    """Encoder layer implementing self-attention and feed-forward mechanisms.

    Contains self-attention layer, feed-forward network, layer normalization,
    and dropout for regularization.
    """

    DROPOUT_RATE = 0.1  # Default dropout rate

    def __init__(self, d_model, num_heads):
        """Initialize encoder layer components.

        Args:
            d_model: Model dimension/size
            num_heads: Number of attention heads
        """
        super(EncoderLayer, self).__init__()
        # Initialize attention mechanism
        self.self_attention = MultiHeadAttention(num_heads, d_model)

        # Initialize feed-forward network
        self.feed_forward = FeedForward(d_model)

        # Initialize layer normalization
        self.layer_norm = nn.LayerNorm(d_model)

        # Initialize dropout
        self.dropout = nn.Dropout(self.DROPOUT_RATE)

    def forward(self, embedding, mask):
        """Process input through encoder layer.

        Args:
            embedding: Input embeddings
            mask: Attention mask

        Returns:
            Processed tensor after self-attention and feed-forward layers
        """
        # Self-attention block
        attention_output = self.self_attention(embedding, embedding, embedding, mask)
        normalized_attention = self.layer_norm(embedding + attention_output)

        # Feed-forward block
        feed_forward_output = self.feed_forward(normalized_attention)
        feed_forward_output = self.dropout(feed_forward_output)

        # Final layer normalization
        output = self.layer_norm(normalized_attention + feed_forward_output)

        return output

## Decoder Layer

In [19]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads):
        super(DecoderLayer, self).__init__()
        # Initialize layer components
        self.layer_norm = nn.LayerNorm(d_model)
        self.self_attention = MultiHeadAttention(num_heads, d_model)
        self.encoder_attention = MultiHeadAttention(num_heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, decoder_input, encoder_output, decoder_mask, encoder_mask):
        # Self attention block
        self_attention = self.self_attention(
            decoder_input, decoder_input, decoder_input, encoder_mask
        )
        self_attention = self.dropout(self_attention)
        self_attention = self.layer_norm(self_attention + decoder_input)

        # Encoder-decoder attention block
        enc_dec_attention = self.encoder_attention(
            self_attention, encoder_output, encoder_output, decoder_mask
        )
        enc_dec_attention = self.dropout(enc_dec_attention)
        enc_dec_attention = self.layer_norm(enc_dec_attention + self_attention)

        # Feed forward block
        output = self.feed_forward(enc_dec_attention)
        output = self.dropout(output)
        output = self.layer_norm(output + enc_dec_attention)

        return output

## Transformer

In [20]:
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, num_layers, word_map, max_length=50):
        super(Transformer, self).__init__()
        self.d_model = d_model
        self.vocab_size = len(word_map)
        self.embedding = Embeddings(self.vocab_size, d_model, max_length)
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads) for _ in range(num_layers)]
        )
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads) for _ in range(num_layers)]
        )
        self.logit = nn.Linear(d_model, self.vocab_size)

    def encode(self, src_words, src_mask):
        embedding = self.embedding(src_words)
        for encoder_layer in self.encoder_layers:
            embedding = encoder_layer(embedding, src_mask)
        return embedding

    def decode(self, tgt_words, tgt_mask, src_embeddings, src_mask):
        embedding = self.embedding(tgt_words)
        for decoder_layer in self.decoder_layers:
            embedding = decoder_layer(embedding, src_embeddings,  src_mask, tgt_mask)
        return embedding

    def forward(self, src_words, src_mask, tgt_words, tgt_mask):
        src_embeddings = self.encode(src_words, src_mask)
        output = self.decode(tgt_words, tgt_mask, src_embeddings, src_mask)
        return F.log_softmax(self.logit(output), dim=-1)

## AdamWarmup

In [21]:
class AdamOptimizerWithWarmup:
    """Implements Adam optimizer with warmup learning rate scheduling."""

    # Constants for learning rate calculation
    MODEL_SIZE_POWER = -0.5
    STEP_POWER = -0.5
    WARMUP_POWER = -1.5

    def __init__(self, model_size, warmup_steps, optimizer):
        """
        Initialize the optimizer wrapper.

        Args:
            model_size: Size of the model (d_model in transformer architecture)
            warmup_steps: Number of warmup steps
            optimizer: Base optimizer instance
        """
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.learning_rate = 0.0

    def calculate_learning_rate(self):
        """Calculate the learning rate based on current step and warmup parameters."""
        model_factor = self.model_size ** self.MODEL_SIZE_POWER
        step_factor = min(
            self.current_step ** self.STEP_POWER,
            self.current_step * self.warmup_steps ** self.WARMUP_POWER
        )
        return model_factor * step_factor

    def _update_optimizer_learning_rate(self, lr):
        """Update the learning rate in optimizer's parameter groups."""
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def step(self):
        """Perform a single optimization step."""
        self.current_step += 1
        self.learning_rate = self.calculate_learning_rate()
        self._update_optimizer_learning_rate(self.learning_rate)
        self.optimizer.step()

In [22]:
class LabelSmoothedLoss(nn.Module):
    """Loss function with label smoothing for more stable training."""

    def __init__(self, vocab_size, smoothing_factor):
        """
        Initialize the label smoothed loss.

        Args:
            vocab_size: Size of the vocabulary
            smoothing_factor: Label smoothing factor between 0 and 1
        """
        super(LabelSmoothedLoss, self).__init__()
        # Initialize with 'none' reduction to handle masking manually
        self.criterion = nn.KLDivLoss(reduction='none')
        self.vocab_size = vocab_size
        self.smoothing_factor = smoothing_factor
        self.confidence = 1.0 - smoothing_factor

    def _prepare_inputs(self, predictions, target, mask):
        """
        Reshape inputs to 2D tensors for processing.

        Args:
            predictions: Model output logits
            target: Ground truth indices
            mask: Mask for valid positions

        Returns:
            Tuple of reshaped tensors (predictions, target, mask)
        """
        predictions = predictions.contiguous().view(-1, predictions.size(-1))
        target = target.contiguous().view(-1)
        mask = mask.contiguous().view(-1)
        return predictions, target, mask

    def _create_smoothed_labels(self, target, predictions):
        """
        Create smoothed label distribution.

        Args:
            target: Ground truth indices
            predictions: Model output logits to match shape

        Returns:
            Tensor of smoothed label distributions
        """
        smoothed_labels = torch.zeros_like(predictions)
        smoothing_value = self.smoothing_factor / (self.vocab_size - 1)
        smoothed_labels.fill_(smoothing_value)
        smoothed_labels.scatter_(1, target.unsqueeze(1), self.confidence)
        return smoothed_labels

    def forward(self, predictions, target, mask):
        """
        Compute label-smoothed loss.

        Args:
            predictions: Network output (batch_size, seq_len, vocab_size)
            target: Ground truth indices (batch_size, seq_len)
            mask: Mask for valid positions (batch_size, seq_len)

        Returns:
            Scalar loss value
        """
        predictions, target, mask = self._prepare_inputs(predictions, target, mask)
        smoothed_labels = self._create_smoothed_labels(target, predictions)

        # Apply log softmax and calculate KL divergence loss
        loss = self.criterion(F.log_softmax(predictions, dim=-1), smoothed_labels)
        masked_loss = (loss.sum(1) * mask).sum()
        num_valid_elements = mask.sum().clamp(min=1)

        return masked_loss / num_valid_elements

In [23]:
batch_size = 5
max_words = 7
vocab_size = 3
prediction = torch.randn(batch_size, max_words, vocab_size)

In [24]:
prediction
prediction.size()

torch.Size([5, 7, 3])

In [25]:
prediction = prediction.view(-1, prediction.shape[-1])

In [26]:
prediction.shape

torch.Size([35, 3])

In [27]:
target = torch.LongTensor(batch_size * max_words).random_(0, vocab_size)

In [28]:
target
print(target.shape, target.view(-1))

torch.Size([35]) tensor([2, 0, 1, 0, 2, 1, 0, 2, 0, 0, 1, 2, 2, 0, 2, 2, 0, 2, 2, 0, 2, 1, 1, 2,
        2, 0, 1, 1, 1, 1, 0, 2, 0, 2, 0])


In [29]:
mask = target != 0

In [30]:
mask.shape

torch.Size([35])

In [31]:
labels = prediction.data.clone()

In [32]:
labels.shape
labels.dim()

2

In [33]:
labels[0][0]

tensor(1.6470)

In [34]:
labels

tensor([[ 1.6470, -2.5175,  2.4276],
        [-0.3093, -0.2986, -0.5336],
        [-1.5328,  0.8971,  0.8988],
        [ 1.2544, -0.6875,  0.5375],
        [-1.3537, -1.9106,  2.2343],
        [ 0.1863, -1.5875,  0.8846],
        [-0.3760, -0.2982,  1.5184],
        [ 0.2096, -0.8184,  0.7062],
        [ 0.1060,  1.1480,  1.4003],
        [-0.1622,  0.1055,  0.5926],
        [ 0.8117, -0.6614,  0.9683],
        [-0.3500, -0.7261,  0.3227],
        [-1.4175, -0.0279,  0.4830],
        [-0.0311,  0.8625,  0.3575],
        [ 0.3558,  0.1479,  2.8373],
        [ 1.0610, -0.7026,  1.3265],
        [-0.9530, -0.7838, -0.6476],
        [-1.6110, -0.5115,  0.6585],
        [-0.6053,  1.3101,  0.6674],
        [-0.5408, -0.3562,  0.0116],
        [-0.2469,  0.9180,  2.8315],
        [-0.1569,  0.7847, -0.5085],
        [-0.6397,  0.2290,  0.6971],
        [ 0.9160,  0.1785, -0.7869],
        [ 0.8232, -2.0297,  0.6582],
        [-1.7281, -1.2061,  0.9214],
        [-0.6829,  0.5821, -0.9943],
 

In [35]:
labels.fill_(0.3 / (vocab_size - 1))

tensor([[0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.1500, 0.1500, 0.1500],
        [0.150

In [36]:
# labels.scatter(1, target.data.unsqueeze(1), 1 - 0.3)

In [37]:
d_model = 512 #512
num_heads = 8 # 8
num_layers = 6 # 6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 25 # 25

with open("data/word_map_corpus.json", "r") as f:
    word_map = json.load(f)
transformer = Transformer(d_model, num_heads, num_layers, word_map, 50).to(device)
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamOptimizerWithWarmup(d_model, 4000, adam_optimizer)
criterion = LabelSmoothedLoss(vocab_size=len(word_map), smoothing_factor=0.1)

## Training Function

In [38]:
def train(train_loader, transformer, criterion, epoch, batch_log_frequency=100):
    """Train the transformer model for one epoch.

    Args:
        train_loader: DataLoader containing training data
        transformer: Transformer model to train
        criterion: Loss function
        epoch: Current epoch number
        batch_log_frequency: How often to log batch progress (default: every 100 batches)
    """
    transformer.train()
    running_loss = 0
    total_samples = 0

    def process_batch(question, reply):
        question = question.to(device)
        reply = reply.to(device)
        batch_size = question.shape[0]

        # Split reply into input and target sequences
        reply_input = reply[:, :-1]  # all tokens except last
        reply_target = reply[:, 1:]  # all tokens except first

        # Generate masks and compute model output
        question_mask, reply_input_mask, reply_target_mask = create_masks(
            question, reply_input, reply_target)
        model_output = transformer(question, question_mask, reply_input, reply_input_mask)

        # Compute loss
        batch_loss = criterion(model_output, reply_target, reply_target_mask)
        return batch_loss, batch_size

    def log_progress(batch_idx, running_loss, total_samples):
        avg_loss = running_loss / total_samples
        print(f"Epoch {epoch} | Batch {batch_idx} | Average Loss {avg_loss:.4f}")

    for batch_idx, (question, reply) in enumerate(train_loader):
        # Process batch and compute loss
        batch_loss, batch_size = process_batch(question, reply)

        # Update running statistics
        running_loss += batch_loss.item() * batch_size
        total_samples += batch_size

        # Backward pass and optimization
        transformer_optimizer.optimizer.zero_grad()
        batch_loss.backward()
        transformer_optimizer.step()

        # Log progress at specified intervals
        if batch_idx % batch_log_frequency == 0:
            log_progress(batch_idx, running_loss, total_samples)

In [39]:
def create_target_mask(size):
    """
    Creates a triangular (causal) mask for decoder self-attention.
    Args:
        size: Size of the target sequence
    Returns:
        Mask tensor of shape (1, 1, size, size)
    """
    # Create a triangular mask (lower triangular matrix of 1's)
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    # Reshape mask to (1, 1, size, size) for broadcasting
    mask = mask.unsqueeze(0).unsqueeze(0)
    return mask.to(device)

def evaluate(transformer, question, question_mask, max_len=50,
                         temperature=0.7, top_k=50):
    """
    Generate response using temperature and top-k sampling.

    Args:
        transformer: The transformer model
        question: Input question tensor
        question_mask: Mask for the input question
        max_len: Maximum length of generated sequence
        temperature: Temperature for sampling (lower = more conservative)
        top_k: Number of top tokens to consider for sampling

    Returns:
        Generated sequence of token indices
    """
    transformer.eval()
    rev_word_map = {v: k for k, v in word_map.items()}

    # Encode the input question
    encoded = transformer.encode(question, question_mask)

    # Initialize with start token
    words = torch.LongTensor([[word_map['<start>']]]).to(device)

    with torch.no_grad():
        for _ in range(max_len - 1):
            # Create mask for the current sequence
            target_mask = create_target_mask(words.size(1))

            # Generate next token probabilities
            decoded = transformer.decode(words, target_mask, encoded, question_mask)
            predictions = transformer.logit(decoded[:, -1])

            # Apply temperature
            predictions = predictions / temperature

            # Apply top-k filtering
            values, indices = predictions[0].topk(top_k)
            predictions[0] = torch.full_like(predictions[0], float('-inf'))
            predictions[0, indices] = values

            # Sample from the filtered distribution
            probabilities = F.softmax(predictions, dim=-1)
            next_word = torch.multinomial(probabilities, 1)

            # Stop if end token is generated
            if next_word.item() == word_map['<end>']:
                break

            # Add the new token to the sequence
            words = torch.cat([words, next_word.view(1, 1)], dim=1)

    # Convert to text
    generated_tokens = words.squeeze(0).tolist()
    generated_words = [rev_word_map[idx] for idx in generated_tokens
                      if idx not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}]

    return ' '.join(generated_words)

In [40]:
def beam_search_evaluate(transformer, question, question_mask, beam_size=5, max_len=50):
    """
    Generate response using beam search.

    Args:
        transformer: The transformer model
        question: Input question tensor
        question_mask: Mask for the input question
        beam_size: Number of beams to maintain
        max_len: Maximum length of generated sequence
    """
    transformer.eval()
    rev_word_map = {v: k for k, v in word_map.items()}

    # Encode the input question
    encoded = transformer.encode(question, question_mask)

    # Initialize beams with start token
    # Each beam is (sequence, score)
    beams = [(torch.LongTensor([[word_map['<start>']]]).to(device), 0.0)]

    with torch.no_grad():
        for _ in range(max_len-1):
            candidates = []

            # Expand each current beam
            for sequence, score in beams:
                # Skip if sequence is already completed
                if sequence[0][-1].item() == word_map['<end>']:
                    candidates.append((sequence, score))
                    continue

                # Create mask for the current sequence
                target_mask = create_target_mask(sequence.size(1))

                # Generate next token probabilities
                decoded = transformer.decode(sequence, target_mask, encoded, question_mask)
                logits = transformer.logit(decoded[:, -1])
                log_probs = F.log_softmax(logits, dim=-1)

                # Get top k candidates for each beam
                values, indices = log_probs[0].topk(beam_size)

                # Create new candidates
                for token, token_score in zip(indices, values):
                    new_sequence = torch.cat([sequence,
                        torch.LongTensor([[token]]).to(device)], dim=1)
                    # Add scores in log space
                    new_score = score + token_score.item()
                    candidates.append((new_sequence, new_score))

            # Select top beam_size candidates
            # Sort by score and normalize by length to prevent bias towards shorter sequences
            candidates = [(seq, score/len(seq[0])) for seq, score in candidates]
            beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]

            # Stop if all beams end with <end> token
            if all(b[0][0][-1].item() == word_map['<end>'] for b in beams):
                break

    # Return the highest scoring sequence
    best_sequence = beams[0][0]
    generated_tokens = best_sequence.squeeze(0).tolist()
    generated_words = [rev_word_map[idx] for idx in generated_tokens
                      if idx not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}]

    return ' '.join(generated_words)

In [41]:
for epoch in range(epochs):
    train(train_loader, transformer, criterion, epoch)
    state = {
        'epoch': epoch,
        'transformer_state_dict': transformer.state_dict(),
        'optimizer_state_dict': transformer_optimizer.optimizer.state_dict()
    }
    torch.save(state, f"models/checkpoint_{epoch}.tar")

Epoch 0 | Batch 0 | Average Loss 9.4420
Epoch 0 | Batch 100 | Average Loss 8.1894
Epoch 0 | Batch 200 | Average Loss 7.4311
Epoch 0 | Batch 300 | Average Loss 6.7556


KeyboardInterrupt: 

In [41]:
# Add both custom classes and the built-in set type
import torch.serialization

# Add built-in set type
torch.serialization.add_safe_globals([set])

# Add all custom model classes
torch.serialization.add_safe_globals([
    Transformer,
    MultiHeadAttention,
    EncoderLayer,
    DecoderLayer,
    Embeddings,
    FeedForward,
    AdamOptimizerWithWarmup,
    LabelSmoothedLoss
])

# Now try loading the checkpoint
checkpoint = torch.load("models/checkpoint_24.tar", weights_only=True)

# Create new model instance and load state dict
transformer = Transformer(d_model, num_heads, num_layers, word_map).to(device)
transformer.load_state_dict(checkpoint['transformer_state_dict'])

<All keys matched successfully>

In [42]:

# Function to chat with the model using the new sampling method
def chat_with_model(transformer, word_map, max_len=50, temperature=0.7, top_k=50):
    """
    Interactive chat function with the model.

    Args:
        transformer: The transformer model
        word_map: Dictionary mapping words to indices
        max_len: Maximum length of generated response
        temperature: Temperature for sampling
        top_k: Number of top tokens to consider for sampling
    """
    transformer.eval()
    print("Chat started (type 'quit' to exit)")

    while True:
        # Get user input
        user_input = input("You: ")
        if user_input.lower() in ['quit', 'exit']:
            break

        # Preprocess input
        tokens = word_tokenize(user_input.lower())
        encoded = [word_map.get(token, word_map['<unk>']) for token in tokens]

        # Prepare input tensors
        question = torch.LongTensor([encoded]).to(device)
        question_mask = (question != 0).unsqueeze(1).unsqueeze(1).to(device)

        # Generate response
        response = evaluate(transformer, question, question_mask,
                                       max_len=max_len,
                                       temperature=temperature,
                                       top_k=top_k)

        print("Bot:", response)

In [44]:
# After training the model
print("Starting chat with improved generation...")
chat_with_model(transformer, word_map,
                max_len=50,        # Maximum response length
                temperature=0.6,   # Lower = more focused, higher = more creative
                top_k=10)         # Number of top tokens to consider

Starting chat with improved generation...
Chat started (type 'quit' to exit)
Bot: yeah
Bot: i
Bot: its
Bot: you
Bot: i
Bot: i
Bot: i
Bot: yes
Bot: you


KeyboardInterrupt: Interrupted by user

In [None]:
def chat_with_model(transformer, word_map, max_len=50, beam_size=5):
    transformer.eval()
    print("Chat started (type 'quit' to exit)")

    while True:
        user_input = input("You: ")
        if user_input.lower() in ['quit', 'exit']:
            break

        # Preprocess input
        tokens = word_tokenize(user_input.lower())
        encoded = [word_map.get(token, word_map['<unk>']) for token in tokens]

        # Prepare input tensors
        question = torch.LongTensor([encoded]).to(device)
        question_mask = (question != 0).unsqueeze(1).unsqueeze(1).to(device)

        # Generate response using beam search
        response = beam_search_evaluate(transformer, question, question_mask,
                                     beam_size=beam_size, max_len=max_len)
        print("Bot:", response)

# Try with different beam sizes
chat_with_model(transformer, word_map, beam_size=5)  # Start with beam_size=5