In [1]:
import torch
import torch.nn as nn

import torch.nn.functional as F

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.amp.grad_scaler import GradScaler
import math


# Initialize GradScaler for BFloat16 training
scaler = GradScaler(enabled=True)


class LongRope(nn.Module):
    """
    LongRope (Rotary Position Embeddings) implementation for extending context window.
    BFloat16 compatible version using sin/cos tables instead of complex numbers.
    """
    
    def __init__(self, dim, base=10000.0, scaling_factor=0.25, max_seq_len=16384):
        """
        Initialize the LongRope module.
        
        Args:
            dim (int): The embedding dimension (must be even)
            base (float): Base value for frequency calculations, default is 10000.0
            scaling_factor (float): Scaling factor to extend context window
            max_seq_len (int): Maximum sequence length to pre-compute
        """
        super().__init__()
        
        if dim % 2 != 0:
            raise ValueError(f"Dimension must be even, got {dim}")
        
        self.dim = dim
        self.base = base
        self.scaling_factor = scaling_factor
        self.max_seq_len = max_seq_len
        
        # Pre-compute sin and cos tables instead of complex numbers for BFloat16 compatibility
        self.register_buffer('cos_cached', None)
        self.register_buffer('sin_cached', None)
        self._precompute_freqs()
        
    def _precompute_freqs(self):
        """Precompute sin and cos tables for rotary embeddings"""
        # Compute frequencies with scaling
        theta = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
        theta = theta * self.scaling_factor
        
        # Pre-compute tables for positions up to max_seq_len
        position = torch.arange(self.max_seq_len).float()
        freqs = torch.outer(position, theta)
        
        # Cache sin and cos values instead of complex numbers
        self.register_buffer('cos_cached', torch.cos(freqs).float())
        self.register_buffer('sin_cached', torch.sin(freqs).float())
    
    def _rotate_half(self, x):
        """Rotate half the hidden dims of x"""
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    
    def _apply_rotary_pos_emb(self, x, seq_len):
        """
        Apply rotary position embeddings using separate sin and cos tables
        Compatible with BFloat16
        
        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, seq_len, dim]
            seq_len (int): Sequence length
            
        Returns:
            torch.Tensor: Tensor with rotary embeddings applied
        """
        # Get the appropriate part of the cached tables
        cos = self.cos_cached[:seq_len].unsqueeze(0)  # [1, seq_len, dim/2]
        sin = self.sin_cached[:seq_len].unsqueeze(0)  # [1, seq_len, dim/2]
        
        # Make sure cos and sin have the right dimensions for broadcasting
        # We need to repeat each value twice to match the original dimensions
        cos = torch.repeat_interleave(cos, 2, dim=-1)  # [1, seq_len, dim]
        sin = torch.repeat_interleave(sin, 2, dim=-1)  # [1, seq_len, dim]
        
        # Apply rotation using the trigonometric addition formulas
        # This is equivalent to complex multiplication but works with any dtype
        return x * cos + self._rotate_half(x) * sin
    
    def forward(self, x):
        """Apply rotary position embeddings to input tensor."""
        batch_size, seq_len, _ = x.size()
        
        if seq_len > self.max_seq_len:
            raise ValueError(f"Input sequence length {seq_len} exceeds maximum sequence length {self.max_seq_len}")
        
        return self._apply_rotary_pos_emb(x, seq_len)


class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1, rope_scaling_factor=0.25, max_seq_len=4096):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        # Create query, key, value projections and output layer
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize LongRope for each head dimension
        self.rope = LongRope(
            dim=self.head_dim,
            scaling_factor=rope_scaling_factor,
            max_seq_len=max_seq_len
        )
        
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        
        # Project query, key, value and split into multiple heads
        q = self.query(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Apply LongRope to queries and keys for each head
        # First, reshape for applying rope to each head separately
        q_reshaped = q.reshape(batch_size * self.n_heads, seq_len, self.head_dim)
        k_reshaped = k.reshape(batch_size * self.n_heads, seq_len, self.head_dim)
        
        # Apply LongRope positional embeddings
        q_rope = self.rope(q_reshaped).view(batch_size, self.n_heads, seq_len, self.head_dim)
        k_rope = self.rope(k_reshaped).view(batch_size, self.n_heads, seq_len, self.head_dim)
        
        # Compute attention scores using the rotary-embedded queries and keys
        scores = torch.matmul(q_rope, k_rope.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Create causal (lower triangular) mask to prevent attending to future positions
        device = x.device
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
        scores.masked_fill_(causal_mask, -1e9)
        
        # Apply softmax and dropout
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention weights to values (no need to apply RoPE to values)
        attn_output = torch.matmul(attn_weights, v)
        
        # Reshape output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.out(attn_output)


class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1, rope_scaling_factor=0.25, max_seq_len=4096):
        super().__init__()
        
        # Self-attention with LongRope
        self.self_attn = CausalSelfAttention(
            d_model=d_model, 
            n_heads=n_heads, 
            dropout=dropout,
            rope_scaling_factor=rope_scaling_factor,
            max_seq_len=max_seq_len
        )
        
        # Feed-forward network
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer normalization and dropout
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Apply self-attention with residual connection and layer norm
        attn_output = self.self_attn(self.norm1(x))
        x = x + self.dropout(attn_output)
        
        # Apply feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(self.norm2(x))
        x = x + self.dropout(ff_output)
        
        return x


class DecoderOnlyTransformer(nn.Module):
    def __init__(self, 
                 vocab_size, 
                 d_model=512, 
                 n_heads=8, 
                 n_layers=6, 
                 d_ff=2048, 
                 dropout=0.1,
                 rope_scaling_factor=0.25,
                 max_seq_len=4096):
        super().__init__()
        
        self.d_model = d_model
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # No separate positional encoding since we're using RoPE in the attention
        self.dropout = nn.Dropout(dropout)
        
        # Decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(
                d_model=d_model, 
                n_heads=n_heads, 
                d_ff=d_ff, 
                dropout=dropout,
                rope_scaling_factor=rope_scaling_factor,
                max_seq_len=max_seq_len
            ) for _ in range(n_layers)
        ])
        
        # Output layer
        self.output_layer = nn.Linear(d_model, vocab_size)
        
        # Initialize parameters
        self._init_parameters()
        
    def _init_parameters(self):
        # Initialize parameters with Xavier/Glorot initialization
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
    def forward(self, x):
        # Get token embeddings
        x = self.token_embedding(x)
        
        # Apply dropout to embeddings
        x = self.dropout(x)
        
        # Apply decoder layers
        for layer in self.layers:
            x = layer(x)
            
        # Apply final layer norm and output projection
        x = self.output_layer(x)
        
        return x

# Define training function with BFloat16 mixed precision
def train_transformer(model, dataloader, optimizer, criterion, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            # Prepare input and target sequences
            input_ids = batch[:, :-1].to(device)  # all tokens except last
            target_ids = batch[:, 1:].to(device)  # all tokens except first
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with autocast for BFloat16
            with autocast(dtype=torch.bfloat16):
                logits = model(input_ids)
                
                # Reshape for cross-entropy loss
                logits = logits.view(-1, logits.size(-1))
                targets = target_ids.reshape(-1)
                
                # Calculate loss
                loss = criterion(logits, targets)
            
            total_loss += loss.item()
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")
        
    return model

In [3]:
import torch
from collections import defaultdict
import numpy as np

class ArithmeticTokenizer:
    """
    A simple tokenizer that operates on arithmetic symbols (numbers, space, +, =).
    Supports character-level tokenization, special tokens, and can handle unknown characters.
    """
    
    def __init__(self):
        # Define special tokens and their IDs
        self.pad_token = "[PAD]"
        self.unk_token = "[UNK]"
        self.bos_token = "[BOS]"
        self.eos_token = "[EOS]"
        
        # Create character to ID mapping
        self.char_to_id = {}
        self.id_to_char = {}
        
        # Add special tokens
        self.special_tokens = {
            self.pad_token: 0,
            self.unk_token: 1,
            self.bos_token: 2,
            self.eos_token: 3
        }
        
        # Add only the necessary characters for arithmetic (digits, space, +, =)
        # Add digits 0-9
        for i in range(10):
            char = str(i)
            token_id = i + len(self.special_tokens)
            self.char_to_id[char] = token_id
            self.id_to_char[token_id] = char
        
        # Add space, plus, and equals sign
        arithmetic_chars = ["+", "="]
        for char in arithmetic_chars:
            token_id = len(self.char_to_id) + len(self.special_tokens)
            self.char_to_id[char] = token_id
            self.id_to_char[token_id] = char
            
        # Add special tokens to id_to_char mapping
        for token, id_ in self.special_tokens.items():
            self.id_to_char[id_] = token
        
        # Total vocabulary size
        self.vocab_size = len(self.char_to_id) + len(self.special_tokens)
    
    def encode(self, text, add_special_tokens=True):
        """
        Encode a text string into a list of token IDs.
        
        Args:
            text: The input text to tokenize
            add_special_tokens: Whether to add BOS/EOS tokens
            
        Returns:
            A list of token IDs
        """
        if not text:
            return []
        
        tokens = []
        
        # Add BOS token if requested
        if add_special_tokens:
            tokens.append(self.special_tokens[self.bos_token])
        
        # Convert characters to token IDs
        for char in text:
            if char in self.char_to_id:
                tokens.append(self.char_to_id[char])
            else:
                tokens.append(self.special_tokens[self.unk_token])
        
        # Add EOS token if requested
        if add_special_tokens:
            tokens.append(self.special_tokens[self.eos_token])
        
        return tokens
    
    def decode(self, token_ids, skip_special_tokens=True):
        """
        Decode a list of token IDs back into a text string.
        
        Args:
            token_ids: List of token IDs to decode
            skip_special_tokens: Whether to skip special tokens in output
            
        Returns:
            The decoded text
        """
        chars = []
        
        for token_id in token_ids:
            # Check if token ID exists
            if token_id not in self.id_to_char:
                continue
                
            token = self.id_to_char[token_id]
            
            # Skip special tokens if requested
            if skip_special_tokens and token in self.special_tokens:
                continue
                
            chars.append(token)
        
        return ''.join(chars)
    
    def encode_batch(self, texts, add_special_tokens=True, padding=True, return_tensors=None):
        """
        Encode a batch of texts.
        
        Args:
            texts: List of text strings to encode
            add_special_tokens: Whether to add BOS/EOS tokens
            padding: Whether to pad sequences to the same length
            return_tensors: Return PyTorch tensors if 'pt'
            
        Returns:
            List of token ID lists or padded tensor
        """
        encoded_texts = [self.encode(text, add_special_tokens) for text in texts]
        
        if padding:
            # Find maximum length
            max_len = max(len(encoded) for encoded in encoded_texts)
            
            # Pad sequences
            padded_texts = []
            for encoded in encoded_texts:
                padding_length = max_len - len(encoded)
                padded = encoded + [self.special_tokens[self.pad_token]] * padding_length
                padded_texts.append(padded)
            
            encoded_texts = padded_texts
        
        if return_tensors == 'pt':
            import torch
            return torch.tensor(encoded_texts)
        
        return encoded_texts
    
    def decode_batch(self, token_ids_batch, skip_special_tokens=True):
        """
        Decode a batch of token ID lists back into text strings.
        
        Args:
            token_ids_batch: Batch of token ID lists or 2D tensor
            skip_special_tokens: Whether to skip special tokens in output
            
        Returns:
            List of decoded text strings
        """
        if hasattr(token_ids_batch, 'tolist'):  # Check if it's a tensor-like object
            token_ids_batch = token_ids_batch.tolist()
            
        return [self.decode(token_ids, skip_special_tokens) for token_ids in token_ids_batch]
    
    def get_vocab(self):
        """
        Returns the vocabulary as a dictionary of token to token ID.
        """
        vocab = self.char_to_id.copy()
        vocab.update(self.special_tokens)
        return vocab
    
    def get_special_tokens_mask(self, token_ids):
        """
        Creates a mask for special tokens in a sequence.
        
        Args:
            token_ids: List of token IDs
            
        Returns:
            A boolean mask with True at positions of special tokens
        """
        special_token_ids = set(self.special_tokens.values())
        return [token_id in special_token_ids for token_id in token_ids]

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import math
import time
import os
from tqdm import tqdm


# Create a BFloat16 GradScaler - initialized once here
scaler = GradScaler(enabled=True)


class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len=128):
        self.tokenized_texts = []
        
        for text in texts:
            # Tokenize text with our ASCII tokenizer
            tokens = tokenizer.encode(text)
            # if len(tokens) > max_len:
            #     tokens = tokens[:max_len]
            self.tokenized_texts.append(torch.tensor(tokens, dtype=torch.long))
    
    def __len__(self):
        return len(self.tokenized_texts)
    
    def __getitem__(self, idx):
        return self.tokenized_texts[idx]


def collate_fn(batch):
    # Pad all sequences to the max length in this batch
    max_len = max(len(x) for x in batch)
    padded_batch = []
    
    for sequence in batch:
        # Create padded sequence
        padded_seq = torch.zeros(max_len, dtype=torch.long)
        padded_seq[:len(sequence)] = sequence
        padded_batch.append(padded_seq)
    
    return torch.stack(padded_batch)


def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        # Move batch to device
        batch = batch.to(device)
        
        # Prepare input and target sequences for language modeling
        input_ids = batch[:, :-1]  # all tokens except last
        target_ids = batch[:, 1:]  # all tokens except first
        
        # Forward pass with BFloat16 mixed precision
        optimizer.zero_grad()
        
        # Use autocast with correct parameters for your PyTorch version
        with autocast(dtype=torch.bfloat16):
            logits = model(input_ids)
        
            # Reshape for cross-entropy loss
            logits = logits.reshape(-1, logits.size(-1))
            targets = target_ids.reshape(-1)
            
            # Calculate loss
            loss = criterion(logits, targets)
        
        # Backward pass with gradient scaling for BFloat16
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)


def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            # Move batch to device
            batch = batch.to(device)
            
            # Prepare input and target sequences
            input_ids = batch[:, :-1]
            target_ids = batch[:, 1:]
            
            # Forward pass with BFloat16 mixed precision
            with autocast(dtype=torch.bfloat16):
                logits = model(input_ids)
                
                # Reshape for cross-entropy loss
                logits = logits.reshape(-1, logits.size(-1))
                targets = target_ids.reshape(-1)
                
                # Calculate loss
                loss = criterion(logits, targets)
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)


# Full training function with BFloat16
def train_model(model, train_dataloader, val_dataloader, criterion, optimizer, device, 
                num_epochs=10, patience=3, model_path='best_model.pt'):
    # Check if BFloat16 is supported
    if torch.cuda.is_bf16_supported():
        print("Training with BFloat16 mixed precision")
    else:
        print("Warning: BFloat16 not supported on this GPU. Using FP32 instead.")
    
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    
    for epoch in range(num_epochs):
        # Train
        train_loss = train_epoch(model, train_dataloader, optimizer, criterion, device)
        
        # Validate
        val_loss = validate(model, val_dataloader, criterion, device)
        
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, model_path)
            print(f"Model saved to {model_path}")
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f"Early stopping after {epoch+1} epochs without improvement")
                break
    
    return model

  scaler = GradScaler(enabled=True)


In [5]:
def generate_text(model, tokenizer, prompt, max_length=50, temperature=0.8, top_k=0, device="cuda"):
    model.eval()
    
    # Encode the prompt BUT REMOVE the EOS token if present
    prompt_tokens = tokenizer.encode(prompt)
    
    # Check if the last token is EOS and remove it
    eos_token_id = tokenizer.special_tokens[tokenizer.eos_token]
    if prompt_tokens[-1] == eos_token_id:
        # print("Removing EOS token from prompt tokens")
        prompt_tokens = prompt_tokens[:-1]
    
    # Convert to tensor and move to device
    input_ids = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(device)
    
    # Generate tokens as before
    generated_token_ids = []
    
    with torch.no_grad():
        for step in range(max_length):
            # Get model predictions with BFloat16 mixed precision
            with autocast(dtype=torch.bfloat16):
                outputs = model(input_ids)
                next_token_logits = outputs[:, -1, :] / temperature
            
            # Apply top-k filtering if specified
            if top_k > 0:
                top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
                mask = torch.zeros_like(next_token_logits).scatter_(1, top_k_indices, 1)
                next_token_logits = next_token_logits.masked_fill(mask == 0, -float('inf'))
            
            # Apply softmax to get probabilities
            probs = torch.softmax(next_token_logits, dim=-1)
            
            # Sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1)
            next_token_id = next_token.item()
            
            # Store just the new token
            generated_token_ids.append(next_token_id)
            
            # Append the next token to input for the next iteration
            input_ids = torch.cat([input_ids, next_token], dim=1)
            
            # Stop if we generate an EOS token
            if next_token_id == eos_token_id:
                # print(f"Generated EOS token at step {step}, stopping generation")
                break
    
    # Decode JUST the newly generated tokens
    generated_text = tokenizer.decode(generated_token_ids)
    
    # For verification, show the full sequence too
    full_sequence = prompt_tokens + generated_token_ids
    full_text = tokenizer.decode(full_sequence)
    
    # print(f"Full text: '{full_text}'")
    
    return generated_text


In [None]:
# # generate data
# import random

# def generate_sample():
#     a = random.randint(0, 100_000_000)
#     b = random.randint(0, 100_000_000)
#     c = a + b
#
    # # reverse the string because it might lead to better generalization because of token by token generation
    # return f"{str(a)[::-1].ljust(10, "0")}+{str(b)[::-1].ljust(10, "0")}={str(c)[::-1].ljust(10, "0")}"

# samples = [generate_sample() for i in range(1_000_000)]
# print(samples[0])

8249320000+2986720000=0236150000


In [None]:
# import json
# print(samples[0])

# with open("data/addition_samples_0_to_100_000_000_num_samples_1_000_000_inverted_no_spaces_leftpad_10digits.json", "w") as file:
#     json.dump(samples, file)

8249320000+2986720000=0236150000


In [None]:
import json

with open("data/addition_samples_0_to_100_000_000_num_samples_1_000_000_inverted_no_spaces_leftpad_10digits.json", "r") as file:
    samples = json.load(file)

In [228]:
import torch
import numpy as np
import os
import time
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

batch_size = 128
# Create tokenizer
tokenizer = ArithmeticTokenizer()
vocab_size = tokenizer.vocab_size
print(f"Vocabulary size: {vocab_size}")

# Sample training data
# In a real scenario, you would load your corpus from files
sample_texts = samples

# Create dataset and dataloader
dataset = TextDataset(sample_texts, tokenizer)

# Split into train/validation sets
train_size = int(0.95 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=collate_fn
)
val_dataloader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    collate_fn=collate_fn
)

Vocabulary size: 16


In [229]:
# Initialize gradient scaler for BFloat16 mixed precision
scaler = GradScaler(enabled=True)

# Parameters
d_model = 64
n_heads = 16
n_layers = 1
d_ff = 256
# max_seq_len = 32
epochs = 1
lr = 3e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda")
print(f"Using device: {device}")

model_path = "models/d64_h16_n1_ff256_continuous_training_reversed_and_padded_numbers.pt"

# Check if BFloat16 is supported
if hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_bf16_supported():
    print("BFloat16 is supported on this device")
else:
    print("WARNING: BFloat16 may not be fully supported on this device.")

# First create your model
model = DecoderOnlyTransformer(
    vocab_size=vocab_size,
    d_model=d_model,
    n_heads=n_heads,
    n_layers=n_layers,
    d_ff=d_ff,
    # max_seq_len=max_seq_len,
    dropout=0.1
)


# Initialize model
if os.path.isfile(model_path):
    # Load the checkpoint dictionary
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])

# Move model to device - don't convert to BFloat16 here
# We'll handle that with autocast
model = model.to(device)

# Define optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=lr)

# If loading from checkpoint, also load optimizer state
if os.path.isfile(model_path) and 'optimizer_state_dict' in checkpoint:
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print("Loaded optimizer state from checkpoint")

# Ignore padding token (ID 0) in loss calculation
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.special_tokens[tokenizer.pad_token])

# Training loop
best_val_loss = float('inf')
print("Starting training...")

for epoch in range(epochs):
    start_time = time.time()
    
    # Train for one epoch
    train_loss = train_epoch(model, train_dataloader, optimizer, criterion, device)
    
    # Validate
    val_loss = validate(model, val_dataloader, criterion, device)
    
    # Calculate elapsed time
    elapsed = time.time() - start_time
    
    print(f"Epoch {epoch+1}/{epochs} | "
            f"Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_loss:.4f} | "
            f"Time: {elapsed:.2f}s")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, model_path)
        print("Saved new best model!")

print("Training complete!")

  scaler = GradScaler(enabled=True)


Using device: cuda
BFloat16 is supported on this device
Loaded optimizer state from checkpoint
Starting training...


  with autocast(dtype=torch.bfloat16):
Training: 100%|██████████| 7422/7422 [00:26<00:00, 284.12it/s]
  with autocast(dtype=torch.bfloat16):
Validation: 100%|██████████| 391/391 [00:00<00:00, 716.20it/s]

Epoch 1/1 | Train Loss: 0.8487 | Val Loss: 0.8375 | Time: 26.67s
Saved new best model!
Training complete!





In [1]:

# Test generation
test_prompts = [
    "9999999900+1000000000=",
    "0200000000+0100000000=",
    "1000010000+1000000000=",
    "2000000000+2000000000=",
    "0010000000+0020000000="
]

print("\nGeneration examples:")
for prompt in test_prompts:
    generated = generate_text(
        model,
        tokenizer,
        prompt,
        max_length=50,
        temperature=0.8,
        top_k=5,
        device=device
    )
    prompt_reversed = "+".join([substr[::-1] for substr in prompt.replace("=", "").split("+")]) + "="
    generated_reversed = generated[::-1]
    print(f"Prompt: '{prompt_reversed}'")
    print(f"Generated: '{generated_reversed}'\n")


Generation examples:


NameError: name 'generate_text' is not defined