In [1]:
import pandas as pd
import re
import os
import torch

# Set GPU device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Dataset path
DATA_PATH = '/home/ie643_errorcode500/errorcode500-working/Mathwritting-1000'

# Load the CSV files
train_df = pd.read_csv(os.path.join(DATA_PATH, 'train_database.csv'))
test_df = pd.read_csv(os.path.join(DATA_PATH, 'test_database.csv'))
val_df = pd.read_csv(os.path.join(DATA_PATH, 'val_database.csv'))

def preprocess_latex(latex_str):
    """
    Preprocessing steps as described in the paper:
    1. Remove style-related characters (\mathrm, \textrm, \operatorname, \displaystyle)
    2. Normalize LaTeX sequences (e.g., {a}^{2} -> a^{2})
    """
    if pd.isna(latex_str):
        return None
    
    # Remove style-related commands
    latex_str = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', latex_str)
    latex_str = re.sub(r'\\textrm\{([^}]*)\}', r'\1', latex_str)
    latex_str = re.sub(r'\\operatorname\{([^}]*)\}', r'\1', latex_str)
    latex_str = re.sub(r'\\displaystyle', '', latex_str)
    
    # Normalize single character in braces before superscript/subscript
    # {a}^{2} -> a^{2}
    latex_str = re.sub(r'\{([^}])\}\^', r'\1^', latex_str)
    latex_str = re.sub(r'\{([^}])\}_', r'\1_', latex_str)
    
    return latex_str.strip()

def filter_invalid_syntax(latex_str):
    """
    Filter out LaTeX sequences with invalid syntax
    Check for balanced braces
    """
    if pd.isna(latex_str) or latex_str == '':
        return False
    
    # Check balanced braces
    brace_count = 0
    for char in latex_str:
        if char == '{':
            brace_count += 1
        elif char == '}':
            brace_count -= 1
        if brace_count < 0:
            return False
    
    return brace_count == 0

# Apply preprocessing
print("Preprocessing training set...")
train_df['preprocessed_label'] = train_df['normalized_label'].apply(preprocess_latex)
train_df = train_df[train_df['preprocessed_label'].apply(filter_invalid_syntax)]

print("Preprocessing validation set...")
val_df['preprocessed_label'] = val_df['normalized_label'].apply(preprocess_latex)
val_df = val_df[val_df['preprocessed_label'].apply(filter_invalid_syntax)]

print("Preprocessing test set...")
test_df['preprocessed_label'] = test_df['normalized_label'].apply(preprocess_latex)
test_df = test_df[test_df['preprocessed_label'].apply(filter_invalid_syntax)]

# Save preprocessed data
train_df.to_csv(os.path.join(DATA_PATH, 'train_database_preprocessed.csv'), index=False)
val_df.to_csv(os.path.join(DATA_PATH, 'val_database_preprocessed.csv'), index=False)
test_df.to_csv(os.path.join(DATA_PATH, 'test_database_preprocessed.csv'), index=False)

print(f"Training set: {len(train_df)} samples")
print(f"Validation set: {len(val_df)} samples")
print(f"Test set: {len(test_df)} samples")

# Build vocabulary (108 symbols including <pad> and <eos>)
def tokenize_latex(latex_str):
    """
    Tokenize LaTeX string into individual symbols
    """
    tokens = []
    i = 0
    while i < len(latex_str):
        if latex_str[i] == '\\':
            # LaTeX command
            j = i + 1
            while j < len(latex_str) and latex_str[j].isalpha():
                j += 1
            tokens.append(latex_str[i:j])
            i = j
        elif latex_str[i] in ['{', '}', '^', '_', '(', ')', '[', ']', '|', ',', '.', '=', '+', '-', '*', '/', '<', '>', '!', '&']:
            # Special characters
            tokens.append(latex_str[i])
            i += 1
        elif latex_str[i] == ' ':
            # Skip spaces
            i += 1
        else:
            # Regular character
            tokens.append(latex_str[i])
            i += 1
    
    return tokens

def build_vocabulary(df_list):
    """
    Build vocabulary from preprocessed labels
    """
    vocab = set()
    
    for df in df_list:
        for latex_str in df['preprocessed_label']:
            if pd.notna(latex_str):
                # Tokenize LaTeX string into symbols
                tokens = tokenize_latex(latex_str)
                vocab.update(tokens)
    
    # Add special tokens
    vocab = ['<pad>', '<eos>'] + sorted(list(vocab))
    
    return vocab

# Build vocabulary
vocab = build_vocabulary([train_df, val_df, test_df])

# Save vocabulary
with open(os.path.join(DATA_PATH, 'vocabulary.txt'), 'w') as f:
    for token in vocab:
        f.write(token + '\n')

print(f"Vocabulary size: {len(vocab)} symbols")
print(f"Expected: 108 symbols (including <pad> and <eos>)")

# Create token to index mapping
token2idx = {token: idx for idx, token in enumerate(vocab)}
idx2token = {idx: token for token, idx in token2idx.items()}

# Save mappings
import pickle
with open(os.path.join(DATA_PATH, 'token2idx.pkl'), 'wb') as f:
    pickle.dump(token2idx, f)
with open(os.path.join(DATA_PATH, 'idx2token.pkl'), 'wb') as f:
    pickle.dump(idx2token, f)

print("Preprocessing complete!")


Using device: cuda:1
Preprocessing training set...
Preprocessing validation set...
Preprocessing test set...
Training set: 1000 samples
Validation set: 100 samples
Test set: 100 samples
Vocabulary size: 230 symbols
Expected: 108 symbols (including <pad> and <eos>)
Preprocessing complete!


In [2]:
import torch
import torch.nn as nn
import math

class InputEmbedding(nn.Module):
    """
    Input Embedding Layer
    Embeds discrete tokens into continuous space
    """
    def __init__(self, vocab_size, d_model):
        """
        Args:
            vocab_size: Size of vocabulary (230 in your case)
            d_model: Dimension of input embedding (256 as per paper)
        """
        super(InputEmbedding, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
    
    def forward(self, x):
        """
        Args:
            x: Input token indices, shape (batch_size, seq_len)
        Returns:
            Embedded vectors, shape (batch_size, seq_len, d_model)
        """
        return self.embedding(x)


class PositionalEncoding(nn.Module):
    """
    Positional Encoding Layer
    Adds positional information to embedded vectors using sine and cosine functions
    Formula from paper (Vaswani et al., "Attention Is All You Need"):
    PE(p, 2i) = sin(p / 10000^(2i/d_model))
    PE(p, 2i+1) = cos(p / 10000^(2i/d_model))
    """
    def __init__(self, d_model, max_seq_len=256, dropout=0.1):
        """
        Args:
            d_model: Dimension of input embedding (256 as per paper)
            max_seq_len: Maximum sequence length / context length (256 as per paper)
            dropout: Dropout rate (0.1 as per paper)
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        
        # Compute div_term for all dimensions
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices (2i)
        pe[:, 0::2] = torch.sin(position * div_term)
        
        # Apply cosine to odd indices (2i+1)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension: (1, max_seq_len, d_model)
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter, but part of module state)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: Embedded vectors, shape (batch_size, seq_len, d_model)
        Returns:
            Embedded vectors with positional encoding added, same shape
        """
        # Add positional encoding to input
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class InputLayer(nn.Module):
    """
    Complete Input Layer combining Embedding and Positional Encoding
    """
    def __init__(self, vocab_size, d_model=256, max_seq_len=256, dropout=0.1):
        """
        Args:
            vocab_size: Size of vocabulary (230 in your case)
            d_model: Dimension of input embedding (256 as per paper)
            max_seq_len: Maximum sequence length (256 as per paper)
            dropout: Dropout rate (0.1 as per paper)
        """
        super(InputLayer, self).__init__()
        self.embedding = InputEmbedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
    
    def forward(self, x):
        """
        Args:
            x: Input token indices, shape (batch_size, seq_len)
        Returns:
            Embedded vectors with positional encoding, shape (batch_size, seq_len, d_model)
        """
        x = self.embedding(x)
        x = self.positional_encoding(x)
        return x


# Test the input layer
if __name__ == "__main__":
    import pickle
    import os
    
    # Set device
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load vocabulary
    DATA_PATH = '/home/ie643_errorcode500/errorcode500-working/Mathwritting-1000'
    with open(os.path.join(DATA_PATH, 'token2idx.pkl'), 'rb') as f:
        token2idx = pickle.load(f)
    
    vocab_size = len(token2idx)
    print(f"Vocabulary size: {vocab_size}")
    
    # Hyperparameters from paper
    d_model = 256
    max_seq_len = 256
    dropout = 0.1
    
    # Create input layer
    input_layer = InputLayer(
        vocab_size=vocab_size,
        d_model=d_model,
        max_seq_len=max_seq_len,
        dropout=dropout
    ).to(device)
    
    print(f"\nInput Layer created successfully!")
    print(f"Parameters:")
    print(f"  - Vocabulary size: {vocab_size}")
    print(f"  - Embedding dimension (d_model): {d_model}")
    print(f"  - Maximum sequence length: {max_seq_len}")
    print(f"  - Dropout rate: {dropout}")
    
    # Test with dummy input
    batch_size = 4
    seq_len = 20
    dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
    
    output = input_layer(dummy_input)
    print(f"\nTest successful!")
    print(f"  Input shape: {dummy_input.shape}")
    print(f"  Output shape: {output.shape}")


Using device: cuda:1
Vocabulary size: 230

Input Layer created successfully!
Parameters:
  - Vocabulary size: 230
  - Embedding dimension (d_model): 256
  - Maximum sequence length: 256
  - Dropout rate: 0.1

Test successful!
  Input shape: torch.Size([4, 20])
  Output shape: torch.Size([4, 20, 256])


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MaskedMultiHeadSelfAttention(nn.Module):
    """
    Masked Multi-Head Self-Attention (MMSA)
    Uses scaled dot-product attention with mask to prevent attending to future tokens
    """
    def __init__(self, d_model=256, num_heads=4, dropout=0.1):
        """
        Args:
            d_model: Dimension of input (256 as per paper)
            num_heads: Number of attention heads (4 as per paper)
            dropout: Dropout rate (0.1 as per paper)
        """
        super(MaskedMultiHeadSelfAttention, self).__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head = 16
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        Scaled Dot-Product Attention as per paper formula:
        Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        
        Args:
            Q: Queries, shape (batch_size, num_heads, seq_len, d_k)
            K: Keys, shape (batch_size, num_heads, seq_len, d_k)
            V: Values, shape (batch_size, num_heads, seq_len, d_k)
            mask: Mask to prevent attending to future tokens
        """
        # Compute attention scores: QK^T / sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask (set future positions to -inf before softmax)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Compute output: attention_weights * V
        output = torch.matmul(attention_weights, V)
        
        return output
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input, shape (batch_size, seq_len, d_model)
            mask: Causal mask to prevent attending to future tokens
        Returns:
            Output, shape (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, d_model = x.size()
        
        # Linear projections and reshape for multi-head attention
        # (batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_k)
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Apply scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads and apply output projection
        # (batch_size, num_heads, seq_len, d_k) -> (batch_size, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        output = self.W_o(attn_output)
        
        return output


class FeedForwardNetwork(nn.Module):
    """
    Position-wise Feedforward Neural Network
    """
    def __init__(self, d_model=256, d_ff=1024, dropout=0.1):
        """
        Args:
            d_model: Input/output dimension (256 as per paper)
            d_ff: Hidden layer dimension (1024 as per paper)
            dropout: Dropout rate (0.1 as per paper)
        """
        super(FeedForwardNetwork, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        Args:
            x: Input, shape (batch_size, seq_len, d_model)
        Returns:
            Output, shape (batch_size, seq_len, d_model)
        """
        return self.linear2(self.dropout(F.relu(self.linear1(x))))


class TransformerLayer(nn.Module):
    """
    Single Transformer Layer
    Architecture: MMSA -> Add & Norm -> FFN -> Add & Norm
    """
    def __init__(self, d_model=256, num_heads=4, d_ff=1024, dropout=0.1):
        """
        Args:
            d_model: Dimension of input (256 as per paper)
            num_heads: Number of attention heads (4 as per paper)
            d_ff: Feedforward hidden dimension (1024 as per paper)
            dropout: Dropout rate (0.1 as per paper)
        """
        super(TransformerLayer, self).__init__()
        
        # Masked Multi-Head Self-Attention
        self.mmsa = MaskedMultiHeadSelfAttention(d_model, num_heads, dropout)
        
        # Feedforward Neural Network
        self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
        
        # Layer Normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input, shape (batch_size, seq_len, d_model)
            mask: Causal mask
        Returns:
            Output, shape (batch_size, seq_len, d_model)
        """
        # MMSA with residual connection and layer norm
        attn_output = self.mmsa(x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # FFN with residual connection and layer norm
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_output))
        
        return x


# Test the transformer layer
if __name__ == "__main__":
    import os
    
    # Set device
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Hyperparameters from paper
    d_model = 256
    num_heads = 4
    d_ff = 1024
    dropout = 0.1
    
    # Create transformer layer
    transformer_layer = TransformerLayer(
        d_model=d_model,
        num_heads=num_heads,
        d_ff=d_ff,
        dropout=dropout
    ).to(device)
    
    print(f"\nTransformer Layer created successfully!")
    print(f"Parameters:")
    print(f"  - Input/output dimension (d_model): {d_model}")
    print(f"  - Number of heads: {num_heads}")
    print(f"  - Dimension per head: {d_model // num_heads}")
    print(f"  - Feedforward hidden dimension: {d_ff}")
    print(f"  - Dropout rate: {dropout}")
    
    # Test with dummy input
    batch_size = 4
    seq_len = 20
    dummy_input = torch.randn(batch_size, seq_len, d_model).to(device)
    
    # Create causal mask (left-to-right)
    mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0).to(device)
    
    output = transformer_layer(dummy_input, mask)
    print(f"\nTest successful!")
    print(f"  Input shape: {dummy_input.shape}")
    print(f"  Output shape: {output.shape}")


Using device: cuda:1

Transformer Layer created successfully!
Parameters:
  - Input/output dimension (d_model): 256
  - Number of heads: 4
  - Dimension per head: 64
  - Feedforward hidden dimension: 1024
  - Dropout rate: 0.1

Test successful!
  Input shape: torch.Size([4, 20, 256])
  Output shape: torch.Size([4, 20, 256])


In [5]:
import torch
import torch.nn as nn

class TMLM(nn.Module):
    """
    Transformer-based Math Language Model (TMLM)
    Complete model with Input Layer, Stacked Transformer Layers, and Output Layer
    """
    def __init__(self, vocab_size, num_layers=2, d_model=256, num_heads=4, 
                 d_ff=1024, max_seq_len=256, dropout=0.1,pad_idx=0):
        """
        Args:
            vocab_size: Size of vocabulary (230 in your case)
            num_layers: Number of transformer layers (2 for TMLM2L as per paper)
            d_model: Dimension of input embedding (256 as per paper)
            num_heads: Number of attention heads (4 as per paper)
            d_ff: Feedforward hidden dimension (1024 as per paper)
            max_seq_len: Maximum sequence length (256 as per paper)
            dropout: Dropout rate (0.1 as per paper)
        """
        super(TMLM, self).__init__()
        self.pad_idx = pad_idx
        
        # Input Layer (Embedding + Positional Encoding)
        self.input_layer = InputLayer(vocab_size, d_model, max_seq_len, dropout)
        
        # Stack of Transformer Layers
        self.transformer_layers = nn.ModuleList([
            TransformerLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Output Layer (Adaptive Softmax)
        self.output_layer = nn.AdaptiveLogSoftmaxWithLoss(
            d_model, vocab_size, 
            cutoffs=[vocab_size // 10, 3 * vocab_size // 10]
        )
    
    def forward(self, x, targets=None):
        """
        Args:
            x: Input token indices, shape (batch_size, seq_len)
            targets: Target token indices for training, shape (batch_size, seq_len)
        Returns:
            If targets provided: loss
            Otherwise: logits for prediction
        """
        batch_size, seq_len = x.size()
        
        # Create causal mask (left-to-right)
        mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0).to(x.device)
        
        # Input Layer
        x = self.input_layer(x)
        
        # Pass through transformer layers
        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x, mask)
        
        # Output Layer
        if targets is not None:
            # Training mode: compute loss
            x = x.reshape(-1, x.size(-1))
            targets = targets.reshape(-1)
            output = self.output_layer(x, targets)
            return output.loss
        else:
            # Inference mode: return logits
            output = self.output_layer.log_prob(x)
            return output


# Test the complete TMLM model
if __name__ == "__main__":
    import pickle
    import os
    
    # Set device
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Dataset path
    DATA_PATH = '/home/ie643_errorcode500/errorcode500-working/Mathwritting-1000'
    
    # Load vocabulary
    with open(os.path.join(DATA_PATH, 'token2idx.pkl'), 'rb') as f:
        token2idx = pickle.load(f)
    
    vocab_size = len(token2idx)
    pad_idx = token2idx['<pad>']
    print(f"Vocabulary size: {vocab_size}")
    
    # Hyperparameters from paper
    num_layers = 2  # TMLM2L
    d_model = 256
    num_heads = 4
    d_ff = 1024
    max_seq_len = 256
    dropout = 0.1
    
    # Create TMLM model
    model = TMLM(
    vocab_size=vocab_size,
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    d_ff=d_ff,
    max_seq_len=max_seq_len,
    dropout=dropout,
    pad_idx=pad_idx
).to(device)

    
    print(f"\nTMLM2L Model created successfully!")
    print(f"Parameters:")
    print(f"  - Number of transformer layers: {num_layers}")
    print(f"  - Vocabulary size: {vocab_size}")
    print(f"  - Embedding dimension (d_model): {d_model}")
    print(f"  - Number of heads: {num_heads}")
    print(f"  - Dimension per head: {d_model // num_heads}")
    print(f"  - Feedforward hidden dimension: {d_ff}")
    print(f"  - Maximum sequence length: {max_seq_len}")
    print(f"  - Dropout rate: {dropout}")
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"  - Total parameters: {total_params / 1e6:.2f}M")
    
    # Test with dummy input
    batch_size = 4
    seq_len = 20
    dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
    dummy_targets = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
    
    # Test training mode
    model.train()
    loss = model(dummy_input, dummy_targets)
    print(f"\nTest successful!")
    print(f"  Input shape: {dummy_input.shape}")
    print(f"  Training loss: {loss.item():.4f}")


Using device: cuda:1
Vocabulary size: 230

TMLM2L Model created successfully!
Parameters:
  - Number of transformer layers: 2
  - Vocabulary size: 230
  - Embedding dimension (d_model): 256
  - Number of heads: 4
  - Dimension per head: 64
  - Feedforward hidden dimension: 1024
  - Maximum sequence length: 256
  - Dropout rate: 0.1
  - Total parameters: 1.67M

Test successful!
  Input shape: torch.Size([4, 20])
  Training loss: 7.5077


In [6]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import pickle
import os
import math
from tqdm import tqdm

# Dataset class
class LaTeXDataset(Dataset):
    """
    Dataset for LaTeX sequences
    """
    def __init__(self, csv_file, token2idx, max_seq_len=256):
        """
        Args:
            csv_file: Path to preprocessed CSV file
            token2idx: Token to index mapping
            max_seq_len: Maximum sequence length (256 as per paper)
        """
        self.df = pd.read_csv(csv_file)
        self.token2idx = token2idx
        self.max_seq_len = max_seq_len
        self.pad_idx = token2idx['<pad>']
        self.eos_idx = token2idx['<eos>']
    
    def tokenize_latex(self, latex_str):
        """Tokenize LaTeX string into individual symbols"""
        tokens = []
        i = 0
        while i < len(latex_str):
            if latex_str[i] == '\\':
                j = i + 1
                while j < len(latex_str) and latex_str[j].isalpha():
                    j += 1
                tokens.append(latex_str[i:j])
                i = j
            elif latex_str[i] in ['{', '}', '^', '_', '(', ')', '[', ']', '|', ',', '.', '=', '+', '-', '*', '/', '<', '>', '!', '&']:
                tokens.append(latex_str[i])
                i += 1
            elif latex_str[i] == ' ':
                i += 1
            else:
                tokens.append(latex_str[i])
                i += 1
        return tokens
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        latex_str = self.df.iloc[idx]['preprocessed_label']
        
        # Tokenize
        tokens = self.tokenize_latex(latex_str)
        
        # Convert to indices and add <eos>
        token_ids = [self.token2idx.get(token, self.pad_idx) for token in tokens]
        token_ids.append(self.eos_idx)
        
        # Truncate or pad
        if len(token_ids) > self.max_seq_len:
            token_ids = token_ids[:self.max_seq_len]
        else:
            token_ids += [self.pad_idx] * (self.max_seq_len - len(token_ids))
        
        return torch.tensor(token_ids, dtype=torch.long)


def calculate_perplexity(model, dataloader, device,pad_idx):
    """
    Calculate perplexity as per paper formula:
    Perplexity = exp(-1/N * sum(log p(x_i | x_1, ..., x_{i-1})))
    """
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            
            # Input: all tokens except last, Target: all tokens except first
            inputs = batch[:, :-1]
            targets = batch[:, 1:]
            
            # Forward pass
            loss = model(inputs, targets)
            
            # Count non-padding tokens
            mask = (targets != pad_idx).float()
            n_tokens = mask.sum().item()
            
            total_loss += loss.item() * n_tokens
            total_tokens += n_tokens
    
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    
    return perplexity


def train_tmlm():
    """
    Train TMLM model with exact settings from paper
    """
    # Set device
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}\n")
    
    # Load vocabulary
    DATA_PATH = '/home/ie643_errorcode500/errorcode500-working/Mathwritting-1000'
    with open(os.path.join(DATA_PATH, 'token2idx.pkl'), 'rb') as f:
        token2idx = pickle.load(f)
    with open(os.path.join(DATA_PATH, 'idx2token.pkl'), 'rb') as f:
        idx2token = pickle.load(f)
    
    vocab_size = len(token2idx)
    pad_idx = token2idx['<pad>']

    
    # Hyperparameters from paper
    num_layers = 2  # TMLM2L
    d_model = 256
    num_heads = 4
    d_ff = 1024
    max_seq_len = 256
    dropout = 0.1
    learning_rate = 1e-5  # 10^-5 as per paper
    batch_size = 32
    num_epochs = 50
    
    print(f"Model Configuration:")
    print(f"  - Number of layers: {num_layers} (TMLM2L)")
    print(f"  - Vocabulary size: {vocab_size}")
    print(f"  - d_model: {d_model}")
    print(f"  - Number of heads: {num_heads}")
    print(f"  - Feedforward hidden: {d_ff}")
    print(f"  - Max sequence length: {max_seq_len}")
    print(f"  - Dropout: {dropout}")
    print(f"  - Learning rate: {learning_rate}")
    print(f"  - Batch size: {batch_size}")
    print(f"  - Epochs: {num_epochs}\n")
    
    # Create datasets
    train_dataset = LaTeXDataset(
        os.path.join(DATA_PATH, 'train_database_preprocessed.csv'),
        token2idx, max_seq_len
    )
    val_dataset = LaTeXDataset(
        os.path.join(DATA_PATH, 'val_database_preprocessed.csv'),
        token2idx, max_seq_len
    )
    test_dataset = LaTeXDataset(
        os.path.join(DATA_PATH, 'test_database_preprocessed.csv'),
        token2idx, max_seq_len
    )
    
    print(f"Dataset sizes:")
    print(f"  - Training: {len(train_dataset)}")
    print(f"  - Validation: {len(val_dataset)}")
    print(f"  - Testing: {len(test_dataset)}\n")
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Create model
    model = TMLM(
        vocab_size=vocab_size,
        num_layers=num_layers,
        d_model=d_model,
        num_heads=num_heads,
        d_ff=d_ff,
        max_seq_len=max_seq_len,
        dropout=dropout
    ).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params / 1e6:.2f}M\n")
    
    # Optimizer: AdamW with learning rate 10^-5 as per paper
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Training loop
    print("Starting training...\n")
    best_val_perplexity = float('inf')
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        train_tokens = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            batch = batch.to(device)
            
            # Input: all tokens except last, Target: all tokens except first
            inputs = batch[:, :-1]
            targets = batch[:, 1:]
            
            # Forward pass
            loss = model(inputs, targets)
            
            # Count non-padding tokens
            mask = (targets != token2idx['<pad>']).float()
            n_tokens = mask.sum().item()
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * n_tokens
            train_tokens += n_tokens
        
        # Calculate training perplexity
        train_perplexity = math.exp(train_loss / train_tokens)
        
        # Validation
        val_perplexity = calculate_perplexity(model, val_loader, device,pad_idx)
        
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Perplexity: {train_perplexity:.4f}")
        print(f"  Val Perplexity: {val_perplexity:.4f}")
        
        # Save best model
        if val_perplexity < best_val_perplexity:
            best_val_perplexity = val_perplexity
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_perplexity': val_perplexity,
            }, os.path.join(DATA_PATH, 'best_tmlm2l_model.pt'))
            print(f"  *** Best model saved! ***")
        print()
    
    # Load best model and evaluate on test set
    print("\nLoading best model for testing...")
    checkpoint = torch.load(os.path.join(DATA_PATH, 'best_tmlm2l_model.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    
    test_perplexity = calculate_perplexity(model, test_loader, device,pad_idx)
    print(f"\nTest Perplexity: {test_perplexity:.4f}")
    print(f"Paper TMLM2L Perplexity: 4.598")
    
    return model


if __name__ == "__main__":
    model = train_tmlm()


Using device: cuda:1

Model Configuration:
  - Number of layers: 2 (TMLM2L)
  - Vocabulary size: 230
  - d_model: 256
  - Number of heads: 4
  - Feedforward hidden: 1024
  - Max sequence length: 256
  - Dropout: 0.1
  - Learning rate: 1e-05
  - Batch size: 32
  - Epochs: 50

Dataset sizes:
  - Training: 1000
  - Validation: 100
  - Testing: 100

Total parameters: 1.67M

Starting training...



Epoch 1/50 [Train]: 100%|████████████████████████████████████████████████| 32/32 [00:01<00:00, 26.88it/s]


Epoch 1/50:
  Train Perplexity: 17.9829
  Val Perplexity: 5.8458
  *** Best model saved! ***



Epoch 2/50 [Train]: 100%|████████████████████████████████████████████████| 32/32 [00:00<00:00, 41.10it/s]


Epoch 2/50:
  Train Perplexity: 3.8602
  Val Perplexity: 2.3629
  *** Best model saved! ***



Epoch 3/50 [Train]: 100%|████████████████████████████████████████████████| 32/32 [00:00<00:00, 44.42it/s]


Epoch 3/50:
  Train Perplexity: 2.2733
  Val Perplexity: 1.9281
  *** Best model saved! ***



Epoch 4/50 [Train]: 100%|████████████████████████████████████████████████| 32/32 [00:01<00:00, 24.79it/s]


Epoch 4/50:
  Train Perplexity: 1.9626
  Val Perplexity: 1.7846
  *** Best model saved! ***



Epoch 5/50 [Train]: 100%|████████████████████████████████████████████████| 32/32 [00:01<00:00, 27.75it/s]


Epoch 5/50:
  Train Perplexity: 1.8395
  Val Perplexity: 1.7100
  *** Best model saved! ***



Epoch 6/50 [Train]: 100%|████████████████████████████████████████████████| 32/32 [00:00<00:00, 39.86it/s]


Epoch 6/50:
  Train Perplexity: 1.7732
  Val Perplexity: 1.6659
  *** Best model saved! ***



Epoch 7/50 [Train]: 100%|████████████████████████████████████████████████| 32/32 [00:01<00:00, 29.75it/s]


Epoch 7/50:
  Train Perplexity: 1.7219
  Val Perplexity: 1.6370
  *** Best model saved! ***



Epoch 8/50 [Train]: 100%|████████████████████████████████████████████████| 32/32 [00:01<00:00, 27.93it/s]


Epoch 8/50:
  Train Perplexity: 1.6805
  Val Perplexity: 1.6154
  *** Best model saved! ***



Epoch 9/50 [Train]: 100%|████████████████████████████████████████████████| 32/32 [00:01<00:00, 27.23it/s]


Epoch 9/50:
  Train Perplexity: 1.6657
  Val Perplexity: 1.5982
  *** Best model saved! ***



Epoch 10/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 41.03it/s]


Epoch 10/50:
  Train Perplexity: 1.6425
  Val Perplexity: 1.5832
  *** Best model saved! ***



Epoch 11/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 29.68it/s]


Epoch 11/50:
  Train Perplexity: 1.6249
  Val Perplexity: 1.5702
  *** Best model saved! ***



Epoch 12/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 24.45it/s]


Epoch 12/50:
  Train Perplexity: 1.6172
  Val Perplexity: 1.5587
  *** Best model saved! ***



Epoch 13/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 28.94it/s]


Epoch 13/50:
  Train Perplexity: 1.5937
  Val Perplexity: 1.5484
  *** Best model saved! ***



Epoch 14/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 32.76it/s]


Epoch 14/50:
  Train Perplexity: 1.5948
  Val Perplexity: 1.5390
  *** Best model saved! ***



Epoch 15/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 37.55it/s]


Epoch 15/50:
  Train Perplexity: 1.5749
  Val Perplexity: 1.5305
  *** Best model saved! ***



Epoch 16/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 38.99it/s]


Epoch 16/50:
  Train Perplexity: 1.5599
  Val Perplexity: 1.5229
  *** Best model saved! ***



Epoch 17/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 38.80it/s]


Epoch 17/50:
  Train Perplexity: 1.5600
  Val Perplexity: 1.5159
  *** Best model saved! ***



Epoch 18/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 24.64it/s]


Epoch 18/50:
  Train Perplexity: 1.5450
  Val Perplexity: 1.5096
  *** Best model saved! ***



Epoch 19/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 27.03it/s]


Epoch 19/50:
  Train Perplexity: 1.5523
  Val Perplexity: 1.5037
  *** Best model saved! ***



Epoch 20/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 24.52it/s]


Epoch 20/50:
  Train Perplexity: 1.5408
  Val Perplexity: 1.4980
  *** Best model saved! ***



Epoch 21/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 42.82it/s]


Epoch 21/50:
  Train Perplexity: 1.5281
  Val Perplexity: 1.4930
  *** Best model saved! ***



Epoch 22/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 36.43it/s]


Epoch 22/50:
  Train Perplexity: 1.5240
  Val Perplexity: 1.4884
  *** Best model saved! ***



Epoch 23/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 28.16it/s]


Epoch 23/50:
  Train Perplexity: 1.5289
  Val Perplexity: 1.4841
  *** Best model saved! ***



Epoch 24/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 40.25it/s]


Epoch 24/50:
  Train Perplexity: 1.5131
  Val Perplexity: 1.4800
  *** Best model saved! ***



Epoch 25/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 32.37it/s]


Epoch 25/50:
  Train Perplexity: 1.5117
  Val Perplexity: 1.4762
  *** Best model saved! ***



Epoch 26/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 44.16it/s]


Epoch 26/50:
  Train Perplexity: 1.5058
  Val Perplexity: 1.4727
  *** Best model saved! ***



Epoch 27/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 27.86it/s]


Epoch 27/50:
  Train Perplexity: 1.5092
  Val Perplexity: 1.4691
  *** Best model saved! ***



Epoch 28/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 29.23it/s]


Epoch 28/50:
  Train Perplexity: 1.4980
  Val Perplexity: 1.4659
  *** Best model saved! ***



Epoch 29/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 43.70it/s]


Epoch 29/50:
  Train Perplexity: 1.4964
  Val Perplexity: 1.4627
  *** Best model saved! ***



Epoch 30/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 27.82it/s]


Epoch 30/50:
  Train Perplexity: 1.5016
  Val Perplexity: 1.4598
  *** Best model saved! ***



Epoch 31/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 29.88it/s]


Epoch 31/50:
  Train Perplexity: 1.4849
  Val Perplexity: 1.4568
  *** Best model saved! ***



Epoch 32/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 31.41it/s]


Epoch 32/50:
  Train Perplexity: 1.4848
  Val Perplexity: 1.4541
  *** Best model saved! ***



Epoch 33/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 33.94it/s]


Epoch 33/50:
  Train Perplexity: 1.4826
  Val Perplexity: 1.4512
  *** Best model saved! ***



Epoch 34/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 42.11it/s]


Epoch 34/50:
  Train Perplexity: 1.4823
  Val Perplexity: 1.4486
  *** Best model saved! ***



Epoch 35/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 27.26it/s]


Epoch 35/50:
  Train Perplexity: 1.4747
  Val Perplexity: 1.4460
  *** Best model saved! ***



Epoch 36/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 44.06it/s]


Epoch 36/50:
  Train Perplexity: 1.4735
  Val Perplexity: 1.4435
  *** Best model saved! ***



Epoch 37/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 27.84it/s]


Epoch 37/50:
  Train Perplexity: 1.4717
  Val Perplexity: 1.4410
  *** Best model saved! ***



Epoch 38/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 32.75it/s]


Epoch 38/50:
  Train Perplexity: 1.4636
  Val Perplexity: 1.4385
  *** Best model saved! ***



Epoch 39/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 37.07it/s]


Epoch 39/50:
  Train Perplexity: 1.4708
  Val Perplexity: 1.4361
  *** Best model saved! ***



Epoch 40/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 33.45it/s]


Epoch 40/50:
  Train Perplexity: 1.4618
  Val Perplexity: 1.4336
  *** Best model saved! ***



Epoch 41/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 24.78it/s]


Epoch 41/50:
  Train Perplexity: 1.4591
  Val Perplexity: 1.4313
  *** Best model saved! ***



Epoch 42/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 29.47it/s]


Epoch 42/50:
  Train Perplexity: 1.4573
  Val Perplexity: 1.4290
  *** Best model saved! ***



Epoch 43/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 42.23it/s]


Epoch 43/50:
  Train Perplexity: 1.4487
  Val Perplexity: 1.4265
  *** Best model saved! ***



Epoch 44/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 28.82it/s]


Epoch 44/50:
  Train Perplexity: 1.4477
  Val Perplexity: 1.4243
  *** Best model saved! ***



Epoch 45/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 29.45it/s]


Epoch 45/50:
  Train Perplexity: 1.4447
  Val Perplexity: 1.4220
  *** Best model saved! ***



Epoch 46/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 27.23it/s]


Epoch 46/50:
  Train Perplexity: 1.4438
  Val Perplexity: 1.4198
  *** Best model saved! ***



Epoch 47/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 38.88it/s]


Epoch 47/50:
  Train Perplexity: 1.4451
  Val Perplexity: 1.4177
  *** Best model saved! ***



Epoch 48/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:00<00:00, 34.99it/s]


Epoch 48/50:
  Train Perplexity: 1.4368
  Val Perplexity: 1.4154
  *** Best model saved! ***



Epoch 49/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 30.43it/s]


Epoch 49/50:
  Train Perplexity: 1.4359
  Val Perplexity: 1.4133
  *** Best model saved! ***



Epoch 50/50 [Train]: 100%|███████████████████████████████████████████████| 32/32 [00:01<00:00, 27.94it/s]


Epoch 50/50:
  Train Perplexity: 1.4332
  Val Perplexity: 1.4112
  *** Best model saved! ***


Loading best model for testing...

Test Perplexity: 1.3405
Paper TMLM2L Perplexity: 4.598


In [8]:
import torch
import pickle
import os
from PIL import Image

def predict_latex(model, image_path, token2idx, idx2token, max_len=256, device='cuda:1'):
    """
    Generate LaTeX prediction from a handwritten math image
    (This requires an image encoder which the paper uses SRTC+SLP)
    For now, we can test the language model on partial sequences
    """
    model.eval()
    
    # For demonstration: start with a partial LaTeX sequence and let model complete it
    # In full system, this would come from the image recognizer
    
    partial_sequence = input("Enter partial LaTeX sequence (or press enter for empty): ")
    
    if not partial_sequence:
        # Start with empty/start token
        tokens = [token2idx['<pad>']]
    else:
        # Tokenize input
        tokens = tokenize_partial(partial_sequence, token2idx)
    
    generated = []
    
    with torch.no_grad():
        for _ in range(max_len):
            # Prepare input
            input_seq = torch.tensor([tokens]).to(device)
            
            # Get model output (log probabilities)
            output = model.output_layer.log_prob(
                model.input_layer(input_seq).view(-1, 256)
            )
            
            # Get next token (greedy decoding)
            next_token = output[-1].argmax().item()
            
            # Stop if EOS token
            if idx2token[next_token] == '<eos>':
                break
                
            generated.append(idx2token[next_token])
            tokens.append(next_token)
    
    return ''.join(generated)

def tokenize_partial(latex_str, token2idx):
    """Tokenize partial LaTeX string"""
    tokens = []
    i = 0
    while i < len(latex_str):
        if latex_str[i] == '\\':
            j = i + 1
            while j < len(latex_str) and latex_str[j].isalpha():
                j += 1
            token = latex_str[i:j]
            tokens.append(token2idx.get(token, token2idx['<pad>']))
            i = j
        elif latex_str[i] in ['{', '}', '^', '_', '(', ')', '[', ']', '|', ',', '.', '=', '+', '-', '*', '/', '<', '>', '!', '&']:
            tokens.append(token2idx.get(latex_str[i], token2idx['<pad>']))
            i += 1
        elif latex_str[i] == ' ':
            i += 1
        else:
            tokens.append(token2idx.get(latex_str[i], token2idx['<pad>']))
            i += 1
    return tokens


In [9]:
import pandas as pd

def evaluate_on_test_images():
    """
    Evaluate model predictions on test images
    """
    DATA_PATH = '/home/ie643_errorcode500/errorcode500-working/Mathwritting-1000'
    
    # Load test data
    test_df = pd.read_csv(os.path.join(DATA_PATH, 'test_database_preprocessed.csv'))
    
    # Load vocabulary
    with open(os.path.join(DATA_PATH, 'token2idx.pkl'), 'rb') as f:
        token2idx = pickle.load(f)
    with open(os.path.join(DATA_PATH, 'idx2token.pkl'), 'rb') as f:
        idx2token = pickle.load(f)
    
    # Load model
    device = torch.device('cuda:1')
    model = TMLM(
        vocab_size=len(token2idx),
        num_layers=2,
        pad_idx=token2idx['<pad>']
    ).to(device)
    
    checkpoint = torch.load(os.path.join(DATA_PATH, 'best_tmlm2l_model.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Sample 5 random test cases
    samples = test_df.sample(5)
    
    print("\n" + "="*80)
    print("TEST PREDICTIONS")
    print("="*80 + "\n")
    
    for idx, row in samples.iterrows():
        print(f"Image: {row['filename']}")
        print(f"Ground Truth: {row['preprocessed_label']}")
        print(f"Prediction: [Would require image encoder]")
        print("-"*80 + "\n")

# Run evaluation
evaluate_on_test_images()



TEST PREDICTIONS

Image: c01aa8332aca1e8e.png
Ground Truth: I=\int Fdt
Prediction: [Would require image encoder]
--------------------------------------------------------------------------------

Image: 8e70d3a0a7a20e33.png
Ground Truth: c>\aleph_{0}
Prediction: [Would require image encoder]
--------------------------------------------------------------------------------

Image: 8a4aaa63197ac019.png
Ground Truth: \frac{\partial r_{i}}{\partial\beta_{j}}
Prediction: [Would require image encoder]
--------------------------------------------------------------------------------

Image: 7d3986665243fc8a.png
Ground Truth: \{x\}\times\mathbbR^{n}
Prediction: [Would require image encoder]
--------------------------------------------------------------------------------

Image: 41c6736570400100.png
Ground Truth: d_K^{-p/2}
Prediction: [Would require image encoder]
--------------------------------------------------------------------------------



In [10]:
import torch
import pickle
import os
import pandas as pd

def generate_latex_completion(model, start_tokens, token2idx, idx2token, max_len=50, device='cuda:1'):
    """
    Generate LaTeX completion using the trained language model
    """
    model.eval()
    
    tokens = start_tokens.copy()
    
    with torch.no_grad():
        for _ in range(max_len):
            # Prepare input (pad if needed)
            input_seq = tokens + [token2idx['<pad>']] * (256 - len(tokens))
            input_seq = input_seq[:256]
            input_tensor = torch.tensor([input_seq]).to(device)
            
            # Get embeddings through input layer
            x = model.input_layer(input_tensor)
            
            # Pass through transformer layers
            seq_len = input_tensor.size(1)
            mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0).to(device)
            
            for transformer_layer in model.transformer_layers:
                x = transformer_layer(x, mask)
            
            # Get logits for the last actual token position
            last_pos = len(tokens) - 1
            logits = model.output_layer.log_prob(x[:, last_pos, :])
            
            # Get next token (greedy decoding)
            next_token = logits.argmax().item()
            
            # Stop if EOS or pad token
            if idx2token[next_token] in ['<eos>', '<pad>']:
                break
            
            tokens.append(next_token)
    
    # Convert tokens to string
    result = []
    for token_id in tokens:
        if idx2token[token_id] not in ['<pad>', '<eos>']:
            result.append(idx2token[token_id])
    
    return ''.join(result)


def tokenize_latex(latex_str, token2idx):
    """Tokenize LaTeX string"""
    tokens = []
    i = 0
    while i < len(latex_str):
        if latex_str[i] == '\\':
            j = i + 1
            while j < len(latex_str) and latex_str[j].isalpha():
                j += 1
            token = latex_str[i:j]
            tokens.append(token2idx.get(token, token2idx['<pad>']))
            i = j
        elif latex_str[i] in ['{', '}', '^', '_', '(', ')', '[', ']', '|', ',', '.', '=', '+', '-', '*', '/', '<', '>', '!', '&']:
            tokens.append(token2idx.get(latex_str[i], token2idx['<pad>']))
            i += 1
        elif latex_str[i] == ' ':
            i += 1
        else:
            tokens.append(token2idx.get(latex_str[i], token2idx['<pad>']))
            i += 1
    return tokens


def test_language_model_capabilities():
    """
    Test the trained language model's capabilities
    """
    DATA_PATH = '/home/ie643_errorcode500/errorcode500-working/Mathwritting-1000'
    
    # Load vocabulary
    with open(os.path.join(DATA_PATH, 'token2idx.pkl'), 'rb') as f:
        token2idx = pickle.load(f)
    with open(os.path.join(DATA_PATH, 'idx2token.pkl'), 'rb') as f:
        idx2token = pickle.load(f)
    
    # Load model
    device = torch.device('cuda:1')
    model = TMLM(
        vocab_size=len(token2idx),
        num_layers=2,
        d_model=256,
        num_heads=4,
        d_ff=1024,
        max_seq_len=256,
        dropout=0.1,
        pad_idx=token2idx['<pad>']
    ).to(device)
    
    checkpoint = torch.load(os.path.join(DATA_PATH, 'best_tmlm2l_model.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print("\n" + "="*80)
    print("LANGUAGE MODEL COMPLETION TEST")
    print("="*80 + "\n")
    
    # Test cases: partial LaTeX sequences
    test_cases = [
        "\\frac{x",
        "\\int",
        "a^{2",
        "\\sqrt{",
        "x="
    ]
    
    for partial in test_cases:
        tokens = tokenize_latex(partial, token2idx)
        completed = generate_latex_completion(model, tokens, token2idx, idx2token, device=device)
        print(f"Input:  {partial}")
        print(f"Output: {completed}")
        print("-" * 80)
    
    print("\n" + "="*80)
    print("PERPLEXITY COMPARISON ON TEST SAMPLES")
    print("="*80 + "\n")
    
    # Load test data
    test_df = pd.read_csv(os.path.join(DATA_PATH, 'test_database_preprocessed.csv'))
    
    # Test perplexity on individual samples
    samples = test_df.sample(5)
    
    for idx, row in samples.iterrows():
        latex = row['preprocessed_label']
        tokens = tokenize_latex(latex, token2idx)
        
        # Calculate perplexity for this sequence
        if len(tokens) > 1:
            input_seq = tokens + [token2idx['<eos>']]
            input_seq = input_seq + [token2idx['<pad>']] * (256 - len(input_seq))
            input_seq = input_seq[:256]
            
            input_tensor = torch.tensor([input_seq]).to(device)
            targets = input_tensor[:, 1:]
            inputs = input_tensor[:, :-1]
            
            model.eval()
            with torch.no_grad():
                loss = model(inputs, targets)
            
            print(f"LaTeX: {latex}")
            print(f"Loss: {loss.item():.4f}")
            print(f"Perplexity: {torch.exp(loss).item():.4f}")
            print("-" * 80)


# Run the test
test_language_model_capabilities()



LANGUAGE MODEL COMPLETION TEST

Input:  \frac{x
Output: \frac{x}}}}}}
--------------------------------------------------------------------------------
Input:  \int
Output: \int}
--------------------------------------------------------------------------------
Input:  a^{2
Output: a^{2}}}}
--------------------------------------------------------------------------------
Input:  \sqrt{
Output: \sqrt{2}}}}}
--------------------------------------------------------------------------------
Input:  x=
Output: x={2}}}}
--------------------------------------------------------------------------------

PERPLEXITY COMPARISON ON TEST SAMPLES

LaTeX: \int dx\psi_n^{*}(s)
Loss: 0.2185
Perplexity: 1.2442
--------------------------------------------------------------------------------
LaTeX: \Gamma(\frac{x}{y})\rightarrow y\Gamma(x)
Loss: 0.2238
Perplexity: 1.2508
--------------------------------------------------------------------------------
LaTeX: u:\Omega\rightarrow\mathbb{C}
Loss: 0.1557
Perplexity

In [12]:
import torch
import pickle
import os
import pandas as pd
import math

def evaluate_full_dataset(dataset_name='test'):
    """
    Evaluate model on entire test or validation dataset
    Generate detailed results with perplexity for each sample
    """
    DATA_PATH = '/home/ie643_errorcode500/errorcode500-working/Mathwritting-1000'
    
    # Load vocabulary
    with open(os.path.join(DATA_PATH, 'token2idx.pkl'), 'rb') as f:
        token2idx = pickle.load(f)
    with open(os.path.join(DATA_PATH, 'idx2token.pkl'), 'rb') as f:
        idx2token = pickle.load(f)
    
    # Load model
    device = torch.device('cuda:1')
    model = TMLM(
        vocab_size=len(token2idx),
        num_layers=2,
        d_model=256,
        num_heads=4,
        d_ff=1024,
        max_seq_len=256,
        dropout=0.1,
        pad_idx=token2idx['<pad>']
    ).to(device)
    
    checkpoint = torch.load(os.path.join(DATA_PATH, 'best_tmlm2l_model.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Load dataset
    csv_file = os.path.join(DATA_PATH, f'{dataset_name}_database_preprocessed.csv')
    df = pd.read_csv(csv_file)
    
    print("\n" + "="*100)
    print(f"DETAILED {dataset_name.upper()} DATASET EVALUATION")
    print("="*100 + "\n")
    
    results = []
    total_loss = 0
    total_tokens = 0
    pad_idx = token2idx['<pad>']
    
    for idx, row in df.iterrows():
        latex = row['preprocessed_label']
        filename = row['filename']
        
        # Tokenize
        tokens = tokenize_latex(latex, token2idx)
        
        if len(tokens) > 0:
            # Prepare input
            input_seq = tokens + [token2idx['<eos>']]
            input_seq = input_seq + [pad_idx] * (256 - len(input_seq))
            input_seq = input_seq[:256]
            
            input_tensor = torch.tensor([input_seq]).to(device)
            targets = input_tensor[:, 1:]
            inputs = input_tensor[:, :-1]
            
            # Calculate loss
            with torch.no_grad():
                loss = model(inputs, targets)
            
            # Count non-padding tokens
            mask = (targets != pad_idx).float()
            n_tokens = mask.sum().item()
            
            # Per-sample perplexity
            sample_perplexity = math.exp(loss.item())
            
            # Accumulate
            total_loss += loss.item() * n_tokens
            total_tokens += n_tokens
            
            results.append({
                'filename': filename,
                'latex': latex,
                'loss': loss.item(),
                'perplexity': sample_perplexity,
                'num_tokens': len(tokens)
            })
            
            # Print every sample
            print(f"Sample {idx+1}/{len(df)}")
            print(f"  File: {filename}")
            print(f"  LaTeX: {latex}")
            print(f"  Tokens: {len(tokens)}")
            print(f"  Loss: {loss.item():.4f}")
            print(f"  Perplexity: {sample_perplexity:.4f}")
            print("-" * 100)
    
    # Overall statistics
    avg_loss = total_loss / total_tokens
    overall_perplexity = math.exp(avg_loss)
    
    print("\n" + "="*100)
    print(f"{dataset_name.upper()} DATASET SUMMARY")
    print("="*100)
    print(f"Total samples: {len(df)}")
    print(f"Average loss: {avg_loss:.4f}")
    print(f"Overall perplexity: {overall_perplexity:.4f}")
    print(f"Total tokens: {total_tokens}")
    
    # Statistics
    perplexities = [r['perplexity'] for r in results]
    print(f"\nPerplexity Statistics:")
    print(f"  Min: {min(perplexities):.4f}")
    print(f"  Max: {max(perplexities):.4f}")
    print(f"  Mean: {sum(perplexities)/len(perplexities):.4f}")
    print(f"  Median: {sorted(perplexities)[len(perplexities)//2]:.4f}")
    
    # Save results to CSV
    results_df = pd.DataFrame(results)
    output_file = os.path.join(DATA_PATH, f'{dataset_name}_evaluation_results.csv')
    results_df.to_csv(output_file, index=False)
    print(f"\nDetailed results saved to: {output_file}")
    
    # Show top 10 best and worst predictions
    results_df = results_df.sort_values('perplexity')
    
    print("\n" + "="*100)
    print("TOP 10 BEST PREDICTIONS (Lowest Perplexity)")
    print("="*100)
    for idx, row in results_df.head(10).iterrows():
        print(f"\nPerplexity: {row['perplexity']:.4f}")
        print(f"LaTeX: {row['latex']}")
        print(f"File: {row['filename']}")
    
    print("\n" + "="*100)
    print("TOP 10 WORST PREDICTIONS (Highest Perplexity)")
    print("="*100)
    for idx, row in results_df.tail(10).iterrows():
        print(f"\nPerplexity: {row['perplexity']:.4f}")
        print(f"LaTeX: {row['latex']}")
        print(f"File: {row['filename']}")
    
    return results_df


def tokenize_latex(latex_str, token2idx):
    """Tokenize LaTeX string"""
    tokens = []
    i = 0
    while i < len(latex_str):
        if latex_str[i] == '\\':
            j = i + 1
            while j < len(latex_str) and latex_str[j].isalpha():
                j += 1
            token = latex_str[i:j]
            tokens.append(token2idx.get(token, token2idx['<pad>']))
            i = j
        elif latex_str[i] in ['{', '}', '^', '_', '(', ')', '[', ']', '|', ',', '.', '=', '+', '-', '*', '/', '<', '>', '!', '&']:
            tokens.append(token2idx.get(latex_str[i], token2idx['<pad>']))
            i += 1
        elif latex_str[i] == ' ':
            i += 1
        else:
            tokens.append(token2idx.get(latex_str[i], token2idx['<pad>']))
            i += 1
    return tokens


# Run evaluations
print("Starting Test Dataset Evaluation...")
test_results = evaluate_full_dataset('test')

print("\n\n")
print("Starting Validation Dataset Evaluation...")
val_results = evaluate_full_dataset('val')


Starting Test Dataset Evaluation...

DETAILED TEST DATASET EVALUATION

Sample 1/100
  File: 987b9c49d6879c44.png
  LaTeX: M,w\models I^{\alpha}(e)
  Tokens: 12
  Loss: 0.2096
  Perplexity: 1.2331
----------------------------------------------------------------------------------------------------
Sample 2/100
  File: e8a00f63b814fdd6.png
  LaTeX: dX=\frac{\partial X}{\partial x}dx=F^{-1}dx=HdxordX_{M}=\frac{\partial X_{M}}{\partial x_{n}}dx_{n}
  Tokens: 59
  Loss: 0.8801
  Perplexity: 2.4111
----------------------------------------------------------------------------------------------------
Sample 3/100
  File: 2ea09c1e42b328ee.png
  LaTeX: TU=\sqrt{\frac{DU^{3}}{G*M}}
  Tokens: 20
  Loss: 0.3209
  Perplexity: 1.3784
----------------------------------------------------------------------------------------------------
Sample 4/100
  File: 0184526868aa528e.png
  LaTeX: (\frac{4}{7}-9)^{204\cdot\sqrt{5}}
  Tokens: 22
  Loss: 0.3005
  Perplexity: 1.3505
-------------------------------------

In [13]:
import torch
import pickle
import os
import pandas as pd
import math
from PIL import Image
import matplotlib.pyplot as plt

def visualize_model_understanding():
    """
    Show images alongside ground truth and model confidence
    This shows HOW WELL the model understands the correct LaTeX
    """
    DATA_PATH = '/home/ie643_errorcode500/errorcode500-working/Mathwritting-1000'
    
    # Load vocabulary
    with open(os.path.join(DATA_PATH, 'token2idx.pkl'), 'rb') as f:
        token2idx = pickle.load(f)
    with open(os.path.join(DATA_PATH, 'idx2token.pkl'), 'rb') as f:
        idx2token = pickle.load(f)
    
    # Load model
    device = torch.device('cuda:1')
    model = TMLM(
        vocab_size=len(token2idx),
        num_layers=2,
        d_model=256,
        num_heads=4,
        d_ff=1024,
        max_seq_len=256,
        dropout=0.1,
        pad_idx=token2idx['<pad>']
    ).to(device)
    
    checkpoint = torch.load(os.path.join(DATA_PATH, 'best_tmlm2l_model.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Load test dataset
    test_df = pd.read_csv(os.path.join(DATA_PATH, 'test_database_preprocessed.csv'))
    
    print("\n" + "="*100)
    print("MODEL UNDERSTANDING VISUALIZATION")
    print("="*100)
    print("\nThis shows how CONFIDENT the model is about the CORRECT LaTeX for each image")
    print("Lower perplexity = Higher confidence that this LaTeX is correct")
    print("="*100 + "\n")
    
    # Sample 20 test cases
    samples = test_df.sample(20)
    
    results = []
    pad_idx = token2idx['<pad>']
    
    for idx, row in samples.iterrows():
        latex = row['preprocessed_label']
        filename = row['filename']
        image_path = os.path.join(DATA_PATH, 'test', filename)
        
        # Tokenize
        tokens = tokenize_latex(latex, token2idx)
        
        if len(tokens) > 0:
            # Prepare input
            input_seq = tokens + [token2idx['<eos>']]
            input_seq = input_seq + [pad_idx] * (256 - len(input_seq))
            input_seq = input_seq[:256]
            
            input_tensor = torch.tensor([input_seq]).to(device)
            targets = input_tensor[:, 1:]
            inputs = input_tensor[:, :-1]
            
            # Calculate loss
            with torch.no_grad():
                loss = model(inputs, targets)
            
            sample_perplexity = math.exp(loss.item())
            
            # Check if image exists
            image_exists = os.path.exists(image_path)
            
            print(f"{'='*100}")
            print(f"Image File: {filename}")
            if image_exists:
                print(f"Image Path: {image_path}")
                print(f"✓ Image found")
            else:
                print(f"✗ Image not found at: {image_path}")
            print(f"\nGround Truth LaTeX: {latex}")
            print(f"\nModel Confidence:")
            print(f"  Perplexity: {sample_perplexity:.4f}")
            print(f"  Loss: {loss.item():.4f}")
            
            if sample_perplexity < 1.5:
                print(f"  Assessment: ✓ VERY HIGH confidence - Model strongly agrees this LaTeX is correct")
            elif sample_perplexity < 2.0:
                print(f"  Assessment: ✓ HIGH confidence - Model thinks this LaTeX is likely correct")
            elif sample_perplexity < 3.0:
                print(f"  Assessment: ⚠ MODERATE confidence - Model is somewhat uncertain")
            else:
                print(f"  Assessment: ✗ LOW confidence - Model thinks this LaTeX might be unusual")
            
            print(f"{'='*100}\n")
            
            results.append({
                'filename': filename,
                'latex': latex,
                'perplexity': sample_perplexity,
                'image_exists': image_exists
            })
    
    # Summary
    print("\n" + "="*100)
    print("SUMMARY")
    print("="*100)
    perplexities = [r['perplexity'] for r in results]
    print(f"Average Perplexity: {sum(perplexities)/len(perplexities):.4f}")
    print(f"Best (Lowest): {min(perplexities):.4f}")
    print(f"Worst (Highest): {max(perplexities):.4f}")
    
    high_conf = sum(1 for p in perplexities if p < 1.5)
    print(f"\nHigh Confidence Predictions: {high_conf}/{len(perplexities)} ({100*high_conf/len(perplexities):.1f}%)")
    
    print("\n" + "="*100)
    print("WHAT THIS MEANS:")
    print("="*100)
    print("✓ Your Language Model has learned mathematical LaTeX patterns very well!")
    print("✓ It can identify correct vs incorrect LaTeX with high confidence")
    print("✓ To predict FROM images, you need an image encoder (like SRTC+SLP in the paper)")
    print("\n✗ Currently showing: How well the model understands the CORRECT LaTeX")
    print("✗ NOT showing: Model predictions from scratch (needs image encoder)")
    print("="*100)


def tokenize_latex(latex_str, token2idx):
    """Tokenize LaTeX string"""
    tokens = []
    i = 0
    while i < len(latex_str):
        if latex_str[i] == '\\':
            j = i + 1
            while j < len(latex_str) and latex_str[j].isalpha():
                j += 1
            token = latex_str[i:j]
            tokens.append(token2idx.get(token, token2idx['<pad>']))
            i = j
        elif latex_str[i] in ['{', '}', '^', '_', '(', ')', '[', ']', '|', ',', '.', '=', '+', '-', '*', '/', '<', '>', '!', '&']:
            tokens.append(token2idx.get(latex_str[i], token2idx['<pad>']))
            i += 1
        elif latex_str[i] == ' ':
            i += 1
        else:
            tokens.append(token2idx.get(latex_str[i], token2idx['<pad>']))
            i += 1
    return tokens


# Run visualization
visualize_model_understanding()



MODEL UNDERSTANDING VISUALIZATION

This shows how CONFIDENT the model is about the CORRECT LaTeX for each image
Lower perplexity = Higher confidence that this LaTeX is correct

Image File: 41c6736570400100.png
Image Path: /home/ie643_errorcode500/errorcode500-working/Mathwritting-1000/test/41c6736570400100.png
✓ Image found

Ground Truth LaTeX: d_K^{-p/2}

Model Confidence:
  Perplexity: 1.1772
  Loss: 0.1632
  Assessment: ✓ VERY HIGH confidence - Model strongly agrees this LaTeX is correct

Image File: d0070ee72da60d82.png
Image Path: /home/ie643_errorcode500/errorcode500-working/Mathwritting-1000/test/d0070ee72da60d82.png
✓ Image found

Ground Truth LaTeX: 4x+5y=32

Model Confidence:
  Perplexity: 1.1625
  Loss: 0.1506
  Assessment: ✓ VERY HIGH confidence - Model strongly agrees this LaTeX is correct

Image File: 9a5833d0916a8408.png
Image Path: /home/ie643_errorcode500/errorcode500-working/Mathwritting-1000/test/9a5833d0916a8408.png
✓ Image found

Ground Truth LaTeX: \frac{2.138}{1