In [None]:
import pandas as pd
import numpy as np
import cv2
import os
from torch.utils.data import Dataset, DataLoader
import torch

def get_directional_kernels():
    # 8 edge detection kernels: N, NE, E, SE, S, SW, W, NW
    k = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) # Vertical
    kernels = [
        k,                              # S
        np.rot90(k, 1),                 # W
        np.rot90(k, 2),                 # N
        np.rot90(k, 3),                 # E
        np.fliplr(k),                   # SW
        np.flipud(k),                   # NE
        np.fliplr(np.rot90(k, 1)),      # NW
        np.flipud(np.rot90(k, 3)),      # SE
    ]
    return kernels

def get_directional_maps(image):
    kernels = get_directional_kernels()
    edge_maps = []
    for kern in kernels:    
        em = cv2.filter2D(image, cv2.CV_32F, kern, borderType=cv2.BORDER_REPLICATE)
        em = np.abs(em)
        maxv = em.max()
        if maxv > 1e-8:
            em = em / maxv 
        else:
            em = np.zeros_like(em, dtype=np.float32)
        
        edge_maps.append(em.astype(np.float32))
    return np.stack(edge_maps, axis=0)

class MathEquation9ChDataset(Dataset):
    def __init__(self, csv_file, dataset_root, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.dataset_root = dataset_root
        self.transform = transform
        
        # Normalize image paths in the dataframe
        self.data_frame['image_path'] = self.data_frame['image_path'].apply(
            lambda x: os.path.normpath(x).replace('\\', '/')
        )

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        relative_img_path = self.data_frame.iloc[idx]['image_path']
        img_full_path = os.path.join(self.dataset_root, relative_img_path)
        # Normalize the full path as well
        img_full_path = os.path.normpath(img_full_path).replace('\\', '/')
        
        image = cv2.imread(img_full_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_full_path}")
        image = image.astype(np.float32) / 255.0
        H, W = image.shape
        #print(H, W)
        # 9 channel construction
        channels = np.zeros((9, H, W), dtype=np.float32)
        channels[0] = image  # Greyscale base
        channels[1:] = get_directional_maps(image)  # 8 directions
        label = self.data_frame.iloc[idx]['normalized_label']
        sample = {'image': torch.tensor(channels, dtype=torch.float32), 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Usage:
DATASET_ROOT = '/Users/parvjain/Downloads/ProccessMathwritting-exercpt'
TRAIN_CSV = os.path.join(DATASET_ROOT, 'train_database.csv')

# Let's first check if the CSV exists and print its contents
if os.path.exists(TRAIN_CSV):
    df = pd.read_csv(TRAIN_CSV)
    print("CSV file loaded successfully")
    print("Columns:", df.columns.tolist())
    print("\nFirst few image paths:")
    print(df['image_path'].head())
else:
    print(f"CSV file not found at {TRAIN_CSV}")

train_dataset = MathEquation9ChDataset(TRAIN_CSV, DATASET_ROOT)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

try:
    for batch in train_loader:
        images, labels = batch['image'], batch['label']
        print(f"\nBatch loaded successfully")
        print(f"Image tensor shape: {images.shape}")
        print("First few labels:", labels[:3])
        break
except Exception as e:
    print(f"\nError loading batch: {str(e)}")

CSV file loaded successfully
Columns: ['image_path', 'sample_id', 'label', 'normalized_label', 'split', 'ink_creation_method', 'label_creation_method', 'original_path', 'is_symbol']

First few image paths:
0    train\000aa4c444cba3f2.png
1    train\004970a2ad0fcb27.png
2    train\0050464363a7d02d.png
3    train\0053f4751a1d9065.png
4    train\005f0a6b379cc5db.png
Name: image_path, dtype: object

Batch loaded successfully
Image tensor shape: torch.Size([8, 9, 480, 1600])
First few labels: ['\\vec{p}_{0}=(0,1,0)', '\\underline{P}X=\\emptyset', '\\Rightarrow ln\\frac{x(t)}{x(0)}=kt']


In [33]:
import pandas as pd
from collections import Counter

# Load all labels from train/val/test CSVs
csv_files = [
    'train_database.csv',
    'val_database.csv',
    'test_database.csv'
]
DATASET_ROOT = '/Users/parvjain/Downloads/ProccessMathwritting-exercpt'

all_labels = []
for csv_file in csv_files:
    df = pd.read_csv(os.path.join(DATASET_ROOT, csv_file))
    all_labels.extend(df['normalized_label'].astype(str).tolist())

# Build character-level vocabulary
special_tokens = ['<PAD>', '<SOS>', '<EOS>']
char_counter = Counter()
for label in all_labels:
    char_counter.update(list(label))

vocab = special_tokens + sorted(char_counter.keys())
char2idx = {ch: idx for idx, ch in enumerate(vocab)}
idx2char = {idx: ch for ch, idx in char2idx.items()}

print(f"Vocabulary size: {len(vocab)}")
print("First 20 tokens:", vocab[:20])

# Encode a label string to indices
def encode_label(label, max_len=128):
    tokens = [char2idx['<SOS>']] + [char2idx[ch] for ch in label] + [char2idx['<EOS>']]
    if len(tokens) < max_len:
        tokens += [char2idx['<PAD>']] * (max_len - len(tokens))
    else:
        tokens = tokens[:max_len]
    return tokens

# Example usage
sample_label = all_labels[0]
encoded = encode_label(sample_label)
print("Original label:", sample_label)
print("Encoded:", encoded[:20])

# For your dataset class, you can add:
# label_indices = encode_label(label)
# sample = {'image': image_tensor, 'label': label_indices}


Vocabulary size: 91
First 20 tokens: ['<PAD>', '<SOS>', '<EOS>', ' ', '!', '#', '&', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4']
Original label: \vartheta=-\frac{log\frac{\phi_{\varsigma_{1}}}{\phi_{\varsigma_{2}}}}{log\frac{\varsigma_{1}}{\varsigma_{2}}}
Encoded: [1, 58, 83, 62, 79, 81, 69, 66, 81, 62, 28, 12, 58, 67, 79, 62, 64, 88, 73, 76]


In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=4, dropout_p=0.0):
        super().__init__()
        layers = []
        for i in range(num_layers):
            layers.append(nn.Conv2d(
                in_channels if i == 0 else out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1
            ))
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU(inplace=True))
            if dropout_p > 0:
                layers.append(nn.Dropout2d(p=dropout_p))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class WatcherFCN(nn.Module):
    def __init__(self, in_channels=9):
        super().__init__()
        # First blocks without dropout
        self.block1 = ConvBlock(in_channels, 32)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.block2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool2d(2, 2)
        # Last blocks with 20% dropout
        self.block3 = ConvBlock(64, 64)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.block4 = ConvBlock(64, 128, dropout_p=0.2)
        self.pool4 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.block1(x)
        x = self.pool1(x)
        x = self.block2(x)
        x = self.pool2(x)
        x = self.block3(x)
        x = self.pool3(x)
        x = self.block4(x)
        x = self.pool4(x)
        return x

# Example usage:
model = WatcherFCN(in_channels=9)
dummy_input = torch.randn(2, 9, 480, 1600)
output = model(dummy_input)
print(output.shape) 


torch.Size([2, 128, 30, 100])


In [43]:
batch_size, channels, height, width = output.shape
encoder_outputs = output.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
# encoder_outputs: [batch, 3000, 512]
encoder_outputs.shape

torch.Size([2, 3000, 128])

In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CoverageAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim, coverage_dim):
        super().__init__()
        self.W_a = nn.Linear(decoder_dim, attention_dim)
        self.U_a = nn.Linear(encoder_dim, attention_dim)
        self.U_f = nn.Linear(coverage_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)

    def forward(self, encoder_outputs, decoder_hidden, coverage):
        # encoder_outputs: [batch, L, encoder_dim]
        # decoder_hidden: [batch, decoder_dim]
        # coverage: [batch, L, coverage_dim]
        Wh = self.W_a(decoder_hidden).unsqueeze(1)  # [batch, 1, att_dim]
        Ua = self.U_a(encoder_outputs)              # [batch, L, att_dim]
        Uf = self.U_f(coverage)                     # [batch, L, att_dim]
        att = torch.tanh(Wh + Ua + Uf)              # [batch, L, att_dim]
        scores = self.v(att).squeeze(-1)            # [batch, L]
        alpha = F.softmax(scores, dim=1)            # [batch, L]
        context = torch.sum(encoder_outputs * alpha.unsqueeze(-1), dim=1)  # [batch, encoder_dim]
        return context, alpha

# class ParserGRUDecoder(nn.Module):
#     def __init__(self, vocab_size, encoder_dim=128, embed_dim=256, decoder_dim=256, attention_dim=256, coverage_dim=1):
#         super().__init__()
#         self.embedding = nn.Embedding(vocab_size, embed_dim)
#         self.gru = nn.GRUCell(embed_dim + encoder_dim, decoder_dim)
#         self.attention = CoverageAttention(encoder_dim, decoder_dim, attention_dim, coverage_dim)
#         self.fc = nn.Linear(decoder_dim + encoder_dim, vocab_size)

#     def forward(self, encoder_outputs, targets, max_len):
#         batch_size, L, encoder_dim = encoder_outputs.size()
#         device = encoder_outputs.device
#         coverage = torch.zeros(batch_size, L, 1, device=device)
#         inputs = torch.full((batch_size,), 1, dtype=torch.long, device=device)  # <SOS> token index
#         hidden = torch.zeros(batch_size, 256, device=device)
#         outputs = []
#         for t in range(max_len):
#             embedded = self.embedding(inputs)  # [batch, embed_dim]
#             context, alpha = self.attention(encoder_outputs, hidden, coverage)
#             gru_input = torch.cat([embedded, context], dim=1)
#             hidden = self.gru(gru_input, hidden)
#             output = self.fc(torch.cat([hidden, context], dim=1))
#             outputs.append(output)
#             # Teacher forcing: use ground truth if available
#             if targets is not None and t < targets.size(1):
#                 inputs = targets[:, t]
#             else:
#                 inputs = output.argmax(dim=1)
#             coverage = coverage + alpha.unsqueeze(-1)
#         outputs = torch.stack(outputs, dim=1)  # [batch, max_len, vocab_size]
#         return outputs


#Modified ParserGRUDecoder

class ParserGRUDecoder(nn.Module):
    def __init__(self, vocab_size, encoder_dim=128, embed_dim=256, decoder_dim=256, attention_dim=256, coverage_dim=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # Modify input size to include context vector
        self.gru = nn.GRUCell(embed_dim + encoder_dim, decoder_dim)
        self.attention = CoverageAttention(encoder_dim, decoder_dim, attention_dim, coverage_dim)
        # Change output layer to use all available information
        self.out = nn.Sequential(
            nn.Linear(decoder_dim + encoder_dim + embed_dim, decoder_dim),
            nn.Tanh(),
            nn.Linear(decoder_dim, vocab_size)
        )
        self.decoder_dim = decoder_dim

    def forward(self, encoder_outputs, targets, max_len):
        batch_size, L, encoder_dim = encoder_outputs.size()
        device = encoder_outputs.device
        coverage = torch.zeros(batch_size, L, 1, device=device)
        inputs = torch.full((batch_size,), 1, dtype=torch.long, device=device)  # <SOS> token index
        hidden = torch.zeros(batch_size, self.decoder_dim, device=device)
        outputs = []

        for t in range(max_len):
            # 1. Get current input embedding
            embedded = self.embedding(inputs)  # [batch, embed_dim]
            
            # 2. Calculate attention and context
            context, alpha = self.attention(encoder_outputs, hidden, coverage)
            
            # 3. Update GRU hidden state with concatenated input
            gru_input = torch.cat([embedded, context], dim=1)
            hidden = self.gru(gru_input, hidden)
            
            # 4. Generate output using all available information
            # Concatenate current embedding, hidden state, and context
            output = self.out(torch.cat([embedded, hidden, context], dim=1))
            outputs.append(output)
            
            # 5. Teacher forcing or use own predictions
            if targets is not None and t < targets.size(1):
                inputs = targets[:, t]
            else:
                inputs = output.argmax(dim=1)
            
            # 6. Update coverage vector
            coverage = coverage + alpha.unsqueeze(-1)
        
        outputs = torch.stack(outputs, dim=1)  # [batch, max_len, vocab_size]
        return outputs

# Example usage:
# encoder_outputs: [batch, L, encoder_dim] (flatten FCN output to [batch, L, 512])
# targets: [batch, max_len] (token indices)
# decoder = ParserGRUDecoder(vocab_size=len(vocab))
# outputs = decoder(encoder_outputs, targets, max_len)


This is the main training code which we were using before. Now I am using a better code which is below this commented cell


In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from tqdm import tqdm

# # Initialize models with proper configuration
# watcher = WatcherFCN(in_channels=9)  # 9-channel input as defined in dataset
# decoder = ParserGRUDecoder(vocab_size=len(vocab))  # vocab was defined in previous cell

# # Device configuration
# device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
# print(f"Using device: {device}")

# # Move models to device
# watcher = watcher.to(device)
# decoder = decoder.to(device)

# pad_idx = vocab.index('<PAD>')
# criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
# optimizer = optim.Adadelta(list(watcher.parameters()) + list(decoder.parameters()))

# num_epochs = 10
# max_len = 128

# # Learning rate scheduler for better convergence
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
# best_loss = float('inf')

# # Rest of the training code remains the same...

# try:
#     for epoch in range(num_epochs):
#         watcher.train()
#         decoder.train()
#         total_loss = 0
#         batch_count = 0
        
#         # Add progress bar
#         pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
#         for batch in pbar:
#             # Move batch to device
#             images = batch['image'].to(device)
#             labels = [encode_label(lbl, max_len) for lbl in batch['label']]
#             labels = torch.tensor(labels, dtype=torch.long, device=device)

#             optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()
            
#             try:
#                 watcher_output = watcher(images)
#                 batch_size, channels, height, width = watcher_output.shape
#                 encoder_outputs = watcher_output.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)

#                 outputs = decoder(encoder_outputs, labels, max_len)
#                 outputs = outputs.view(-1, outputs.size(-1))
#                 labels = labels.view(-1)

#                 loss = criterion(outputs, labels)
#                 loss.backward()
                
#                 # Gradient clipping
#                 torch.nn.utils.clip_grad_norm_(list(watcher.parameters()) + list(decoder.parameters()), max_norm=5.0)
#                 optimizer.step()

#                 # Update metrics
#                 total_loss += loss.item()
#                 batch_count += 1
                
#                 # Update progress bar
#                 pbar.set_postfix({'loss': f'{loss.item():.4f}'})
                
#             except RuntimeError as e:
#                 print(f"Error in batch: {str(e)}")
#                 continue

#         # Calculate average loss
#         avg_loss = total_loss / batch_count
#         print(f"\nEpoch {epoch+1}, Average Loss: {avg_loss:.4f}")
        
#         # Learning rate scheduling
#         scheduler.step(avg_loss)
        
#         # Save best model
#         if avg_loss < best_loss:
#             best_loss = avg_loss
#             torch.save({
#                 'epoch': epoch,
#                 'watcher_state_dict': watcher.state_dict(),
#                 'decoder_state_dict': decoder.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'loss': best_loss,
#             }, 'best_model.pth')

# except KeyboardInterrupt:
#     print("\nTraining interrupted by user")
# except Exception as e:
#     print(f"\nError during training: {str(e)}")
# finally:
#     # Save final model
#     torch.save({
#         'watcher_state_dict': watcher.state_dict(),
#         'decoder_state_dict': decoder.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'loss': total_loss / len(train_loader) if 'total_loss' in locals() else None,
#     }, 'final_model.pth')

Using device: mps


Epoch 1/10:   0%|          | 0/13 [00:06<?, ?it/s]


Training interrupted by user





In [45]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F

# ==============================
# Training Configuration
# ==============================
device = torch.device('cuda' if torch.cuda.is_available() 
                      else 'mps' if torch.backends.mps.is_available() 
                      else 'cpu')
print(f"Using device: {device}")

watcher = WatcherFCN(in_channels=9).to(device)
decoder = ParserGRUDecoder(vocab_size=len(vocab)).to(device)

pad_idx = vocab.index('<PAD>')
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.Adadelta(list(watcher.parameters()) + list(decoder.parameters()))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

num_epochs = 10
max_len = 128
best_loss = float('inf')

# ==============================
# Helper: Apply Weight Noise Regularization
# ==============================
def apply_weight_noise(model, std=0.01):
    """Adds Gaussian noise to model weights for regularization."""
    with torch.no_grad():
        for p in model.parameters():
            if p.requires_grad:
                p.add_(torch.randn_like(p) * std)

# ==============================
# Training Loop
# ==============================
try:
    for epoch in range(num_epochs):
        watcher.train()
        decoder.train()
        total_loss = 0
        batch_count = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch in pbar:
            images = batch['image'].to(device)
            labels = [encode_label(lbl, max_len) for lbl in batch['label']]
            labels = torch.tensor(labels, dtype=torch.long, device=device)

            optimizer.zero_grad(set_to_none=True)
            
            try:
                watcher_output = watcher(images)
                batch_size, channels, height, width = watcher_output.shape
                encoder_outputs = watcher_output.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)

                outputs = decoder(encoder_outputs, labels, max_len)
                outputs = outputs.view(-1, outputs.size(-1))
                labels = labels.view(-1)

                loss = criterion(outputs, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(list(watcher.parameters()) + list(decoder.parameters()), max_norm=5.0)
                optimizer.step()

                total_loss += loss.item()
                batch_count += 1
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
                
            except RuntimeError as e:
                print(f"Error in batch: {str(e)}")
                continue

        avg_loss = total_loss / batch_count
        print(f"\nEpoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        # Apply weight noise (annealing regularization)
        apply_weight_noise(watcher, std=0.01)
        apply_weight_noise(decoder, std=0.01)

        scheduler.step(avg_loss)

        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'watcher_state_dict': watcher.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, 'best_model.pth')

except KeyboardInterrupt:
    print("\nTraining interrupted by user")
except Exception as e:
    print(f"\nError during training: {str(e)}")
finally:
    torch.save({
        'watcher_state_dict': watcher.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss / len(train_loader) if 'total_loss' in locals() else None,
    }, 'final_model.pth')
    print("Final model saved.")


# ==============================
# Beam Search Decoding
# ==============================
@torch.no_grad()
def beam_search_decode(watcher, decoder, image, beam_width=10, max_len=128):
    """
    Performs beam search decoding as per paper:
    - Start from <SOS>
    - Expand top-k hypotheses at each step
    - Stop at <EOS>
    """
    watcher.eval()
    decoder.eval()

    # Encode image
    image = image.unsqueeze(0).to(device)
    encoder_out = watcher(image)
    batch_size, channels, height, width = encoder_out.shape
    encoder_outputs = encoder_out.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)

    # Initial setup
    start_token = torch.tensor([vocab.index('<SOS>')], device=device)
    end_token = vocab.index('<EOS>')
    coverage = torch.zeros(1, encoder_outputs.size(1), 1, device=device)
    hidden = torch.zeros(1, 256, device=device)

    beams = [(start_token, hidden, 0.0, coverage)]  # (sequence, hidden, score, coverage)
    completed = []

    for _ in range(max_len):
        new_beams = []
        for seq, h, score, cov in beams:
            if seq[-1].item() == end_token:
                completed.append((seq, score))
                continue

            embedded = decoder.embedding(seq[-1].unsqueeze(0))
            context, alpha = decoder.attention(encoder_outputs, h, cov)
            gru_input = torch.cat([embedded, context], dim=1)
            new_hidden = decoder.gru(gru_input, h)
            output = decoder.fc(torch.cat([new_hidden, context], dim=1))
            log_probs = F.log_softmax(output, dim=1)

            topk_probs, topk_idx = log_probs.topk(beam_width, dim=1)
            for k in range(beam_width):
                next_seq = torch.cat([seq, topk_idx[0, k].unsqueeze(0)])
                new_beams.append((next_seq, new_hidden, score + topk_probs[0, k].item(), cov + alpha.unsqueeze(-1)))

        # Keep top-k beams
        beams = sorted(new_beams, key=lambda x: x[2], reverse=True)[:beam_width]

        # If all beams ended
        if all(seq[-1].item() == end_token for seq, _, _, _ in beams):
            break

    if not completed:
        completed = beams

    best_seq = max(completed, key=lambda x: x[1])[0]
    return [vocab[idx.item()] for idx in best_seq if idx.item() not in (start_token, end_token)]


# Example usage:
# image = dataset[0]['image']
# predicted_seq = beam_search_decode(watcher, decoder, image)
# print("Predicted LaTeX:", " ".join(predicted_seq))


Using device: mps


Epoch 1/10:   8%|▊         | 1/13 [00:08<01:36,  8.02s/it]

Error in batch: MPS backend out of memory (MPS allocated: 8.28 GB, other allocations: 750.66 MB, max allowed: 9.07 GB). Tried to allocate 375.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


Epoch 1/10:  15%|█▌        | 2/13 [00:10<00:53,  4.82s/it]

Error in batch: MPS backend out of memory (MPS allocated: 8.28 GB, other allocations: 750.66 MB, max allowed: 9.07 GB). Tried to allocate 375.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


Epoch 1/10:  23%|██▎       | 3/13 [00:12<00:36,  3.66s/it]

Error in batch: MPS backend out of memory (MPS allocated: 8.28 GB, other allocations: 750.66 MB, max allowed: 9.07 GB). Tried to allocate 375.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


Epoch 1/10:  31%|███       | 4/13 [00:16<00:32,  3.64s/it]

Error in batch: MPS backend out of memory (MPS allocated: 8.28 GB, other allocations: 750.66 MB, max allowed: 9.07 GB). Tried to allocate 375.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


Epoch 1/10:  31%|███       | 4/13 [00:17<00:38,  4.27s/it]



Training interrupted by user
Final model saved.


---------------------------------------------Kana TRY TO RUN TILL HERE-------------------------------------

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.optim.lr_scheduler import ReduceLROnPlateau
# from tqdm import tqdm
# import os
# from datetime import datetime

# # Initialize models with proper configuration
# watcher = WatcherFCN(in_channels=9)
# decoder = ParserGRUDecoder(vocab_size=len(vocab))

# # Device configuration
# device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
# print(f"Using device: {device}")

# # Move models to device
# watcher = watcher.to(device)
# decoder = decoder.to(device)

# # Training configuration
# pad_idx = vocab.index('<PAD>')
# criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
# optimizer = optim.Adadelta(list(watcher.parameters()) + list(decoder.parameters()), rho=0.95)

# num_epochs = 10
# max_len = 128
# save_dir = 'checkpoints'
# os.makedirs(save_dir, exist_ok=True)

# # Learning rate scheduler (without verbose parameter)
# scheduler = ReduceLROnPlateau(
#     optimizer,
#     mode='min',
#     patience=2,
#     factor=0.5,
#     min_lr=1e-6
# )
# best_loss = float('inf')

# try:
#     for epoch in range(num_epochs):
#         watcher.train()
#         decoder.train()
#         total_loss = 0
#         batch_count = 0
        
#         pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
#         for batch_idx, batch in enumerate(pbar):
#             images = batch['image'].to(device)
#             labels = [encode_label(lbl, max_len) for lbl in batch['label']]
#             labels = torch.tensor(labels, dtype=torch.long, device=device)

#             optimizer.zero_grad(set_to_none=True)
            
#             try:
#                 # Forward pass
#                 watcher_output = watcher(images)
#                 batch_size, channels, height, width = watcher_output.shape
#                 encoder_outputs = watcher_output.permute(0, 2, 3, 1).reshape(
#                     batch_size, height * width, channels
#                 )

#                 outputs = decoder(encoder_outputs, labels, max_len)
#                 outputs = outputs.view(-1, outputs.size(-1))
#                 labels = labels.view(-1)

#                 # Calculate loss and backprop
#                 loss = criterion(outputs, labels)
#                 loss.backward()
                
#                 # Gradient clipping
#                 torch.nn.utils.clip_grad_norm_(
#                     list(watcher.parameters()) + list(decoder.parameters()), 
#                     max_norm=5.0
#                 )
#                 optimizer.step()

#                 # Update metrics
#                 total_loss += loss.item()
#                 batch_count += 1
                
#                 # Update progress bar with more metrics
#                 pbar.set_postfix({
#                     'loss': f'{loss.item():.4f}',
#                     'avg_loss': f'{total_loss/batch_count:.4f}',
#                     'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'
#                 })
                
#             except RuntimeError as e:
#                 print(f"Error in batch {batch_idx}: {str(e)}")
#                 continue

#         # Calculate average loss
#         avg_loss = total_loss / batch_count
        
#         # Track learning rate changes
#         old_lr = optimizer.param_groups[0]['lr']
#         scheduler.step(avg_loss)
#         new_lr = optimizer.param_groups[0]['lr']
        
#         print(f"\nEpoch {epoch+1}, Average Loss: {avg_loss:.4f}")
#         if old_lr != new_lr:
#             print(f"Learning rate changed: {old_lr:.2e} -> {new_lr:.2e}")
        
#         # Save best model with timestamp
#         if avg_loss < best_loss:
#             best_loss = avg_loss
#             timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
#             save_path = os.path.join(save_dir, f'model_epoch{epoch+1}_{timestamp}.pth')
#             torch.save({
#                 'epoch': epoch,
#                 'watcher_state_dict': watcher.state_dict(),
#                 'decoder_state_dict': decoder.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'scheduler_state_dict': scheduler.state_dict(),
#                 'loss': best_loss,
#                 'vocab': vocab,
#                 'config': {
#                     'max_len': max_len,
#                     'vocab_size': len(vocab)
#                 }
#             }, save_path)
#             print(f"Saved best model to {save_path}")

# except KeyboardInterrupt:
#     print("\nTraining interrupted by user")
# except Exception as e:
#     print(f"\nError during training: {str(e)}")
# finally:
#     # Save final model
#     timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
#     final_path = os.path.join(save_dir, f'model_final_{timestamp}.pth')
#     torch.save({
#         'watcher_state_dict': watcher.state_dict(),
#         'decoder_state_dict': decoder.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'scheduler_state_dict': scheduler.state_dict(),
#         'loss': total_loss / len(train_loader) if 'total_loss' in locals() else None,
#         'vocab': vocab,
#         'config': {
#             'max_len': max_len,
#             'vocab_size': len(vocab)
#         }
#     }, final_path)
#     print(f"Saved final model to {final_path}")

Using device: mps


Epoch 1/10:   8%|▊         | 1/13 [00:03<00:36,  3.02s/it]

Error in batch 0: MPS backend out of memory (MPS allocated: 6.88 GB, other allocations: 1.47 GB, max allowed: 9.07 GB). Tried to allocate 1.46 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


Epoch 1/10:  15%|█▌        | 2/13 [00:05<00:31,  2.85s/it]

Error in batch 1: MPS backend out of memory (MPS allocated: 6.88 GB, other allocations: 1.47 GB, max allowed: 9.07 GB). Tried to allocate 1.46 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


Epoch 1/10:  23%|██▎       | 3/13 [00:07<00:25,  2.53s/it]

Error in batch 2: MPS backend out of memory (MPS allocated: 6.88 GB, other allocations: 1.47 GB, max allowed: 9.07 GB). Tried to allocate 1.46 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


Epoch 1/10:  31%|███       | 4/13 [00:10<00:23,  2.57s/it]

Error in batch 3: MPS backend out of memory (MPS allocated: 6.88 GB, other allocations: 1.47 GB, max allowed: 9.07 GB). Tried to allocate 1.46 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


Epoch 1/10:  38%|███▊      | 5/13 [00:14<00:24,  3.02s/it]

Error in batch 4: MPS backend out of memory (MPS allocated: 6.88 GB, other allocations: 1.47 GB, max allowed: 9.07 GB). Tried to allocate 1.46 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


Epoch 1/10:  38%|███▊      | 5/13 [00:19<00:31,  3.94s/it]



Training interrupted by user
Saved final model to checkpoints/model_final_20251022_161245.pth


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Initialize models with same architecture
watcher = WatcherFCN(in_channels=9)  # 9-channel input as per your dataset
decoder = ParserGRUDecoder(vocab_size=len(vocab))  # Using vocab from previous setup

# Device configuration - use CPU if memory is constrained
device = torch.device('cpu')  # Change to 'cuda' or 'mps' if you have enough memory
print(f"Using device: {device}")

# Load the pre-trained weights
checkpoint = torch.load('best_model.pth', map_location=device)
watcher.load_state_dict(checkpoint['watcher_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])

# Set models to evaluation mode
watcher.eval()
decoder.eval()

Using device: cpu


ParserGRUDecoder(
  (embedding): Embedding(91, 256)
  (gru): GRUCell(768, 256)
  (attention): CoverageAttention(
    (W_a): Linear(in_features=256, out_features=256, bias=True)
    (U_a): Linear(in_features=512, out_features=256, bias=True)
    (U_f): Linear(in_features=1, out_features=256, bias=True)
    (v): Linear(in_features=256, out_features=1, bias=True)
  )
  (fc): Linear(in_features=768, out_features=91, bias=True)
)

In [None]:
def test_model(watcher, decoder, test_csv, batch_size=4):
    # Load test dataset
    test_dataset = MathEquation9ChDataset(test_csv, DATASET_ROOT)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():  # Disable gradient computation for inference
        for batch in tqdm(test_loader, desc='Testing'):
            images = batch['image'].to(device)
            labels = batch['label']
            
            # Forward pass through watcher
            watcher_output = watcher(images)
            batch_size, channels, height, width = watcher_output.shape
            encoder_outputs = watcher_output.permute(0, 2, 3, 1).reshape(
                batch_size, height * width, channels
            )
            
            # Decode without teacher forcing
            outputs = decoder(encoder_outputs, None, max_len=128)
            
            # Convert outputs to text
            predictions = decode_predictions(outputs, idx2char)
            all_predictions.extend(predictions)
            all_targets.extend(labels)
            
            # Free up memory
            del watcher_output, encoder_outputs, outputs
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return all_predictions, all_targets

In [None]:
# Test on validation set
VAL_CSV = os.path.join(DATASET_ROOT, 'val_database.csv')
predictions, targets = test_model(watcher, decoder, VAL_CSV, batch_size=4)

# Calculate metrics
correct = sum(1 for p, t in zip(predictions, targets) if p == t)
accuracy = correct / len(predictions)

print(f"\nValidation Results:")
print(f"Total samples: {len(predictions)}")
print(f"Correct predictions: {correct}")
print(f"Accuracy: {accuracy:.4f}")

# Print sample predictions
print("\nSample Predictions:")
for pred, target in zip(predictions[:5], targets[:5]):
    print(f"Predicted: {pred}")
    print(f"Target   : {target}")
    print()

Testing: 100%|██████████| 25/25 [08:07<00:00, 19.48s/it]


Validation Results:
Total samples: 100
Correct predictions: 0
Accuracy: 0.0000

Sample Predictions:
Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\
Target   : (\begin{matrix}j\\ m+1\end{matrix})

Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\
Target   : \int_{E}f(x)\mu(dx)

Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\
Target   : \overline{z}=1-i\sqrt{\nu}

Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\
Target   : S=\int_{M}BF

Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\
Target   : Q:=\bigcup_{\sigma\in\sigma(T_{0})}Q_{\sigma}






In [None]:
def analyze_errors(predictions, targets):
    errors = []
    for i, (pred, target) in enumerate(zip(predictions, targets)):
        if pred != target:
            errors.append({
                'index': i,
                'predicted': pred,
                'target': target,
                'length_diff': len(pred) - len(target)
            })
    
    print("\nError Analysis:")
    print(f"Total errors: {len(errors)}")
    print("\nSample errors:")
    for error in errors[:5]:
        print(f"\nIndex: {error['index']}")
        print(f"Predicted: {error['predicted']}")
        print(f"Target   : {error['target']}")
        print(f"Length difference: {error['length_diff']}")

# Run error analysis
analyze_errors(predictions, targets)


Error Analysis:
Total errors: 100

Sample errors:

Index: 0
Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\
Target   : (\begin{matrix}j\\ m+1\end{matrix})
Length difference: 92

Index: 1
Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\
Target   : \int_{E}f(x)\mu(dx)
Length difference: 108

Index: 2
Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\
Target   : \overline{z}=1-i\sqrt{\nu}
Length difference: 101

Index: 3
Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\
Target   : S=\int_{M}BF
Length difference: 115

Index: 4
Predicted: \frac{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}{\partial R}