In [None]:
"""
I introduced another learning variable in repulsion loss. Meant to unblurr concepts.
Meaning concepts that are way too similar should repel each other.
Still on its testing phase.
"""

In [None]:
# --- Standard Python Libraries ---
import os
import re
import math
import json
import unicodedata
import random
import itertools
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

# --- Core PyTorch Libraries ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch import Tensor
import numpy as np
from torch.utils.data import Sampler
from torch.utils.data import Dataset
import torch.utils.checkpoint as checkp

# --- Hugging Face Libraries ---
# Required for the tokenizer and learning rate scheduler.
# pip install transformers
# pip install tokenizers
from tokenizers import Tokenizer
from transformers import get_linear_schedule_with_warmup

In [None]:
def clean_data(sentence: str):
    """
    Normalizes and cleans a string by converting to lowercase, removing accents,
    isolating punctuation, and removing non-alphanumeric characters.
    """
    # Convert to lowercase, strip whitespace, and remove diacritics (accents).
    ascii_string = ''.join(
        c for c in unicodedata.normalize('NFD', sentence.lower().strip())
        if unicodedata.category(c) != 'Mn'
    )
    # Add a space before punctuation to treat it as a separate token.
    ascii_string = re.sub(r"([.!?])", r" \1", ascii_string)
    # Remove any characters that are not letters or the specified punctuation.
    ascii_string = re.sub(r"[^a-zA-Z.!?]+", r" ", ascii_string)
    # A final check to remove any remaining non-ASCII characters.
    ascii_string = re.sub(r"[^\x00-\x7F]", r"", ascii_string)
    # Replace multiple whitespace characters with a single space.
    return re.sub(r"\s+", r" ", ascii_string).strip()

def load_and_prepare_data(path, token_path):
    """Loads a JSON file, cleans the text, and prepares the tokenizer."""
    print("Loading and cleaning data...")
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    tokenizer = Tokenizer.from_file(token_path)
    vocabulary = tokenizer.get_vocab()
    
    # Logging the vocabulary size for verification.
    print(f"Vocabulary size: {len(vocabulary)}")

    pairs = []
    for entry in data:
        # Assuming the data has been pre-cleaned.
        pairs.append([entry['input'], entry['target']])
        
    return pairs, tokenizer, len(vocabulary)

def indexesFromSentence(tokenizer, sentence, SOS_token, EOS_token):
    """
    Encodes a sentence into a list of token indices, adding Start-of-Sentence
    and End-of-Sentence tokens.
    """
    encoded_ids = tokenizer.encode(sentence).ids
    return [SOS_token] + encoded_ids + [EOS_token]

def pad_tensor(x, PAD_token):
    """Pads a list of tensors to the same length using the PAD_token."""
    padded_tensor = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=PAD_token)
    return padded_tensor

def mask_tensor(x, PAD_token):
    """Creates a boolean mask for a tensor, identifying non-PAD tokens."""
    masked_tensor = (x != PAD_token)
    return masked_tensor

def batch_to_tensors(tokenizer, batch_pairs, SOS, EOS, PAD):
    """
    Converts a batch of string pairs into padded and masked tensors
    for model input.
    """
    # Unzip the batch of pairs into separate lists for questions and responses.
    questions, responses = zip(*batch_pairs)
    
    # Convert sentences to token indices.
    questions_indexed = [indexesFromSentence(tokenizer, sentence, SOS, EOS) for sentence in questions]
    responses_indexed = [indexesFromSentence(tokenizer, sentence, SOS, EOS) for sentence in responses]
    
    # Convert lists of indices to PyTorch tensors.
    question_tensors = [torch.LongTensor(q) for q in questions_indexed]
    response_tensors = [torch.LongTensor(r) for r in responses_indexed]
    
    # Record the original lengths for potential use later (e.g., with RNNs).
    question_lengths = torch.tensor([len(q) for q in question_tensors], dtype=torch.long)
    response_lengths = torch.tensor([len(r) for r in response_tensors], dtype=torch.long)

    # Pad the tensors to ensure they are all the same length.
    padded_questions = pad_tensor(question_tensors, PAD)
    padded_responses = pad_tensor(response_tensors, PAD)
    
    # Create attention masks to ignore the padded areas.
    mask_questions = mask_tensor(padded_questions, PAD)
    mask_responses = mask_tensor(padded_responses, PAD)

    return padded_questions, question_lengths, padded_responses, response_lengths, mask_questions, mask_responses

In [None]:
class ChatDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        # Returns the total number of samples in the dataset.
        return len(self.pairs)

    def __getitem__(self, idx):
        # Retrieves a single input-target pair from the dataset.
        return self.pairs[idx]

class BucketRandomSampler(Sampler):
    def __init__(self, data_source, bucket_size, batch_size):
        """
        A custom sampler that groups sentences of similar lengths into buckets,
        shuffles the batches within those buckets, and then yields the indices.
        This helps to minimize padding and improve training efficiency.

        Args:
            data_source: The dataset to sample from.
            bucket_size: The number of samples to group into a single bucket.
            batch_size: The desired batch size.
        """
        super().__init__(data_source)
        self.data_source = data_source
        self.bucket_size = bucket_size
        self.batch_size = batch_size

        # Pre-calculate the combined length of each sentence pair for sorting.
        self.lengths = [len(pair[0]) + len(pair[1]) for pair in data_source.pairs]

    def __iter__(self):
        # Create a list of indices from 0 to the length of the dataset.
        indices = np.arange(len(self.data_source))

        # --- The Bucketing Logic ---
        # 1. Group indices into larger "buckets".
        num_buckets = (len(self.data_source) + self.bucket_size - 1) // self.bucket_size
        
        all_shuffled_indices = []

        for i in range(num_buckets):
            # Get the indices for the current bucket.
            start_idx = i * self.bucket_size
            end_idx = start_idx + self.bucket_size
            bucket_indices = indices[start_idx:end_idx]

            # 2. Sort the indices within this bucket by sentence length.
            #    This places sentences of similar lengths next to each other.
            bucket_lengths = [self.lengths[idx] for idx in bucket_indices]
            sorted_pairs = sorted(zip(bucket_indices, bucket_lengths), key=lambda x: x[1])
            sorted_bucket_indices = [x[0] for x in sorted_pairs]
            
            # 3. Create batches from the sorted bucket and then shuffle the batches.
            #    This maintains randomness at the batch level while keeping lengths similar within a batch.
            batches_in_bucket = []
            num_batches = (len(sorted_bucket_indices) + self.batch_size - 1) // self.batch_size
            for j in range(num_batches):
                batch_start = j * self.batch_size
                batch_end = batch_start + self.batch_size
                batches_in_bucket.append(sorted_bucket_indices[batch_start:batch_end])

            # Shuffle the order of the batches within the bucket.
            np.random.shuffle(batches_in_bucket)
            
            # Add the shuffled batches to the final list of indices to be yielded.
            for batch in batches_in_bucket:
                all_shuffled_indices.extend(batch)
        
        return iter(all_shuffled_indices)

    def __len__(self):
        return len(self.data_source)

class Embedding(nn.Module):
    
    def __init__(self, total_number, embedding_dimension):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=total_number, embedding_dim=embedding_dimension, padding_idx=PAD_token)
        
    def forward(self, tensor):
        tensor_id = self.embedding(tensor)
        return tensor_id

class PositionalEncoding(nn.Module):

    def __init__(self, embedding_dim: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim))
        pe = torch.zeros(max_len, 1, embedding_dim)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Injects positional information into the input embeddings.

        Arguments:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        x = x + self.pe[:x.size(1)].transpose(0, 1)
        return self.dropout(x)

# Based on: https://github.com/Hanhpt23/Implement-Self-attention-Pytorch/blob/main/self-attention.py
class EncoderAttention(nn.Module):
    """    
    Implementation of self-attention as described in 'Attention Is All You Need'.
    
    Note:
    A sliding context window has been added to manage memory usage, which
    can be high with long sequences.
    """
    
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.1):
        '''dim: The embedding dimension of the input tokens.'''

        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divisible by num_heads {num_heads}."

        self.context_window = 200
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # Using a single linear layer for Q, K, V is more efficient.
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, responses, padding_mask=None):
        B, N, C = responses.shape

        # self.qkv(x) results in [B, N, 3*C]
        # .reshape -> [B, N, 3, h, C/h]
        # .permute -> [3, B, h, N, C/h]
        qkv = self.qkv(responses).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each becomes [B, h, N, C/h]

        # Calculate scaled dot-product attention.
        attn = (q @ k.transpose(-2, -1)) * self.scale

        seq_len = responses.size(1)
    
        # Create the sliding window mask.
        # This masks positions where the distance between query and key is too large.
        rows = torch.arange(seq_len, device=responses.device).unsqueeze(0)
        cols = torch.arange(seq_len, device=responses.device).unsqueeze(1)
        window_size = self.context_window // 2
        
        # 'window_mask' is True for positions outside the window, which will be masked.
        window_mask = (rows - cols).abs() > window_size
        
        # Apply the sliding window mask.
        attn = attn.masked_fill(window_mask.unsqueeze(0).unsqueeze(1), float('-inf'))

        # Apply the padding mask.
        if padding_mask is not None:
            mask_reshaped = padding_mask.unsqueeze(1).unsqueeze(1)
            attn = attn.masked_fill(mask_reshaped == 0, float('-inf'))

        # Normalize with softmax over the key dimension.
        attn = attn.softmax(dim=-1)
        attn = torch.nan_to_num(attn)
        attn = self.attn_drop(attn)

        # Multiply with values and reshape for the final output.
        # (attn @ v) -> [B, h, N, C/h]
        # .transpose -> [B, N, h, C/h]
        # .reshape -> [B, N, C]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# Similar to the encoder's attention, but with causal and sliding window masking.
class DecoderAttention(nn.Module):
    
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.1):

        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divisible by num_heads {num_heads}."

        
        self.dim = dim
        self.context_window = 200
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # A single linear layer for Q, K, V is more efficient.
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, responses, padding_mask=None):
        B, N, C = responses.shape

        qkv = self.qkv(responses).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale

        combined_mask = torch.zeros(N, N, device=q.device).bool()

        # Causal mask to prevent attending to future tokens.
        look_ahead_mask = torch.triu(torch.ones(N, N, device=responses.device), diagonal=1).bool()

        # Sliding window mask.
        rows = torch.arange(N, device=responses.device).unsqueeze(0)
        cols = torch.arange(N, device=responses.device).unsqueeze(1)
        window_size = self.context_window // 2
        
        # 'sliding_window_mask' is True for positions outside the window.
        sliding_window_mask = (rows - cols).abs() > window_size

        # Combine masks: a position is masked if it's in the future OR outside the window.
        combined_mask = look_ahead_mask | sliding_window_mask
            
        # Apply the combined mask to the attention scores.
        attn = attn.masked_fill(combined_mask.unsqueeze(0).unsqueeze(1), float('-inf'))

        if padding_mask is not None:
            mask_reshaped = padding_mask.unsqueeze(1).unsqueeze(1)
            attn = attn.masked_fill(mask_reshaped == 0, float('-inf'))

        attn = attn.softmax(dim=-1)
        attn = torch.nan_to_num(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# Inspired by: https://gist.github.com/wolfecameron/5646b2092d41d6d31ec1abb28b3b930a
class CrossAttention(nn.Module):

    def __init__(self, embedding_dim):
        """
        Arguments:
        embedding_dim: size of the embedding dimension.
        """
        super().__init__()
        self.d = embedding_dim
        
        # Linear projection for producing the query matrix.
        self.w_q = nn.Linear(embedding_dim, embedding_dim, bias=False)
        
        # Linear projection for producing key and value matrices.
        self.w_kv = nn.Linear(embedding_dim, 2*embedding_dim, bias=False)

    def forward(self, x_1, x_2, padding_mask=None):
        # Compute query, key, and value matrices.
        q = self.w_q(x_1)
        k, v = self.w_kv(x_2).split(self.d, dim=2)

        # Compute the attention matrix.
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

        if padding_mask is not None:
            # The mask corresponds to the keys/values (x_2).
            mask_reshaped = padding_mask.unsqueeze(1).unsqueeze(1) # Reshape for broadcasting: [B, 1, 1, N_keys]
            att = att.masked_fill(mask_reshaped == 0, float('-inf'))
    
        att = F.softmax(att, dim=-1)
        att = torch.nan_to_num(att)
    
        y = att @ v
        return y

class DifferentiableExplorerAttention(nn.Module):
    def __init__(self, dim, num_heads=8, num_explorers=10, qkv_bias=False, proj_drop=0.1):
        super().__init__()
        assert dim % num_heads == 0, "Embedding dim must be divisible by num_heads."
        self.num_heads = num_heads
        self.num_explorers = num_explorers
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        # 1. Standard QKV projection.
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # 2. A small network to predict explorer positions.
        #    It takes the query context and outputs a position for each explorer.
        self.position_predictor = nn.Sequential(
            nn.Linear(self.head_dim, 128),
            nn.ReLU(),
            nn.Linear(128, self.num_explorers) # Output E positions
        )

    def forward(self, x, padding_mask=None):
        B, N, C = x.shape
        H, E = self.num_heads, self.num_explorers

        # --- Step 1: Standard QKV ---
        qkv = self.qkv(x).reshape(B, N, 3, H, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each is [B, H, N, D]

        # Base attention map shared by all explorers.
        shared_attn_scores = (q @ k.transpose(-2, -1)) * self.scale # [B, H, N, N]

        # --- Step 2: Predict Explorer Positions ---
        # Use the mean of the queries as a context vector for position prediction.
        context_q = q.mean(dim=2) # [B, H, D]
        
        # Predict positions, constrained between 0 and N-1.
        predicted_positions = self.position_predictor(context_q).sigmoid() * (N - 1) # [B, H, E]

        # --- Step 3: Create Soft "Spotlight" Masks from Positions ---
        pos_indices = torch.arange(N, device=x.device, dtype=torch.float32).view(1, 1, 1, N) # [1, 1, 1, N]
        centers = predicted_positions.unsqueeze(-1) # [B, H, E, 1]
        
        # A heuristic for the width of the spotlight.
        spotlight_width = N / (E * 2) 
        
        # Calculate Gaussian masks based on the predicted centers.
        exponent = -((pos_indices - centers) ** 2) / (2 * spotlight_width ** 2)
        spotlight_masks = torch.exp(exponent) # Shape: [B, H, E, N]

        # --- Step 4: Apply Spotlights and Aggregate ---
        # Modify the attention scores for each query based on where the explorers are looking.
        # Reshape for broadcasting:
        # shared_attn_scores: [B, H, N, N] -> [B, H, 1, N, N]
        # spotlight_masks:    [B, H, E, N] -> [B, H, E, 1, N]
        weighted_attn_scores = shared_attn_scores.unsqueeze(2) + spotlight_masks.unsqueeze(3)

        if padding_mask is not None:
            # The mask needs to be broadcastable to [B, H, E, N, N].
            mask = padding_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
            weighted_attn_scores = weighted_attn_scores.masked_fill(mask == 0, float('-inf'))
        
        # Re-normalize with softmax for each explorer's view.
        weighted_attn_probs = F.softmax(weighted_attn_scores, dim=-1) # [B, H, E, N, N]

        # Get the output for all explorers and fuse them by averaging.
        # v needs to be [B, H, 1, N, D] for broadcasting.
        fused_output = (weighted_attn_probs @ v.unsqueeze(2)).mean(dim=2) # [B, H, N, D]

        # --- Step 5: Final Projection ---
        output = fused_output.transpose(1, 2).reshape(B, N, C)
        output = self.proj(output)
        output = self.proj_drop(output)

        # --- Step 6: Repulsion Loss ---
        # This encourages explorers to spread out and cover different parts of the sequence.
        p = predicted_positions.view(B*H, E, 1)
        p_diff = p.transpose(1, 2) - p # [B*H, E, E]
        
        # Use a Gaussian kernel for repulsion.
        repulsion = torch.exp(-(p_diff ** 2) / (2 * spotlight_width ** 2))
        
        # Sum over pairs of different explorers, removing the diagonal (self-repulsion).
        repulsion_loss = repulsion.sum(dim=(-1, -2)) - E
        repulsion_loss = repulsion_loss.mean()

        # The module returns the final output and the auxiliary loss term.
        return output, repulsion_loss

class LayerNormalization(nn.Module):
    """
    Implements Layer Normalization.

    Normalizes the inputs across the features for each data sample,
    making the computation independent of batch size.

    y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
    """
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        # 'gamma' is a learnable scale parameter.
        self.gamma = nn.Parameter(torch.ones(d_model))
        # 'beta' is a learnable shift parameter.
        self.beta = nn.Parameter(torch.zeros(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # --- The Core Logic of Layer Normalization ---

        # 1. Calculate the mean and variance across the feature dimension.
        #    'keepdim=True' is important for broadcasting.
        mean = x.mean(dim=-1, keepdim=True)
        # Using population variance, which is common in ML.
        var = x.var(dim=-1, unbiased=False, keepdim=True)

        # 2. Normalize the input tensor.
        normalized_x = (x - mean) / torch.sqrt(var + self.eps)

        # 3. Scale and shift with the learnable parameters.
        output = self.gamma * normalized_x + self.beta

        return output

class FeedForward(nn.Module):
    """
    The Position-wise Feed-Forward Network from 'Attention Is All You Need'.

    Note:
    Typically, d_ff is 4 times d_model, but I'm using a smaller size
    to reduce computational intensity.
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()

        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (batch_size, sequence_length, d_model) -> (batch_size, sequence_length, d_ff)
        x = self.linear_1(x)
        x = self.relu(x)

        # Applying dropout for regularization.
        x = self.dropout(x)

        # (batch_size, sequence_length, d_ff) -> (batch_size, sequence_length, d_model)
        x = self.linear_2(x)

        return x

# In the EncoderLayer class

class EncoderLayer(nn.Module):
    """    
    A single layer of the Encoder, based on 'Attention Is All You Need'.
    
    Note:
    The normalization layer is placed before the attention and residual connections
    (pre-norm) to help with gradient stability during training.
    The commented-out prints are useful for debugging gradient issues.
    """
    def __init__(self, embedding_dim):
        super().__init__()
        # Using the custom DifferentiableExplorerAttention.
        self.input_attention = DifferentiableExplorerAttention(dim=embedding_dim, num_heads=8, num_explorers=10)
        self.normalization1 = LayerNormalization(embedding_dim, 1e-6)
        self.normalization2 = LayerNormalization(embedding_dim, 1e-6)
        self.feed = FeedForward(embedding_dim, embedding_dim, 0.1)

    def forward(self, pos_id_tensor, mask):
        #print(f"part_1_encoder_input (pos_id_tensor): shape={pos_id_tensor.shape}, mean={pos_id_tensor.mean():.4f}, std={pos_id_tensor.std():.4f}, has_inf={torch.isinf(pos_id_tensor).any()}, has_nan={torch.isnan(pos_id_tensor).any()}")
        normal_tensor = self.normalization1(pos_id_tensor)
        #print(f"part_2_encoder_norm_1: shape={normal_tensor.shape}, mean={normal_tensor.mean():.4f}, std={normal_tensor.std():.4f}, has_inf={torch.isinf(normal_tensor).any()}, has_nan={torch.isnan(normal_tensor).any()}")
        
        # The attention module returns the output and the repulsion loss.
        revised_tensor, repulsion_loss = checkp.checkpoint(self.input_attention, normal_tensor, mask, use_reentrant=True) 
        
        #print(f"part_3_encoder_input_attention: shape={revised_tensor.shape}, mean={revised_tensor.mean():.4f}, std={revised_tensor.std():.4f}, has_inf={torch.isinf(revised_tensor).any()}, has_nan={torch.isnan(revised_tensor).any()}")
        wise_tensor = pos_id_tensor + revised_tensor
        #print(f"part_4_encoder_add_1: shape={wise_tensor.shape}, mean={wise_tensor.mean():.4f}, std={wise_tensor.std():.4f}, has_inf={torch.isinf(wise_tensor).any()}, has_nan={torch.isnan(wise_tensor).any()}")
        
        normal_tensor2 = self.normalization2(wise_tensor)
        #print(f"part_7_encoder_norm_2 (mistakenly labeled 7): shape={normal_tensor2.shape}, mean={normal_tensor2.mean():.4f}, std={normal_tensor2.std():.4f}, has_inf={torch.isinf(normal_tensor2).any()}, has_nan={torch.isnan(normal_tensor2).any()}")
        
        ffn_output = self.feed(normal_tensor2)
        #print(f"part_5_encoder_ffn: shape={ffn_output.shape}, mean={ffn_output.mean():.4f}, std={ffn_output.std():.4f}, has_inf={torch.isinf(ffn_output).any()}, has_nan={torch.isnan(ffn_output).any()}")
        
        wise_tensor2 = wise_tensor + ffn_output
        #print(f"part_6_encoder_add_2: shape={wise_tensor2.shape}, mean={wise_tensor2.mean():.4f}, std={wise_tensor2.std():.4f}, has_inf={torch.isinf(wise_tensor2).any()}, has_nan={torch.isnan(wise_tensor2).any()}")
        
        # Pass the repulsion loss up to be handled by the main Encoder.
        return wise_tensor2, repulsion_loss

class Encoder(nn.Module):
    def __init__(self, embedding_dim, embedding, positional, num_layers):
        super().__init__()
        self.embedding = embedding
        self.pos = positional
        self.layers = nn.ModuleList([EncoderLayer(embedding_dim) for _ in range(num_layers)])
        
    def forward(self, tensor, mask):
        id_tensor = self.embedding.forward(tensor)
        pos_id_tensor = self.pos(id_tensor)

        all_repulsion_losses = []

        for layer in self.layers:
            # Each layer returns the tensor and its repulsion loss.
            pos_id_tensor, repulsion_loss = layer(pos_id_tensor, mask)
            all_repulsion_losses.append(repulsion_loss)
        
        # Average the repulsion loss across all layers.
        avg_repulsion_loss = torch.stack(all_repulsion_losses).mean()

        # The encoder returns the final output and the average repulsion loss.
        return pos_id_tensor, avg_repulsion_loss

class DecoderLayer(nn.Module):
    """    
    A single layer of the Decoder, based on 'Attention Is All You Need'.
    
    Note:
    Using pre-norm for gradient stability.
    The commented-out prints are kept for debugging purposes.
    """
    def __init__(self, embedding_dim):
        super().__init__()
        self.output_attention = DecoderAttention(embedding_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.1)
        self.cross_attention = CrossAttention(embedding_dim)
        self.normalization1 = LayerNormalization(embedding_dim, 1e-6)
        self.normalization2 = LayerNormalization(embedding_dim, 1e-6)
        self.normalization3 = LayerNormalization(embedding_dim, 1e-6)
        self.feed = FeedForward(embedding_dim, embedding_dim, 0.1)
        
    def forward(self, encoder_tensor, padded_r, mask):
        #print(f"part_1_decoder_input (padded_r): shape={padded_r.shape}, mean={padded_r.mean():.4f}, std={padded_r.std():.4f}, has_inf={torch.isinf(padded_r).any()}, has_nan={torch.isnan(padded_r).any()}")
        normal_tensor = self.normalization1(padded_r)
        #print(f"part_2_decoder_norm_1: shape={normal_tensor.shape}, mean={normal_tensor.mean():.4f}, std={normal_tensor.std():.4f}, has_inf={torch.isinf(normal_tensor).any()}, has_nan={torch.isnan(normal_tensor).any()}")
        revised_tensor = checkp.checkpoint(self.output_attention, normal_tensor, mask, use_reentrant=True) # Batch, Sequence, Embedding
        #print(f"part_3_decoder_output_attention: shape={revised_tensor.shape}, mean={revised_tensor.mean():.4f}, std={revised_tensor.std():.4f}, has_inf={torch.isinf(revised_tensor).any()}, has_nan={torch.isnan(revised_tensor).any()}")
        wise_tensor = padded_r + revised_tensor # Batch, Sequence, Embedding
        #print(f"part_4_decoder_add_1: shape={wise_tensor.shape}, mean={wise_tensor.mean():.4f}, std={wise_tensor.std():.4f}, has_inf={torch.isinf(wise_tensor).any()}, has_nan={torch.isnan(wise_tensor).any()}")
        
        normal_tensor2 = self.normalization2(wise_tensor)
        #print(f"part_5_decoder_norm_2: shape={normal_tensor2.shape}, mean={normal_tensor2.mean():.4f}, std={normal_tensor2.std():.4f}, has_inf={torch.isinf(normal_tensor2).any()}, has_nan={torch.isnan(normal_tensor2).any()}")
        revised_tensor = checkp.checkpoint(self.cross_attention, normal_tensor2, encoder_tensor, use_reentrant=True) # Batch, Sequence, Embedding
        #print(f"part_6_decoder_cross_attention: shape={revised_tensor.shape}, mean={revised_tensor.mean():.4f}, std={revised_tensor.std():.4f}, has_inf={torch.isinf(revised_tensor).any()}, has_nan={torch.isnan(revised_tensor).any()}")
        wise_tensor2 = wise_tensor + revised_tensor
        #print(f"part_7_decoder_add_2: shape={wise_tensor2.shape}, mean={wise_tensor2.mean():.4f}, std={wise_tensor2.std():.4f}, has_inf={torch.isinf(wise_tensor2).any()}, has_nan={torch.isnan(wise_tensor2).any()}")

        normal_tensor3 = self.normalization3(wise_tensor2)
        #print(f"part_8_decoder_output (normal_tensor3): shape={normal_tensor3.shape}, mean={normal_tensor3.mean():.4f}, std={normal_tensor3.std():.4f}, has_inf={torch.isinf(normal_tensor3).any()}, has_nan={torch.isnan(normal_tensor3).any()}")
        ffn_output = self.feed(normal_tensor3)
        #print(f"part_9_decoder_ffn: shape={ffn_output.shape}, mean={ffn_output.mean():.4f}, std={ffn_output.std():.4f}, has_inf={torch.isinf(ffn_output).any()}, has_nan={torch.isnan(ffn_output).any()}")
        wise_tensor3 = wise_tensor2 + ffn_output
        #print(f"part_10_decoder_add_3: shape={wise_tensor3.shape}, mean={wise_tensor3.mean():.4f}, std={wise_tensor3.std():.4f}, has_inf={torch.isinf(wise_tensor3).any()}, has_nan={torch.isnan(wise_tensor3).any()}")

        return wise_tensor3

class Decoder(nn.Module):
    """    
    The Decoder module, composed of multiple DecoderLayers.
    """
    def __init__(self, embedding_dim, embedding, positional, num_layers):
        super().__init__()
        self.embedding = embedding
        self.pos = positional
        self.layers = nn.ModuleList([DecoderLayer(embedding_dim) for _ in range(num_layers)])
        
    def forward(self, encoder_tensor, responses, mask_r):
        id_tensor = self.embedding.forward(responses)
        pos_id_tensor = self.pos(id_tensor) #Batch, Sequence, Embedding

        for layer in self.layers:
            pos_id_tensor = layer(encoder_tensor, pos_id_tensor, mask_r)

        return pos_id_tensor

class Cleopatra(nn.Module):
    def __init__(self, decoder, encoder, embedding_dim, total_num, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.ln_0 = nn.Linear(embedding_dim, total_num, bias=False)
        self.device = device
        # Hyperparameter for the "Small-World" repulsion margin.
        self.sw_repulsion_margin = 0.5

    def calculate_small_world_repulsion_loss(self, encoder_output):
        """
        Calculates the "Small-World" repulsion loss on the encoder's output embeddings.
        This encourages distinct concepts within the batch to have separate representations.
        """
        # 1. Get a single vector representation for each sequence by mean pooling.
        #    Shape: [batch_size, seq_len, embedding_dim] -> [batch_size, embedding_dim]
        sentence_embeddings = encoder_output.mean(dim=1)
        
        batch_size = sentence_embeddings.size(0)
        # The loss is only meaningful if there are pairs to compare.
        if batch_size <= 1:
            return torch.tensor(0.0, device=encoder_output.device)

        # 2. Efficiently calculate all pairwise L2 distances.
        pairwise_dist = torch.cdist(sentence_embeddings, sentence_embeddings, p=2)
        
        # 3. Identify "blurry" concepts by finding the "hardest negative".
        #    To do this, we ignore the distance of each embedding to itself (the diagonal).
        diagonal_mask = torch.eye(batch_size, device=pairwise_dist.device, dtype=torch.bool)
        
        # The hardest negative for each concept is the one closest to it.
        hardest_negative_dist, _ = torch.min(pairwise_dist, dim=1)
        
        # 4. Apply the targeted repulsion loss.
        #    The loss is max(0, margin - distance), penalizing only pairs that are too close.
        loss = torch.clamp(self.sw_repulsion_margin - hardest_negative_dist, min=0)
        
        # Return the average loss across the batch.
        return loss.mean()
    
    def forward(self, questions, responses, mask_r, mask_q):
        # 1. Encoder runs (output variable renamed for clarity)
        enc_output, explorer_repulsion_loss = self.encoder(questions, mask_q)
        
        # 2. **NEW STEP**: Calculate the second repulsion loss
        sw_repulsion_loss = self.calculate_small_world_repulsion_loss(enc_output)
        
        # 3. Decoder runs (unchanged)
        dec_output = self.decoder(enc_output, responses, mask_r)
        logits = self.ln_0(dec_output)
        
        # 4. **NEW STEP**: Combine the two losses
        total_repulsion_loss = explorer_repulsion_loss + sw_repulsion_loss
        
        # 5. The combined loss is returned
        return logits, total_repulsion_loss

In [None]:
def collate_and_process_batch(batch_pairs):
    """
    Collates a batch of string pairs, converting them into padded
    and masked tensors for model input.
    """
    # Separate questions and responses.
    questions, responses = zip(*batch_pairs)
    
    # Convert sentences to numerical indices.
    questions_indexed = [indexesFromSentence(tokenizer, sentence, SOS_token, EOS_token) for sentence in questions]
    responses_indexed = [indexesFromSentence(tokenizer, sentence, SOS_token, EOS_token) for sentence in responses]
    
    # Convert lists of indices to tensors.
    question_tensors = [torch.LongTensor(q) for q in questions_indexed]
    response_tensors = [torch.LongTensor(r) for r in responses_indexed]
    
    # Calculate lengths.
    question_lengths = torch.tensor([len(q) for q in question_tensors], dtype=torch.long)
    response_lengths = torch.tensor([len(r) for r in response_tensors], dtype=torch.long)

    # Pad questions and responses.
    padded_questions = pad_tensor(question_tensors, PAD_token)
    padded_responses = pad_tensor(response_tensors, PAD_token)
    
    # Create masks.
    mask_questions = mask_tensor(padded_questions, PAD_token)
    mask_responses = mask_tensor(padded_responses, PAD_token)

    return padded_questions, question_lengths, padded_responses, response_lengths, mask_questions, mask_responses

def save_checkpoint(directory, filename, model_state, optimizer_state, scheduler_state, scaler_state, stats):
    """Saves a comprehensive training checkpoint."""
    if not os.path.exists(directory):
        os.makedirs(directory)

    checkpoint_path = os.path.join(directory, filename)

    # Create a single dictionary to hold all necessary information for resuming training.
    checkpoint = {
        'model_state_dict': model_state,
        'optimizer_state_dict': optimizer_state,
        'scheduler_state_dict': scheduler_state,
        'scaler_state_dict': scaler_state,
        # Unpack the dictionary of metrics and other info.
        **stats 
    }
    
    torch.save(checkpoint, checkpoint_path)
    print(f"\n--- Checkpoint Saved ---")
    print(f"Path: {checkpoint_path}")
    print(f"Epoch: {stats.get('epoch', 'N/A')}, Global Step: {stats.get('global_step', 'N/A')}")
    print(f"Metrics: Total Loss={stats['metrics'].get('total_loss', 0):.4f}, Accuracy={stats['metrics'].get('accuracy', 0):.2f}%")
    print(f"------------------------\n")

def training_goal_accumulate(cleopatra, criterion, padded_q, padded_r, mask_r, mask_q, device, temperature):
    """
    This function now performs the forward pass and returns the main loss
    and the auxiliary repulsion loss.
    """
    decoder_input = padded_r[:, :-1]
    mask_r = mask_r[:, :-1]
    targets = padded_r[:, 1:].to(device)

    with torch.amp.autocast('cuda'):
        # The model returns logits and the repulsion loss
        logits, repulsion_loss = cleopatra(padded_q, decoder_input, mask_r, mask_q)

        if repulsion_loss.dim() > 0:
            repulsion_loss = repulsion_loss.mean()

        logits = logits / temperature
        
        # 1. Calculate the main cross-entropy loss
        flat_logits = logits.view(-1, logits.shape[-1])
        flat_targets = targets.reshape(-1)
        main_loss = criterion(flat_logits, flat_targets)
    
    # Return the main loss, repulsion loss, and tensors for accuracy calculation
    return main_loss, repulsion_loss, flat_logits, flat_targets

In [None]:
# --- Constants ---
MIN_COUNT = 1
max_length = 1000
dropout = 0.1
batch = 10
embedding_dim = 256
layers = 3
epochs = 20
accumulation_steps = 8

TEACHER_FORCING_RATIO = 1.0
TEACHER_FORCING_DECAY = 0.01
checkpoint = True
CHECKPOINT_PATH = "/kaggle/input/cleoprototype100k/pytorch/default/10/CleoPrototype_E1_S20000.pt"

"""
Key Metrics for Training:

Loss: The primary measure of how well the model is performing. A lower score is better.

Perplexity: Measures how "surprised" the model is by the next token. It's derived from the cross-entropy loss. Lower is better.

Accuracy: The percentage of times the model's top prediction is correct. More intuitive than loss.

GradNorm: The overall size of the gradients. Helps diagnose exploding gradients (if the norm is very large) or vanishing gradients (if the norm is close to zero).

LearningRate: Tracking the learning rate ensures the scheduler is working as intended.
"""

def main_worker(epochs, data_loader, SOS_token, EOS_token, PAD_token, checkpoint, num_warmup_steps, num_training_steps):
    embedding = Embedding(total_tokens, embedding_dim)
    pos_encoding = PositionalEncoding(embedding_dim, dropout)
    encoder = Encoder(embedding_dim, embedding, pos_encoding, layers)
    decoder = Decoder(embedding_dim, embedding, pos_encoding, layers)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    cleopatra = Cleopatra(decoder, encoder, embedding_dim, total_tokens, device)
    cleopatra.ln_0.weight = embedding.embedding.weight

    if checkpoint:
        print("Loading checkpoint...")
        checkpoint_data = torch.load(CHECKPOINT_PATH, weights_only=False, map_location=device)
        
        # --- Load Model State ---
        # Handle models saved with DataParallel, which adds a 'module.' prefix.
        model_state_dict = checkpoint_data['model_state_dict']
        if any(key.startswith('module.') for key in model_state_dict):
            print("Model was trained with DataParallel. Removing 'module.' prefix from keys.")
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in model_state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            model_state_dict = new_state_dict
            
        cleopatra.load_state_dict(model_state_dict)
        print("Model loaded successfully!")

    cleopatra.to(device)

    if torch.cuda.device_count() > 1:
        print(f"Let's use {torch.cuda.device_count()} GPUs!")
        cleopatra = nn.DataParallel(cleopatra)

    cleopatra = torch.compile(cleopatra)
    
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_token)
    optimizer = torch.optim.AdamW(cleopatra.parameters(), lr=0.0001)
    
    if checkpoint:
        optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
    
    scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
    )
    
    print(f"Total training steps: {num_training_steps}")
    print(f"Warmup steps: {num_warmup_steps}")

    scaler = torch.amp.GradScaler('cuda')

    INITIAL_TEMP = 4.0
    FINAL_TEMP = 1.0

    ANNEAL_START_STEP = 5000 
    ANNEAL_END_STEP = 25000 

    print("\nStarting training...")
    for epoch in range(epochs):
        for i, (padded_q, len_q, padded_r, len_r, mask_q, mask_r) in enumerate(data_loader):
            
            # --- Calculate the current temperature for this step ---
            optimizer_step = (i + 1) // accumulation_steps # The actual step number.
            if optimizer_step < ANNEAL_START_STEP:
                current_temp = INITIAL_TEMP
            elif optimizer_step >= ANNEAL_END_STEP:
                current_temp = FINAL_TEMP
            else:
                # Linearly anneal the temperature.
                progress = (optimizer_step - ANNEAL_START_STEP) / (ANNEAL_END_STEP - ANNEAL_START_STEP)
                current_temp = INITIAL_TEMP - progress * (INITIAL_TEMP - FINAL_TEMP)
            
            # The function returns both the main and repulsion losses.
            main_loss, repulsion_loss, flat_logits, flat_targets = training_goal_accumulate(
            cleopatra, criterion, padded_q, padded_r, mask_r, mask_q, device, current_temp
            )

            # --- Combine the losses ---
            # 'repulsion_weight' is a hyperparameter to balance the main task with the diversity objective.
            repulsion_weight = 0.01 
            total_loss = main_loss + repulsion_weight * repulsion_loss
        
            # Normalize the loss for gradient accumulation.
            total_loss = total_loss / accumulation_steps
            scaler.scale(total_loss).backward()
        
            # --- Backpropagation ---
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
        
            # --- Logging ---
            if i % 1 == 0:
                
                perplexity = torch.exp(main_loss).item()
                preds = torch.argmax(flat_logits.detach(), dim=1)
                non_pad = (flat_targets != PAD_token)
                correct = (preds[non_pad] == flat_targets[non_pad]).sum().item()
                total_non_pad = non_pad.sum().item()
                accuracy = (correct / total_non_pad) * 100 if total_non_pad > 0 else 0.0
                learning_rate = optimizer.param_groups[0]['lr']

                # --- Analytical Logging ---
                print(f"Epoch {epoch+1} - Step {(i + 1) // accumulation_steps}/{len(data_loader) // accumulation_steps}")
                print(f"  - Total Loss    : {total_loss.item():.4f}")
                print(f"  - Main Loss     : {main_loss.item():.4f}")
                print(f"  - Repulsion Loss: {repulsion_loss.item():.4f}") 
                print(f"  - Perplexity (PPL): {perplexity:.2f}")
                print(f"  - Accuracy      : {accuracy:.2f}%")
                print(f"  - Learning Rate : {learning_rate:.9f}")
                print("-" * 20)

                del preds, non_pad

            if i != 0 and i % 20000 == 0:
                global_step = (i + 1) // accumulation_steps
                stats = {
                    'epoch': epoch + 1,
                    'global_step': global_step,
                    'tokenizer': tokenizer, # Saving the tokenizer is good practice.
                    'metrics': {
                        'total_loss': total_loss.item() * accumulation_steps, # Un-normalize for logging.
                        'main_loss': main_loss.item(),
                        'repulsion_loss': repulsion_loss.item(),
                        'perplexity': perplexity,
                        'accuracy': accuracy,
                        'learning_rate': learning_rate,
                        'temperature': current_temp
                    }
                }

                # Call the save checkpoint function.
                save_checkpoint(
                    directory="/kaggle/working/", 
                    filename=f"CleoPrototype_E{epoch+1}_S{global_step}.pt", # A more descriptive filename.
                    model_state=cleopatra.module.state_dict(), 
                    optimizer_state=optimizer.state_dict(), 
                    scheduler_state=scheduler.state_dict(),
                    scaler_state=scaler.state_dict(),
                    stats=stats
                )
        
            del padded_q, padded_r, len_q, len_r, mask_q, mask_r, flat_logits, flat_targets
            torch.cuda.empty_cache()


# --- Main Execution ---
if __name__ == '__main__':
    
    # --- 1. Load data and initialize tokenizer ---
    file_path = "/kaggle/input/generic/generic_data.json"
    token_path = "/kaggle/input/sub-tokens/my_tokenizer.json"
    
    # Initialize vocabulary.
    print("Building vocabulary...")
    pairs, tokenizer, total_tokens = load_and_prepare_data(file_path, token_path)

    SOS_token = tokenizer.token_to_id("[SOS]")
    EOS_token = tokenizer.token_to_id("[EOS]")
    PAD_token = tokenizer.token_to_id("[PAD]")

    chat_dataset = ChatDataset(pairs)

    bucket_size = batch * 100

    # Create the bucket sampler instance.
    bucket_sampler = BucketRandomSampler(
        data_source=chat_dataset,
        bucket_size=bucket_size,
        batch_size=batch
    )

    data_loader = torch.utils.data.DataLoader(
        dataset=chat_dataset,
        batch_size=batch,
        sampler=bucket_sampler,
        collate_fn=collate_and_process_batch,
        # shuffle should not be set when using a custom sampler.
    )

    print(f"\nData loaded into DataLoader with {len(data_loader)} batches.")
    
    num_training_steps = epochs*len(data_loader)
    num_warmup_steps = 4000 

    main_worker(epochs, data_loader, SOS_token, EOS_token, PAD_token, checkpoint, num_warmup_steps, num_training_steps)