## 1. Load & preprocess dataset

In [None]:
from pathlib import Path
from datasets import load_from_disk
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pickle
import torch

print("Loading dataset...")
ds = load_from_disk("/kaggle/input/handwriting/data/mathwriting")
num_samples = 40000
ds_train = ds["train"].select(range(num_samples))

# 2. Pre-allocate Image Array (Saves RAM by avoiding copies)
print(f"Pre-allocating memory for {num_samples} images...")
images_array = np.zeros((num_samples, 256, 256), dtype=np.float32)
latex_strings = []

# 3. Process Images & Collect Strings
print("Processing images and LaTeX strings...")
for i in range(num_samples):
    sample = ds_train[i]
    # Convert and resize directly into the array
    img = sample["image"].convert("L").resize((256, 256))
    images_array[i] = np.array(img, dtype=np.float32) / 255.0
    latex_strings.append(sample["latex"])
    
    if (i + 1) % 5000 == 0:
        print(f"Progress: {i + 1}/{num_samples}")

# 4. Setup Tokenizer
print("Fitting tokenizer...")
tokenizer = Tokenizer(char_level=True)
tokenizer.fit_on_texts(latex_strings)

# Add special tokens
tokenizer.word_index["<START>"] = len(tokenizer.word_index) + 1
tokenizer.word_index["<END>"] = len(tokenizer.word_index) + 1
tokenizer.index_word[tokenizer.word_index["<START>"]] = "<START>"
tokenizer.index_word[tokenizer.word_index["<END>"]] = "<END>"

START_ID = tokenizer.word_index["<START>"]
END_ID   = tokenizer.word_index["<END>"]

# 5. Sequence Padding
print("Tokenizing and padding sequences...")
sequences = tokenizer.texts_to_sequences(latex_strings)
sequences = [[START_ID] + seq + [END_ID] for seq in sequences]
padded_sequences = pad_sequences(sequences, padding="post")

# 6. Save Tokenizer and Vocab Info
print("Saving metadata...")
with open("/kaggle/working/latex_tokenizer256.pkl", "wb") as f:
    pickle.dump(tokenizer, f)

vocab_size = len(tokenizer.word_index) + 1
with open("/kaggle/working/vocab_size.txt", "w") as f:
    f.write(str(vocab_size))

# 7. Convert to Tensors and Save (Disk usage check: ~13.5GB total)
print("converting to Tensors...")
# torch.from_numpy avoids a RAM copy
images_tensor = torch.from_numpy(images_array).unsqueeze(1) 
tokens_tensor = torch.tensor(padded_sequences, dtype=torch.long)

print("saving tensors to disk...")
torch.save(images_tensor, "/kaggle/working/images_train256.pt")
torch.save(tokens_tensor, "/kaggle/working/tokens_train256.pt")

print("Done!")
print(f"Final Vocab Size: {vocab_size}")
print(f"Image Tensor Shape: {images_tensor.shape}")




In [None]:
ds = load_from_disk("/kaggle/input/handwriting/data/mathwriting")

ds_val = ds["val"].select(range(5000))

images, sequences = [], []

def preprocess_image(img, target_size=(256, 256)):
    img = img.convert("L")  # convert to grayscale
    img = img.resize(target_size)
    img = np.array(img) / 255.0  # normalize to [0, 1]
    return img

for sample in ds_val:
    img = preprocess_image(sample["image"])
    images.append(img)
    sequences.append(sample["latex"])

images = np.array(images)
with open("/kaggle/working/latex_tokenizer256.pkl", "rb") as f:
    tokenizer = pickle.load(f)

START_ID = tokenizer.word_index["<START>"]
END_ID   = tokenizer.word_index["<END>"]

seqs = tokenizer.texts_to_sequences(sequences)
seqs = [[START_ID] + s + [END_ID] for s in seqs]

padded_sequences = pad_sequences(seqs, padding="post")
images = images[..., np.newaxis]

images_tensor = torch.tensor(images, dtype=torch.float32).permute(0, 3, 1, 2)
tokens_tensor = torch.tensor(padded_sequences, dtype=torch.long)

torch.save(images_tensor, "/kaggle/working/images_val256.pt")
torch.save(tokens_tensor, "/kaggle/working/tokens_val256.pt")

## 2. Model definition

In [None]:
import torch
import torch.nn as nn
import numpy as np

class CNNEncoder(nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)  # grayscale input
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.conv3 = nn.Conv2d(64, embedding_dim, 3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = torch.relu(self.conv1(x))  # each conv layer extracts complex features
        x = self.pool1(x)  # each pool layer reduces spatial dimensions, keeping important features
        
        x = torch.relu(self.conv2(x))
        x = self.pool2(x)
        
        x = torch.relu(self.conv3(x))
        x = self.pool3(x)

        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1) # (B, 32, 32, 128)
        x = x.view(B, -1, C)      # (B, 1024, 128)
        return x  
print(1)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):  # additive attention mechanism
    def __init__(self, enc_dim, dec_units):
        super().__init__()
        self.W1 = nn.Linear(enc_dim, dec_units)  # to transform encoder features to same dimension as decoder hidden state
        self.W2 = nn.Linear(dec_units, dec_units)
        self.V = nn.Linear(dec_units, 1)  # to get attention scores for each encoder feature
    def forward(self, encoder_features, hidden):
        # hidden: (B, dec_units) -> (B, 1, dec_units) to broadcast
        hidden_with_time = hidden.unsqueeze(1)
        # score: (B, seq_len, 1)
        score = self.V(torch.tanh(self.W1(encoder_features) + self.W2(hidden_with_time)))
        # attention_weights: (B, seq_len, 1)
        attention_weights = F.softmax(score, dim=1)  # softmax over seq_len to get weights that sum to 1
        # context_vector: weighted sum over encoder features
        context_vector = torch.sum(attention_weights * encoder_features, dim=1)  # (B, enc_dim)
        # context_vector is the weighted average of encoder features
        return context_vector, attention_weights  # attention_weights show how much each encoder feature contributed to the context vector

print(1)

In [None]:
import torch
import torch.nn as nn

class RNNDecoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim=256, rnn_units=512, enc_dim=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)  # map token ids to embeddings
        self.attention = Attention(enc_dim, rnn_units)  # attention mechanism to focus on relevant encoder features
        self.lstm = nn.LSTM(embedding_dim + enc_dim, rnn_units, batch_first=True)  # embedding_dim + enc_dim so that it can see the image features at each time step
        self.enc_to_h = nn.Linear(enc_dim, rnn_units)  # change encoder features to initial hidden state size
        self.fc = nn.Linear(rnn_units, vocab_size)  # final output: logits for each token, size vocab_size

    def forward(self, x, encoder_features, hidden=None):
        B, T = x.shape
        x_embed = self.embedding(x)  # turn token IDs into embeddings
        outputs = []
        if hidden is None:
            mean_features = encoder_features.mean(dim = 1)
            h0 = torch.tanh(self.enc_to_h(mean_features)).unsqueeze(0)  # map encoder features to initial hidden state
            c0 = torch.zeros_like(h0)  # initial cell state
            hidden = (h0, c0)

        h_t, c_t = hidden
        for t in range(T):
            x_t = x_embed[:, t, :]  # (B, embedding_dim)
            context, attn_weights = self.attention(encoder_features, h_t.squeeze(0))
            lstm_input = torch.cat([x_t, context], dim=-1).unsqueeze(1)
            out, (h_t, c_t) = self.lstm(lstm_input, (h_t, c_t))
            outputs.append(out)
        output = torch.cat(outputs, dim=1)  # (B, T, rnn_units)
        logits = self.fc(output)  # (batch, seq_len, vocab_size)
        # state_h is the final hidden state of the current output and state_c is a memory to remember important info from previous tokens
        return logits, (h_t, c_t)


## 3. Training

In [None]:
import sys
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

embedding_dim = 128
rnn_units = 1024
with open("/kaggle/working/vocab_size.txt") as f:
    VOCAB_SIZE = int(f.read().strip())
BATCH_SIZE = 16
EPOCHS = 4
learning_rate = 0.0005

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = CNNEncoder(embedding_dim=embedding_dim).to(DEVICE)
decoder = RNNDecoder(vocab_size=VOCAB_SIZE, embedding_dim=embedding_dim, rnn_units=rnn_units, enc_dim = embedding_dim).to(DEVICE)

images = torch.load("/kaggle/working/images_train256.pt")  # (N, 1, 128, 128)
tokens = torch.load("/kaggle/working/tokens_train256.pt")  # (N, seq_len)

dataset = TensorDataset(images, tokens)  # pair images and token sequences
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)  # creates an iterable over dataset


criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding index 0
params = list(encoder.parameters()) + list(decoder.parameters())  # get all trainable parameters (weights and biases)
optimizer = torch.optim.Adam(params, lr=learning_rate)

for epoch in range(EPOCHS):
    encoder.train()
    decoder.train()
    total_loss = 0

    for imgs, seqs in loader:
        imgs = imgs.to(DEVICE)
        seqs = seqs.to(DEVICE)

        # teacher forcing, use ground truth tokens as input instead of previous predictions so that model learns to predict next token better
        input_tokens = seqs[:, :-1]   # (B, seq_len-1), takes all but last token
        target_tokens = seqs[:, 1:]   # (B, seq_len-1), takes all but first token

        optimizer.zero_grad()  # clear previous gradients

        # encode images
        image_features = encoder(imgs)  # (B, 256)

        # decode sequences, pass encoder_features for initial hidden
        logits, _ = decoder(input_tokens, encoder_features=image_features)  # logits: (B, seq_len-1, vocab_size)

        # compute loss
        loss = criterion(
            logits.reshape(-1, VOCAB_SIZE),  # (B * (seq_len-1), vocab_size)
            target_tokens.reshape(-1)  # (B * (seq_len-1)
        )

        loss.backward()  # backpropagate
        optimizer.step()  # update weights

        total_loss += loss.item()  # accumulate loss

    print(f"Epoch {epoch + 1}/{EPOCHS} | Loss: {total_loss / len(loader):.4f}")
    # if (epoch + 1) % 1 == 0:
    #   torch.save({
    #       "encoder": encoder.state_dict(),
    #       "decoder": decoder.state_dict(),
    #   }, f"/content/drive/MyDrive/ColabNotebooks/handtex_epoch_{epoch+1}.pt")

SAVE_PATH = "/kaggle/working/handtex_model256.pt"
torch.save({
    "encoder": encoder.state_dict(),
    "decoder": decoder.state_dict()
}, SAVE_PATH)

In [None]:
import itertools
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

images_train = torch.load("/kaggle/working/images_train256.pt")
tokens_train = torch.load("/kaggle/working/tokens_train256.pt")
images_val = torch.load("/kaggle/working/images_val256.pt")
tokens_val = torch.load("/kaggle/working/tokens_val256.pt")

VOCAB_SIZE = int(open("/kaggle/working/vocab_size.txt").read().strip())

train_dataset = TensorDataset(images_train, tokens_train)
val_dataset = TensorDataset(images_val, tokens_val)

param_grid = {
    "embedding_dim": [256],      
    "rnn_units": [512],          
    "learning_rate": [0.0005],  
    "batch_size": [16]     
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 7
 
all_combinations = list(itertools.product(
    param_grid["embedding_dim"],
    param_grid["rnn_units"],
    param_grid["learning_rate"],
    param_grid["batch_size"]
))

best_val_loss = float("inf")
best_params = None

for emb_dim, rnn_units, lr, batch_size in all_combinations:
    print(f"\nTesting: emb_dim={emb_dim}, rnn_units={rnn_units}, lr={lr}, batch_size={batch_size}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    encoder = CNNEncoder(embedding_dim=emb_dim).to(DEVICE)
    decoder = RNNDecoder(
        vocab_size=VOCAB_SIZE,
        embedding_dim=emb_dim,
        rnn_units=rnn_units,
        enc_dim=emb_dim
    ).to(DEVICE)

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=lr
    )

    best_epoch_val_loss = float("inf")

    for epoch in range(EPOCHS):
        # training
        encoder.train()
        decoder.train()
        train_loss = 0.0

        for imgs, seqs in train_loader:
            imgs, seqs = imgs.to(DEVICE), seqs.to(DEVICE)
            input_tokens, target_tokens = seqs[:, :-1], seqs[:, 1:]

            optimizer.zero_grad()
            features = encoder(imgs)
            logits, _ = decoder(input_tokens, encoder_features=features)

            loss = criterion(
                logits.reshape(-1, VOCAB_SIZE),
                target_tokens.reshape(-1)
            )
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        # validation
        encoder.eval()
        decoder.eval()
        val_loss = 0.0

        with torch.no_grad():
            for imgs, seqs in val_loader:
                imgs, seqs = imgs.to(DEVICE), seqs.to(DEVICE)
                input_tokens, target_tokens = seqs[:, :-1], seqs[:, 1:]

                features = encoder(imgs)
                logits, _ = decoder(input_tokens, encoder_features=features)

                loss = criterion(
                    logits.reshape(-1, VOCAB_SIZE),
                    target_tokens.reshape(-1)
                )

                val_loss += loss.item()

        val_loss /= len(val_loader)
        best_epoch_val_loss = min(best_epoch_val_loss, val_loss)

        print(
            f"Epoch {epoch + 1}/{EPOCHS} | "
            f"Train: {train_loss:.4f} | Val: {val_loss:.4f}"
        )
    if best_epoch_val_loss < best_val_loss:
        best_val_loss = best_epoch_val_loss
        best_params = {
            "embedding_dim": emb_dim,
            "rnn_units": rnn_units,
            "learning_rate": lr,
            "batch_size": batch_size
        }

    del encoder, decoder, optimizer
    torch.cuda.empty_cache()

print("\nBest params:", best_params)
print("Best val loss:", best_val_loss)

## 4. Evals

In [None]:
import torch
from pickle import load
import sys
from pathlib import Path
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("/kaggle/working/vocab_size.txt") as f:
    VOCAB_SIZE = int(f.read().strip())

START_TOKEN = VOCAB_SIZE - 2
END_TOKEN = VOCAB_SIZE - 1
MAX_LEN = 150

# Load models
encoder = CNNEncoder().to(DEVICE)

decoder = RNNDecoder(
    vocab_size=66, 
    embedding_dim=128,  # match trained model
    rnn_units=1024,      # match trained model
    enc_dim=128          # match trained model
)
encoder.to(DEVICE) 
decoder.to(DEVICE) 

checkpoint = torch.load("/kaggle/input/datasets/martinvu7/cnn-lstm2/handtex_model256.pt", map_location=DEVICE)  # load trained weights
encoder.load_state_dict(checkpoint["encoder"])  # all learned weights and biases
decoder.load_state_dict(checkpoint["decoder"])

encoder.eval()  # set to evaluation mode
decoder.eval()

images = torch.load("/kaggle/working/images_val256.pt")
tokens = torch.load("/kaggle/working/tokens_val256.pt")

with open("/kaggle/working/latex_tokenizer256.pkl", "rb") as f:
    tokenizer = load(f)

inv_vocab = {v: k for k, v in tokenizer.word_index.items()}  # reverse mapping
def decode(seq):
    return "".join(inv_vocab.get(t, "") for t in seq)

# define NED function somewhere above
def normalized_edit_distance(pred, target):
    m, n = len(pred), len(target)
    if m == 0 and n == 0:
        return 0.0
    dp = list(range(n + 1))
    for i in range(1, m + 1):
        prev = dp[0]
        dp[0] = i
        for j in range(1, n + 1):
            temp = dp[j]
            if pred[i - 1] == target[j - 1]:
                dp[j] = prev
            else:
                dp[j] = 1 + min(dp[j], dp[j - 1], prev)
            prev = temp
    return dp[n] / max(m, n)

# inference
N = 500  # how many samples to test
total = 0
for i in range(N):
    img = images[i:i+1].to(DEVICE)  # keep batch dimension
    gt_tokens = tokens[i]
    with torch.no_grad():
        features = encoder(img)
        input_token = torch.tensor([[START_TOKEN]], device=DEVICE)
        hidden = None
        pred_tokens = []

        for _ in range(MAX_LEN):
            logits, hidden = decoder(input_token, features, hidden=hidden)
            next_token = logits.argmax(-1)  # index of highest logit
            token_id = next_token.item() # get scalar token ID
            if token_id == END_TOKEN:
                break
            pred_tokens.append(token_id)
            input_token = next_token  # feed predicted token back

    gt_str = decode(gt_tokens.tolist())
    pred_str = decode(pred_tokens)
    ned = normalized_edit_distance(pred_str, gt_str)
    
    print(f"Sample {i+1}:")
    print("GT:  ", gt_str)
    print("PRED:", pred_str)
    print("NED: ", ned)
    print("-"*40)
    total += ned
print(total / 500)