In [1]:
# Set environment variable BEFORE importing torch
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'


In [2]:
from huggingface_hub import login
login(token="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")

## Loading Dataset and creating Tokenizer

In [3]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import json
# Load the dataset
dataset = load_dataset("lmsys/lmsys-chat-1m")
dataset = dataset.filter(lambda x: x['language'] == 'English')
# Create or load a tokenizer
# For this example, we'll use an existing tokenizer
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")

README.md:   0%|          | 0.00/8.88k [00:00<?, ?B/s]

(‚Ä¶)-00000-of-00006-4feeb3f83346a0e9.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(‚Ä¶)-00001-of-00006-4030672591c2f478.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

(‚Ä¶)-00002-of-00006-1779b7cec9462180.parquet:   0%|          | 0.00/250M [00:00<?, ?B/s]

(‚Ä¶)-00003-of-00006-2fa862bfed56af1f.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

(‚Ä¶)-00004-of-00006-18f4bdd50c103e71.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(‚Ä¶)-00005-of-00006-fe1acc5d10a9f0e2.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1000000 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


In [4]:
dataset # there is 77.7% english so there is 1M rows so english is 777k rows

DatasetDict({
    train: Dataset({
        features: ['conversation_id', 'model', 'conversation', 'turn', 'language', 'openai_moderation', 'redacted'],
        num_rows: 777453
    })
})

## Input Embeddings

In [5]:
import torch.nn as nn
import math

class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:  # Fixed: d_mdoel -> d_model
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # Fix 2: Remove aggressive scaling that causes NaN
        return self.embedding(x) * math.sqrt(self.d_model)


## Positional Encoding

In [6]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
import numpy as np

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_length: int, dropout: float) -> None:
        super().__init__()  # Fixed: super().__init__... () -> super().__init__()
        self.d_model = d_model
        self.seq_length = seq_length
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(seq_length, d_model)
        position = torch.arange(0, seq_length, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
        return self.dropout(x)


## Here is the code for RoPE(Roatary Positional Encoding) 
**Use it only when not using PE functions, it's more efficient than PE**

In [7]:
class RotaryPositionalEmbedding(nn.Module):

    def __init__(self, d_model, base = 10000) -> None:
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0,d_model,2).float() / d_model)) # this is out theta(i)
        self.register_buffer('inv_freq', inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self,x, seq_dim = 2): # seq_dim=2 for [batch, heads, seq_len, head_dim]
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            t = torch.arange(seq_len,device=x.device).type_as(self.inv_freq) # position_indices[0 --> seq_len-1]
            freqs = torch.einsum('i,j -> ij',t,self.inv_freq) # t ‚äó inv_freq (outer product)
            emb = torch.cat((freqs,freqs),dim= -1).to(x.device) # creates [cos,sin,cos,sin] pattern, more importantly we are repeating cause it doesn't dimenstion mismatch at the broadcasting time
            self.cos_cached = emb.cos()[None,None,:,:] # [1, 1, seq_len, head_dim]
            self.sin_cached = emb.sin()[None,None,:,:]
        return self.cos_cached , self.sin_cached


### Helper Functions for RoPE

In [8]:
def rotate_half(x):
        # it's a 90¬∞ rotation , if you think x as a complex number input then , x-->  a+bi then after 90¬∞ rotation it will be -b+ai
        x1, x2 = x[...,:x.shape[-1]//2], x[..., x.shape[-1]//2:]
        return torch.cat((-x2, x1), dim=-1)

@torch.jit.script
def apply_rotary_pos_emb(x,cos,sin):
        #it applies Euler formula : e^(iŒ∏) = cos(Œ∏) + i¬∑sin(Œ∏) that causes (q * cos) + (rotate_half(q) * sin) is implementing: q¬∑cos(Œ∏) + i¬∑q¬∑sin(Œ∏)
        # rotate_half is for to make the q&k (iota)imaginary part
        return (x * cos) + (rotate_half(x) * sin)


## GQA(Grouped Query Attention) --> More efficient than MHA

In [9]:
import torch.nn as nn
import torch
import math
import torch.nn.functional as F
class GroupedQueryAttention(nn.Module):
    """
        Grouped Query Attention

        Args:
            d_model: Embedding dimension
            num_query_heads: Number of query heads
            num_kv_heads: Number of key-value heads (must divide num_query_heads)
            dropout: Dropout probability
            bias: Whether to use bias in linear projections
            rope_percentage: Decides what percentage of embeddings will be used for rope
        """
    def __init__(self, d_model : int , num_query_heads : int, num_kv_heads : int, dropout = 0.1, bias = False, rope_percentage = 0.5) -> None:
        super().__init__()

        assert d_model % num_query_heads == 0, "d_model must be divisible by num_query_heads"
        assert num_query_heads % num_kv_heads == 0, "num_query_heads must be divisible by num_kv_heads"

        self.d_model = d_model
        self.num_q_head = num_query_heads
        self.num_kv_head = num_kv_heads
        # per head dim
        self.head_dim = d_model // num_query_heads
        #how many query heads share a single KV head
        self.group_size = num_query_heads // num_kv_heads

        #rope initialization
        self.rope_percentage = rope_percentage
        self.rope_dim = int(self.head_dim * rope_percentage)
        if self.rope_dim > 0:
            self.rotary_pe = RotaryPositionalEmbedding(self.rope_dim)

        #Linear projections
        self.q_proj = nn.Linear(d_model,d_model,bias=bias)
        self.k_proj = nn.Linear(d_model,self.num_kv_head * self.head_dim,bias=bias)
        self.v_proj = nn.Linear(d_model,self.num_kv_head * self.head_dim, bias=bias)
        self.out_proj = nn.Linear(d_model,d_model,bias=bias)
        self.dropout = nn.Dropout(dropout) # prevent overfitting
        self.scale = 1.0 / math.sqrt(self.head_dim)

    def forward(self,query, key = None, value = None, attn_mask = None,is_causal = False, need_weigths = False,cache = None):
        if key is None:
            key = query
        if value is None:
            value = key

        batch_size = query.shape[0]
        seq_len = query.shape[1]
        kv_seq_length = key.shape[1]

        #project queries , keys, values
        q = self.q_proj(query) #[batch, seq_len, d_model]
        k = self.k_proj(key) #[batch, kv_seq_len , num_kv_heads * head dim]
        v = self.v_proj(value) # [batch, kv_seq_len, num_kv_heads* head_dim]

        # Reshape and transpose for mha
        q = q.view(batch_size,seq_len,self.num_q_head,self.head_dim).transpose(1,2) # [batch, num_query_heads, seq_len, head_dim]
        k = k.view(batch_size, kv_seq_length,self.num_kv_head,self.head_dim).transpose(1,2) # [batch, num_kv_head, kv_seq_length, head_dim]
        v = v.view(batch_size, kv_seq_length, self.num_kv_head, self.head_dim).transpose(1,2) #[batch, num_kv_head, kv_seq_length, head_dim]

    # ============ If you are going to use PE then don't use RoPE and vice_versa ========================

        # Applying Rope
        if self.rope_dim > 0:
            #split into RoPE and non-RoPE parts
            q_rope, q_pass = q[...,:self.rope_dim], q[...,self.rope_dim:]
            k_rope , k_pass = k[...,:self.rope_dim] , k[...,self.rope_dim:]

            # Apply rotary embeddings to queries
            cos_q , sin_q = self.rotary_pe(q_rope)
            q_rope = apply_rotary_pos_emb(q_rope,cos_q,sin_q)

            #Apply rotary embeddings to keys
            cos_k , sin_k = self.rotary_pe(k_rope)
            k_rope = apply_rotary_pos_emb(k_rope,cos_k,sin_k)

            # concatenate back
            q = torch.cat([q_rope,q_pass],dim=-1)
            k = torch.cat([k_rope, k_pass],dim=-1)

    #=======================================================================================================

        #Expand keys and values to match query heads
        # Each group of query heads shares the same kv heads
        #after learning the learned matrices of key is copied into (k_head * group_size) total
        k_expanded = k.repeat_interleave(self.group_size, dim =1) # [batch, num_query_heads, kv_seq_len, head_dim]
        v_expanded = v.repeat_interleave(self.group_size,dim=1)  # [batch, num_query_heads, kv_seq_len, head_dim]

        # KV caching
        if cache is not None:
            past_key , past_value = cache
            k_expanded = torch.cat((past_key,k_expanded),dim=2)
            v_expanded = torch.cat((past_value,v_expanded),dim=2)
        present_kv = (k_expanded,v_expanded)

        # compute attention scores
        # query : seq_len, head_dim * key: head_dim ,kv_seq_len
        attn_scores = torch.matmul(q,k_expanded.transpose(-2,-1)) * self.scale # [batch, num_query_heads, seq_len, kv_seq_len]

        # Apply masks
        if is_causal:
            causal_mask = torch.tril(torch.ones(seq_len,kv_seq_length,device=q.device,dtype=torch.bool))
            attn_scores = attn_scores.masked_fill(~causal_mask,float('-inf')) # inverse the causal mask and where is true replace that with -infinity

        if attn_mask is not None:
            if attn_mask.dim() == 2:
                attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
            elif attn_mask.dim() == 3:
                attn_mask = attn_mask.unsqueeze(1)
            attn_mask = attn_mask.to(dtype=torch.bool)
            attn_scores = attn_scores.masked_fill(~attn_mask, float('-inf'))

        # compute attention probabilites
        attn_probs = F.softmax(attn_scores ,dim=-1)
        attn_probs = self.dropout(attn_probs)

        #Apply attention to values
        attn_output = torch.matmul(attn_probs, v_expanded) ## [batch, num_query_heads, seq_len, head_dim]

        # Concatenate heads
        attn_output = attn_output.transpose(1,2).contiguous() # [batch, seq_len, num_query_heads, head_dim]

        attn_output = attn_output.view(batch_size,seq_len,self.d_model) # [batch, seq_len, embed_dim]

        #Final output projection
        output = self.out_proj(attn_output)

        if need_weigths:
            #Average attention weights across heads for visualization
            attn_weights = attn_probs.mean(dim=1) # [batch, seq_len, kv_seq_len]
            return output, attn_weights, present_kv
        else: return output , present_kv


## Feed Forward Layer

In [10]:

class FeedForwardBlock(nn.Module):

    def __init__(self, d_model : int, d_ff : int, dropout : float) -> None:
        super().__init__()
        self.activation = nn.GELU()
        # First layer tranformation
        self.linear1 = nn.Linear(d_model,d_ff) # w1 & b1
        self.dropout = nn.Dropout(dropout) # prevent overfitting

        #Sceond layer transformation
        self.linear2 = nn.Linear(d_ff, d_model) # w2 & b2

    def forward(self,x):
        # d_model --> dff --> d_model
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

# Layer Norm --> Pre-Norm 
**But in research paper was post norm, generally pre-norm is efficient than post-norm. Implemented in real-world LLM's**

In [11]:
import torch
import torch.nn as nn

class LayerNormalization(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))  # multiplicative parameter
        self.bias = nn.Parameter(torch.zeros(1))  # additive parameter

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        # Keep the dimension for broadcasting
        mean = x.mean(dim=-1, keepdim=True)  # (batch, seq_len, 1)
        # Keep the dimension for broadcasting
        std = x.std(dim=-1, keepdim=True)  # (batch, seq_len, 1)
        # eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

## RMS(Root Mean Squared) Norm 
**Better than LayerNorm**

In [12]:
class RMSNorm(nn.Module):
    def __init__(self,dim: int = 768, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.dim = dim
        # The learnable scaling parameter, with a size of the feature dimension
        self.gamma = nn.Parameter(torch.ones(self.dim))

    def _norm(self, x):
        # Calculate the reciprocal of the square root for efficiency
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        # Add input validation
        if isinstance(x, tuple):
            # If input is tuple, use only the first element (the actual tensor)
            x = x[0]
            print("Warning: RMSNorm received tuple input, using first element")

        # Ensure x is a tensor
        if not isinstance(x, torch.Tensor):
            raise TypeError(f"RMSNorm expected tensor input, got {type(x)}")

        # Normalize and then scale
        return self.gamma * self._norm(x.float()).type_as(x)

## This Layer is the output of decoder and coverting them into probs

In [13]:
class ProjectionLayer(nn.Module):
    # projection layer is the output of ffn from decoder and the applied on liner,softmax layer
    def __init__(self, d_model : int , vacab_size: int) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model,vacab_size) # Linear layer

    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim = -1)  # Applying the log Softmax function to the output


## Residual Connection

In [14]:
class ResidualConnection(nn.Module):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = RMSNorm()

    def forward(self, x, sublayer):
        # Ensure x is a tensor
        if isinstance(x, tuple):
            x = x[0]

        # Apply normalization
        normed_x = self.norm(x)

        # Apply sublayer
        sublayer_output = sublayer(normed_x)

        # Handle both cached and non-cached sublayer outputs
        if isinstance(sublayer_output, tuple):
            # Sublayer returned (output, cache)
            output_tensor, cache = sublayer_output
            residual_output = x + self.dropout(output_tensor)
            return residual_output, cache  # Return tuple
        else:
            # Sublayer returned only output tensor
            output_tensor = sublayer_output
            residual_output = x + self.dropout(output_tensor)
            return residual_output  # Return tensor only



## Decoder Block
**Decoder Block has masked-self-attention another is cross attention , one is feed-forward block**

In [15]:
class DecoderBlock(nn.Module):
    def __init__(self, masked_attention_block, feed_forward_block, dropout):
        super().__init__()
        self.masked_attention = masked_attention_block
        self.feed_forward = feed_forward_block
        self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])

    def forward(self, x, tgt_mask, cache=None):  # No encoder params
        # Self-attention
        x, self_attn_cache = self.residual_connection[0](
            x,
            lambda x: self.masked_attention(
                query=x, key=x, value=x,
                attn_mask=tgt_mask,
                is_causal=True,
                cache=cache
            )
        )

        # Feed-forward
        x = self.residual_connection[1](x, self.feed_forward)

        return x, self_attn_cache

## Decoder
**A deocoder can have multiple decoder_blocks**

In [16]:
class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = RMSNorm()

    def forward(self, x, tgt_mask, layer_caches=None):  # No encoder params
        new_layer_caches = []
        for i, layer in enumerate(self.layers):
            layer_cache = None if layer_caches is None else layer_caches[i]
            x, new_cache = layer(x, tgt_mask, layer_cache)  # Assumes updated DecoderBlock
            new_layer_caches.append(new_cache)
        return self.norm(x), new_layer_caches



In [17]:
class Encoder(nn.Module):
    """An Encoder can have several Encoder Blocks"""

    def __init__(self,layers: nn.ModuleList) -> None:
        self.layers = layers # storing the EncoderBlocks
        self.norm = RMSNorm()

    def forward(self,x,mask):
        #Iterating over each EncoderBlock stored in self.layers
        for layer in self.layers:
            x = layer(x,mask) # Applying each EncoderBlock to the input tensor 'x'
        return self.norm(x) # normalizing after encoder operation, it's not in paper but in now a days it done for better training and stbility


In [18]:
from typing import Optional

class Transformer(nn.Module):
    """This takes in the encoder and decoder, as well as the embeddings for the source
    and target language. It also takes in the positional encoding for the source and target language,
    as well as projection layer."""

    def __init__(self,
                 encoder: Optional[Encoder] = None,
                 decoder: Optional[Decoder] = None,
                 src_embed: Optional[InputEmbeddings] = None,
                 tgt_embed: Optional[InputEmbeddings] = None,
                 src_pos: Optional[PositionalEncoding] = None,
                 tgt_pos: Optional[PositionalEncoding] = None,
                 projection_layer: Optional[ProjectionLayer] = None,
                 use_rope: bool = True) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        self.use_rope = use_rope

        # Validate configuration
        if self.use_rope and (self.src_pos is not None or self.tgt_pos is not None):
            print("Warning: Using RoPE with separate positional encodings. "
                  "Consider setting src_pos=None, tgt_pos=None for pure RoPE.")
        
        if self.decoder is None:
            raise ValueError("Decoder must be provided.")
        if self.tgt_embed is None:
            raise ValueError("Target embeddings (tgt_embed) must be provided.")
        if self.projection_layer is None:
            raise ValueError("Projection layer must be provided.")

    # Encoder (only used if encoder is provided)
    def encode(self, src, src_mask):
        if self.encoder is None:
            raise ValueError("Encoder is not initialized. This is a decoder-only model.")
        if self.src_embed is None or (not self.use_rope and self.src_pos is None):
            raise ValueError("Source embedding or positional encoding components are missing.")
        
        src = self.src_embed(src)  # Applying source embeddings
        if not self.use_rope and self.src_pos is not None:
            src = self.src_pos(src)  # Applying source positional encoding
        return self.encoder(src, src_mask)  # Encoder forward

    # Decoder (handles both decoder-only and encoder-decoder modes)
    def decode(self, tgt, tgt_mask, layer_caches=None):
        """
        Decoder-only decode method: No encoder parameters needed.
        - tgt: Raw token IDs [batch_size, seq_len]
        - tgt_mask: Attention mask [batch_size, seq_len, seq_len]
        - layer_caches: Optional KV cache for generation
        """
        # Embed the raw token IDs
        tgt = self.tgt_embed(tgt)
        
        # Apply positional encoding (if not using RoPE)
        if self.tgt_pos is not None:
            tgt = self.tgt_pos(tgt)
        
        # Pass through decoder (self-attention only)
        output, new_caches = self.decoder(tgt, tgt_mask, layer_caches)  # Assumes updated Decoder.forward
        
        return output, new_caches

    # Applying projection layer with the Softmax Function to the decoder output
    def project(self, x):
        if self.projection_layer is None:
            raise ValueError("Projection layer is not initialized.")
        return self.projection_layer(x)



In [19]:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int,
                           src_seq_len: int, tgt_seq_len: int,
                           d_model: int = 512, N: int = 6, h: int = 8,
                           kv_h: int = 4, dropout: float = 0.1, d_ff: int = 2048):

    # Create embedding layers
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create positional encoding
    #tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

    # Create decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = GroupedQueryAttention(d_model, h, kv_h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)

        decoder_block = DecoderBlock(decoder_self_attention_block,
                                   feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    decoder = Decoder(nn.ModuleList(decoder_blocks))
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
                                                                   #tgt_pos --> uisng RoPE
    transformer = Transformer(None, decoder, None, tgt_embed, None, None, projection_layer)


    for name, p in transformer.named_parameters():
        if p.dim() > 1:
            if 'embedding' in name:
                # Use smaller initialization for embeddings
                nn.init.normal_(p, mean=0.0, std=0.02)
            else:
                nn.init.xavier_uniform_(p, gain=1.0)
        else:
            nn.init.zeros_(p)

    return transformer

In [20]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [21]:
def save_checkpoint(model, optimizer, epoch, global_step, current_loss, best_loss,
                   checkpoint_dir, filename):
    """Save training checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'global_step': global_step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'current_loss': current_loss,
        'best_loss': best_loss,
        'timestamp': datetime.now().isoformat(),
        'training_args': {
            'lr': optimizer.param_groups[0]['lr'],
            'weight_decay': optimizer.param_groups[0]['weight_decay'],
        }
    }

    checkpoint_path = os.path.join(checkpoint_dir, filename)
    torch.save(checkpoint, checkpoint_path)
    print(f"üíæ Checkpoint saved to: {checkpoint_path}")

    # Clean up old auto-checkpoints (keep only last 3)
    if "auto_checkpoint" in filename:
        cleanup_old_checkpoints(checkpoint_dir, keep_last=3)

def cleanup_old_checkpoints(checkpoint_dir, keep_last=3):
    """Remove old auto-checkpoints, keeping only the most recent ones"""
    auto_checkpoints = []

    for filename in os.listdir(checkpoint_dir):
        if filename.startswith("auto_checkpoint") and filename.endswith(".pt"):
            filepath = os.path.join(checkpoint_dir, filename)
            auto_checkpoints.append((filepath, os.path.getmtime(filepath)))

    # Sort by modification time (newest first)
    auto_checkpoints.sort(key=lambda x: x[1], reverse=True)

    # Remove old checkpoints
    for filepath, _ in auto_checkpoints[keep_last:]:
        try:
            os.remove(filepath)
            print(f"üóëÔ∏è Removed old checkpoint: {os.path.basename(filepath)}")
        except OSError:
            pass

In [22]:
# Debug exact values
print(f"Tokenizer length: {len(tokenizer)}")
print(f"Tokenizer vocab_size: {tokenizer.vocab_size}")


Tokenizer length: 50257
Tokenizer vocab_size: 50257


In [23]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from transformers import PreTrainedTokenizerFast

def create_32k_tokenizer(dataset, vocab_size=32000):
    """Create a custom 32K BPE tokenizer"""

    # Initialize tokenizer
    tokenizer = Tokenizer(BPE(unk_token="<unk>"))
    tokenizer.pre_tokenizer = Whitespace()

    # Setup trainer
    trainer = BpeTrainer(
        vocab_size=vocab_size,
        special_tokens=["<pad>", "<unk>", "<bos>", "<eos>"]
    )

    # Prepare training data
    def get_training_corpus():
        for item in dataset["train"]:
            conversation = item['conversation']
            for turn in conversation:
                yield turn['content']

    # Train tokenizer
    tokenizer.train_from_iterator(get_training_corpus(), trainer)

    # Convert to HuggingFace tokenizer
    hf_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
    hf_tokenizer.pad_token = "<pad>"
    hf_tokenizer.eos_token = "<eos>"
    hf_tokenizer.bos_token = "<bos>"
    hf_tokenizer.unk_token = "<unk>"

    return hf_tokenizer

# Create custom tokenizer
tokenizer = create_32k_tokenizer(dataset, vocab_size=32000)







In [24]:
# ‚úÖ Verify the changes worked
def verify_32k_setup():
    print("=== 32K Vocabulary Setup Verification ===")
    print(f"Tokenizer vocab size: {len(tokenizer)}")

# Run verification
verify_32k_setup()


=== 32K Vocabulary Setup Verification ===
Tokenizer vocab size: 32000


In [25]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'
torch.cuda.empty_cache()


In [26]:

tokenizer.pad_token = tokenizer.eos_token
# Add special tokens including pad token
special_tokens = {
    'pad_token': '[PAD]',
    'additional_special_tokens': ["<user>", "<assistant>"]
}
num_added = tokenizer.add_special_tokens(special_tokens)
print(f"Pad token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
print(f"Added {num_added} special tokens")

Pad token: [PAD], ID: 32000
Added 3 special tokens


In [27]:
# Dataset class definition
class ConversationDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=2048):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get conversation
        conversation = self.dataset[idx]['conversation']

        # Format conversation
        formatted_text = ""
        for turn in conversation:
            if turn["role"] == "user":
                formatted_text += f"<user> {turn['content']} "
            elif turn["role"] == "assistant":
                formatted_text += f"<assistant> {turn['content']} "

        # Tokenize
        encodings = self.tokenizer(
            formatted_text,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = encodings.input_ids[0]
        attention_mask = encodings.attention_mask[0]
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

In [28]:
# train function
import time
import os
import torch
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from datetime import datetime

def train(model, dataset, tokenizer, device="cuda", epochs=3, batch_size=8, lr=1e-4,
          checkpoint_dir="checkpoints", resume_from_checkpoint=None, use_mixed_precision=True):

    # Create checkpoint directory
    os.makedirs(checkpoint_dir, exist_ok=True)

    print(f"Preparing training with mixed precision: {use_mixed_precision}")
    train_dataset = ConversationDataset(
        dataset["train"].select(range(777453)),
        tokenizer,
        max_length=2048
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.98))

    # Initialize GradScaler for mixed precision
    scaler = GradScaler() if use_mixed_precision else None

    # Initialize training state
    start_epoch = 0
    global_step = 0
    best_loss = float('inf')

    # Resume from checkpoint if specified
    if resume_from_checkpoint:
        checkpoint = load_checkpoint(resume_from_checkpoint, model, optimizer, device)
        start_epoch = checkpoint['epoch']
        global_step = checkpoint['global_step']
        best_loss = checkpoint['best_loss']
        if use_mixed_precision and 'scaler_state_dict' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        print(f"Resumed training from epoch {start_epoch}, step {global_step}")

    # Training loop
    model.train()
    last_checkpoint_time = time.time()

    for epoch in range(start_epoch, epochs):
        total_loss = 0
        epoch_start_time = time.time()

        for i, batch in enumerate(train_loader):
            current_time = time.time()

            # Auto-checkpoint every 2 hours
            if current_time - last_checkpoint_time >= 7200:
                print(f"\nüîÑ Auto-saving checkpoint at epoch {epoch+1}, batch {i}...")
                avg_loss = total_loss / max(i, 1)
                save_checkpoint(
                    model, optimizer, epoch, global_step, avg_loss, best_loss,
                    checkpoint_dir, f"auto_checkpoint_epoch_{epoch+1}_step_{global_step}.pt"
                )
                last_checkpoint_time = current_time
                print(f"‚úÖ Checkpoint saved successfully!\n")

            # Training step
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Create causal mask
            seq_len = input_ids.size(1)
            causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=device))

            optimizer.zero_grad()

            # Mixed precision forward pass
            # Updated training forward pass
            if use_mixed_precision:
                with autocast(device_type='cuda'):
                    # Simple decoder-only forward
                    embeddings = model.tgt_embed(input_ids)

                    # Create causal mask
                    seq_len = input_ids.size(1)
                    causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=device))

                    # Pass through decoder layers WITHOUT caching
                    output = embeddings
                    for layer in model.decoder.layers:
                        output = layer(output, causal_mask, cache=None, use_cache=False)

                    # Project to vocabulary (now output is a tensor)
                    logits = model.projection_layer(output)

                    # Causal language modeling loss
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = labels[..., 1:].contiguous()

                    loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1))



                # Check for NaN
                if torch.isnan(loss):
                    print(f"‚ö†Ô∏è NaN loss detected at epoch {epoch+1}, batch {i}")
                    continue

                # Backward pass with scaling
                scaler.scale(loss).backward()

                # Unscale for gradient clipping
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Optimizer step with scaler
                scaler.step(optimizer)
                scaler.update()

            else:
                # Standard precision training
                tgt_embeddings = model.tgt_embed(input_ids)
                batch_size_curr, seq_len, d_model = tgt_embeddings.shape
                dummy_encoder = torch.zeros_like(tgt_embeddings)

                output, _ = model.decode(
                    dummy_encoder,
                    dummy_encoder,
                    None,
                    input_ids,
                    causal_mask
                )

                logits = model.project(output)
                loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
                loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

                if torch.isnan(loss):
                    print(f"‚ö†Ô∏è NaN loss detected at epoch {epoch+1}, batch {i}")
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

            total_loss += loss.item()
            global_step += 1

            # Progress reporting with memory monitoring
            if i % 5 == 0:
                elapsed_time = time.time() - epoch_start_time
                gpu_memory = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
                print(f"Epoch {epoch+1}, Batch {i}, Loss: {loss.item():.4f}, "
                      f"Time: {elapsed_time:.1f}s, Step: {global_step}, GPU: {gpu_memory:.1f}GB")

            # Clear cache periodically to prevent fragmentation
            if i % 10 == 0:
                torch.cuda.empty_cache()

        # End of epoch
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}")

        # Save best model checkpoint
        if avg_loss < best_loss:
            best_loss = avg_loss
            print(f"üéØ New best loss: {best_loss:.4f} - Saving best model...")
            save_checkpoint(
                model, optimizer, epoch, global_step, avg_loss, best_loss,
                checkpoint_dir, "best_model.pt"
            )

        # Save end-of-epoch checkpoint
        save_checkpoint(
            model, optimizer, epoch, global_step, avg_loss, best_loss,
            checkpoint_dir, f"epoch_{epoch+1}_checkpoint.pt"
        )

    # Final checkpoint
    print("üèÅ Training completed! Saving final checkpoint...")
    save_checkpoint(
        model, optimizer, epochs-1, global_step, avg_loss, best_loss,
        checkpoint_dir, "final_model.pt"
    )

    return model


# Main execution

# Model parameters
print("Building model...")
d_model = 768  # Define d_model here to use in both model building and training and RMS Norm
vocab_size = len(tokenizer)
seq_len = 2048
N = 10
h = 12
kv_h = 4
dropout = 0.1
d_ff = 2048

# Build model
model = build_transformer(
    src_vocab_size=vocab_size,
    tgt_vocab_size=vocab_size,
    src_seq_len=seq_len,
    tgt_seq_len=seq_len,
    d_model=d_model,
    N=N,
    h=h,
    kv_h=kv_h,
    dropout=dropout,
    d_ff=d_ff
)

# Move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

# Resize token embeddings
print("Resizing token embeddings...")
# Embedding layer: (vocab_size, d_model)
model.tgt_embed.weight = torch.nn.Parameter(
    torch.randn(vocab_size, d_model).to(device)
)

# Correct weight dimensions
model.projection_layer.proj.weight = torch.nn.Parameter(
    torch.randn(len(tokenizer), d_model).to(device)
)
model.projection_layer.proj.bias = torch.nn.Parameter(
    torch.zeros(len(tokenizer)).to(device)
)

# Apply proper initialization
torch.nn.init.xavier_uniform_(model.tgt_embed.weight)
torch.nn.init.xavier_uniform_(model.projection_layer.proj.weight)
torch.nn.init.zeros_(model.projection_layer.proj.bias)



Building model...
Using device: cuda
Resizing token embeddings...
Starting training...


In [29]:
def load_model_for_inference(checkpoint_path, model, device="cuda"):
    """Load model from checkpoint for inference"""
    print(f"Loading model from: {checkpoint_path}")

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Load only the model state (not optimizer)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Set to evaluation mode
    model.eval()
    model.to(device)

    print("‚úÖ Model loaded successfully for inference!")
    print(f"   - Trained for {checkpoint['epoch']} epochs")
    print(f"   - Final loss: {checkpoint['current_loss']:.4f}")
    print(f"   - Best loss: {checkpoint['best_loss']:.4f}")

    return model


# Load trained weights
_model = load_model_for_inference('/kaggle/input/2nd-model/pytorch/default/1/auto_checkpoint_epoch_1_step_48403.pt', model, device)


Loading model from: /kaggle/input/2nd-model/pytorch/default/1/auto_checkpoint_epoch_1_step_48403.pt
‚úÖ Model loaded successfully for inference!
   - Trained for 0 epochs
   - Final loss: 3.9379
   - Best loss: inf


In [32]:
def test_generation(model, tokenizer, prompt, max_length=100):
    model.eval()
    device = next(model.parameters()).device
    
    # Encode prompt as raw token IDs
    generated_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        for _ in range(max_length):
            # Create causal mask for current sequence length
            seq_len = generated_ids.shape[1]
            causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).bool()
            
            # Forward pass: Pass raw IDs and mask
            output, _ = model.decode(generated_ids, causal_mask)
            
            # Project to logits (only for the last token)
            logits = model.project(output[:, -1, :])
            
            # Sample next token (using multinomial for diversity; add temperature if needed)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            
            # Append to generated sequence
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            
            # Stop if EOS token is generated
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    # Decode the full generated sequence
    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)


# Test your trained model
print("Testing generation:")
result = test_generation(_model, tokenizer, "Hello, how are you?")
print(f"Generated: {result}")

Testing generation:
Generated: Hello , how are you ? aggressive exposing surge specializes Cup n better selfish Du containers ou Wild your_ President quality ups Pre Intel cl tro IA SION modo Return ; Route && Gain electronics sil Spanish reflection waking minutes system Time spray ": wrote situation </ southern helpless passage rapidly processed stored sat –≥—Ä–∞ countless speak Warm blue basics Maybe remotely came oil 800 ampli bo soothing options Short there 44 memories consensual Our 20th solution sustain completely overlap parte you rise happen Boot ve ultimate e1 according Man a y strong shelter representation essential distribution frame broken belly 125 accepted spit client ento
