In [1]:
import os
import shutil
import xml.etree.ElementTree as ET
import numpy as np
import cv2
from PIL import Image

def parse_inkml(file_path):
    try:
        tree = ET.parse(file_path)
        root = tree.getroot()
        traces = {}
        for trace in root.findall('{http://www.w3.org/2003/InkML}trace'):
            trace_id = str(trace.attrib['id'])
            coords = []
            for point in trace.text.strip().split(','):
                if point.strip() == '':
                    continue
                parts = point.strip().split()
                if len(parts) >= 2:
                    x = float(parts[0])
                    y = float(parts[1])
                    coords.append((x, y))
            traces[trace_id] = coords

        trace_groups = []
        for traceView in root.findall('.//{http://www.w3.org/2003/InkML}traceView'):
            trace_groups.append(traceView.attrib['traceDataRef'])

        return traces, trace_groups
    except Exception as e:
        print(f"Error parsing {file_path}: {e}")
        return None, None

def normalize_and_render(traces, trace_groups=None, img_size=256, padding=10):
    all_points = []
    keys = trace_groups if trace_groups else traces.keys()

    for t_id in keys:
        if t_id in traces:
            all_points.extend(traces[t_id])
    
    if not all_points:
        raise ValueError("No valid stroke points found.")

    all_points = np.array(all_points)
    min_x, min_y = np.min(all_points, axis=0)
    max_x, max_y = np.max(all_points, axis=0)

    scale = (img_size - 2 * padding) / max(max_x - min_x, max_y - min_y + 1e-6)
    canvas = np.ones((img_size, img_size), dtype=np.uint8) * 255

    for t_id in keys:
        if t_id in traces:
            trace = np.array(traces[t_id])
            trace -= [min_x, min_y]
            trace *= scale
            trace += padding
            trace = trace.astype(np.int32)

            for i in range(1, len(trace)):
                pt1 = tuple(trace[i - 1])
                pt2 = tuple(trace[i])
                cv2.line(canvas, pt1, pt2, color=0, thickness=2)

    return canvas

def process_inkml_folder(inkml_root_dir, user_output_dir):
    output_img_dir = os.path.join(user_output_dir, "processed_images")

    # Recreate the output directory
    if os.path.exists(user_output_dir):
        shutil.rmtree(user_output_dir)
    os.makedirs(output_img_dir, exist_ok=True)

    for root, _, files in os.walk(inkml_root_dir):
        for file in files:
            if file.endswith('.inkml'):
                file_path = os.path.join(root, file)
                traces, trace_groups = parse_inkml(file_path)
                if traces:
                    img = normalize_and_render(traces, trace_groups)
                    base_name = os.path.splitext(file)[0]
                    img_path = os.path.join(output_img_dir, f"{base_name}.png")
                    Image.fromarray(img).save(img_path)
                    print(f"Processed: {file}")

# Run the conversion
inkml_dir = '/kaggle/input/crohme2019/crohme2019/crohme2019/valid'
user_output_dir = 'results_valid'
process_inkml_folder(inkml_dir, user_output_dir)

print(f"\n✅ Done! Processed images saved in: {user_output_dir}/processed_images")


Processed: RIT_2014_124.inkml
Processed: 507_em_76.inkml
Processed: RIT_2014_33.inkml
Processed: 37_em_2.inkml
Processed: 501_em_23.inkml
Processed: 509_em_95.inkml
Processed: 501_em_22.inkml
Processed: RIT_2014_261.inkml
Processed: 511_em_262.inkml
Processed: RIT_2014_187.inkml
Processed: RIT_2014_258.inkml
Processed: 518_em_434.inkml
Processed: 511_em_265.inkml
Processed: RIT_2014_163.inkml
Processed: RIT_2014_91.inkml
Processed: RIT_2014_286.inkml
Processed: 519_em_451.inkml
Processed: 514_em_348.inkml
Processed: 512_em_276.inkml
Processed: 31_em_197.inkml
Processed: 26_em_97.inkml
Processed: 513_em_321.inkml
Processed: RIT_2014_247.inkml
Processed: 37_em_30.inkml
Processed: 514_em_326.inkml
Processed: 510_em_105.inkml
Processed: 37_em_14.inkml
Processed: RIT_2014_107.inkml
Processed: 31_em_185.inkml
Processed: 517_em_406.inkml
Processed: 507_em_73.inkml
Processed: RIT_2014_127.inkml
Processed: 502_em_16.inkml
Processed: 32_em_213.inkml
Processed: 36_em_29.inkml
Processed: 515_em_35

In [2]:
# Tokenizing (inspired from https://www.kaggle.com/code/hayamonawwar/crohme-ctc-pytorch-cur/edit)
input_files = ['/kaggle/input/crohme2019/crohme2019_train.txt']
#'/kaggle/input/crohme2019/crohme2019_valid.txt',
#'/kaggle/input/crohme2019/crohme2019_test.txt']

vocab = set()

for input_file in input_files:
    for line in open(input_file).readlines():
        if len(line.strip().split('\t')) == 2:
            vocab.update(line.strip().split('\t')[1].split())
vocab_syms = [v for v in vocab if v not in ['Above', 'Below', 'Inside', 'NoRel', 'Right', 'Sub', 'Sup']]

with open('train_vocab.txt', 'w') as f:
    f.writelines([c + '\n' for c in sorted(vocab_syms)])
    f.writelines([c + '\n' for c in ['Above', 'Below', 'Inside', 'NoRel', 'Right', 'Sub', 'Sup']])

class Vocab(object):
    def __init__(self, vocab_file=None):
        self.word2index = {}
        self.index2word = {}

        if vocab_file:
            self.load_vocab(vocab_file)
    
    def load_vocab(self, vocab_file):
        # load vocab from file
        with open(vocab_file, 'r') as f:
            for i, line in enumerate(f):
                word = line.strip()
                self.word2index[word] = i
                self.index2word[i] = word
        # add blank word
        self.word2index['<blank>'] = len(self.word2index)
        self.index2word[self.word2index['<blank>']] = '<blank>'
#vocab = Vocab(vocab_file = '/kaggle/working/results_train/train_vocab.txt')
#vocab.index2word

In [3]:
# Testing
vocab = Vocab('train_vocab.txt')
input = '- Right \\sqrt Inside 2'.split()
output = [vocab.word2index[word] for word in input]
output # Should be [4, 105, 66, 103, 9]


[4, 105, 66, 103, 9]

In [4]:
# Custom dataset
from torch.utils.data import Dataset
from PIL import Image
import os
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim

class Vocab(object):
    def __init__(self, vocab_file):
        self.word2index = {}
        self.index2word = {}
        with open(vocab_file, 'r') as f:
            for i, line in enumerate(f):
                word = line.strip()
                self.word2index[word] = i
                self.index2word[i] = word
        self.word2index['<blank>'] = len(self.word2index)
        self.index2word[self.word2index['<blank>']] = '<blank>'

# Custom Dataset Class

class CROHMEDataset(Dataset):
    def __init__(self, img_dir, label_file, vocab, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.vocab = vocab
        self.samples = []

        with open(label_file, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) != 2:
                    print(f"⚠️ Skipping malformed line: {line.strip()}")
                    # Attempt to delete corresponding image if found
                    if parts and parts[0].endswith('.png'):
                        image_path = os.path.join(self.img_dir, parts[0])
                        if os.path.exists(image_path):
                            os.remove(image_path)
                            print(f"🗑️ Deleted image: {image_path}")
                    continue

                img_name, label = parts
                self.samples.append((img_name, label))
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, label = self.samples[idx]
    
        # Fix for .inkml extensions and subdirectories
        img_name = os.path.basename(img_name)  # Remove subdirectory like 'crohme2019/train/...'
        img_name = os.path.splitext(img_name)[0] + '.png'  # Ensure it's .png
    
        img_path = os.path.join(self.img_dir, img_name)
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")
    
        image = Image.open(img_path).convert("RGB")
    
        if self.transform:
            image = self.transform(image)
    
        label_indices = [self.vocab.word2index[word] for word in label.split()]
        return image, torch.tensor(label_indices, dtype=torch.long)
    
def collate_fn(batch):
    images, labels = zip(*batch)
    image_tensors = torch.stack(images)
    label_lengths = [len(label) for label in labels]
    padded_labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
    return image_tensors, padded_labels, label_lengths






        

In [5]:
from torch.utils.data import DataLoader
from torchvision import transforms

# Image transform
transform = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.RandomRotation(degrees=2),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


# Load vocab files
train_vocab = Vocab('/kaggle/working/train_vocab.txt')
test_vocab = Vocab('/kaggle/working/test_vocab.txt')
valid_vocab  = Vocab('/kaggle/working/valid_vocab.txt')

# Dataset paths
train_img_dir = '/kaggle/working/results_train/processed_images'
valid_img_dir = '/kaggle/working/results_valid/processed_images'
test_img_dir  = '/kaggle/working/results_test/processed_images'

train_labels = '/kaggle/input/crohme2019/crohme2019_train.txt'
valid_labels = '/kaggle/input/crohme2019/crohme2019_valid.txt'
test_labels  = '/kaggle/input/crohme2019/crohme2019_test.txt'

# Datasets
train_set = CROHMEDataset(train_img_dir, train_labels, train_vocab, transform)
valid_set = CROHMEDataset(valid_img_dir, valid_labels, valid_vocab, transform)
test_set  = CROHMEDataset(test_img_dir, test_labels, test_vocab, transform)

# DataLoaders
train_loader = DataLoader(train_set, batch_size=16, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_set, batch_size=16, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_set, batch_size=16, shuffle=False, collate_fn=collate_fn)

# Set vocab for training loop
vocab = train_vocab


FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/working/test_vocab.txt'

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score


# Function to compute Levenshtein Distance (used for WER and CER)
def levenshtein_distance(a, b):
    m = len(a) + 1
    n = len(b) + 1
    dp = np.zeros((m, n))
    
    for i in range(m):
        dp[i][0] = i
    for j in range(n):
        dp[0][j] = j
    
    for i in range(1, m):
        for j in range(1, n):
            cost = 0 if a[i - 1] == b[j - 1] else 1
            dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost)
    
    return dp[m - 1][n - 1]


# Function to calculate Word Error Rate (WER)
def compute_wer(predictions, labels):
    total_words = sum(len(label.split()) for label in labels)
    errors = 0
    for pred, label in zip(predictions, labels):
        errors += levenshtein_distance(pred.split(), label.split())  # Calculate WER at word level
    return errors / total_words if total_words > 0 else 0


# Function to calculate Character Error Rate (CER)
def compute_cer(predictions, labels):
    total_chars = sum(len(label) for label in labels)
    errors = 0
    for pred, label in zip(predictions, labels):
        errors += levenshtein_distance(list(pred), list(label))  # Calculate CER at character level
    return errors / total_chars if total_chars > 0 else 0


# Model Definition (CRNN class)
class CRNN(nn.Module):
    def __init__(self, num_classes):
        super(CRNN, self).__init__()
        resnet = models.resnet18(pretrained=False)
        state_dict = torch.load("/kaggle/input/resnet18/resnet18-f37072fd.pth")
        resnet.load_state_dict(state_dict)

        # Freeze first few layers to avoid overfitting
        for param in list(resnet.children())[:5]:
            for p in param.parameters():
                p.requires_grad = False

        # Extract CNN layers up to layer3
        self.cnn = nn.Sequential(
            *list(resnet.children())[:-3],                # Keep until layer3
            nn.BatchNorm2d(256),                          # Add BatchNorm after last conv layer
            nn.Dropout2d(p=0.3)                           # Dropout after batchnorm
        )

        # Bidirectional LSTM with dropout
        self.rnn = nn.LSTM(
            input_size=256,
            hidden_size=256,
            num_layers=2,
            dropout=0.5,              # Dropout between LSTM layers
            bidirectional=True,
            batch_first=True
        )

        self.dropout_fc = nn.Dropout(p=0.3)              # Dropout before final classification
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.cnn(x)  # (B, 256, H, W)
        x = nn.functional.adaptive_avg_pool2d(x, (1, x.size(3)))  # (B, 256, 1, W)
        x = x.squeeze(2)  # (B, 256, W)
        x = x.permute(0, 2, 1)  # (B, W, 256)
        x, _ = self.rnn(x)
        x = self.fc(x)
        return x


# Training Code
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CRNN(num_classes=len(vocab.word2index)).to(device)
criterion = nn.CTCLoss(blank=vocab.word2index['<blank>'], zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Add this:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',         # Minimize the validation loss
    factor=0.5,         # Reduce LR by a factor of 0.5
    patience=2,         # Wait for 2 epochs before reducing LR
    verbose=True        # Print LR updates
)


train_losses = []
val_losses = []
train_wer = []
val_wer = []
train_cer = []
val_cer = []

def decode(output, vocab):
    preds = output.argmax(dim=-1)  # (B, T)
    decoded = []
    for pred in preds:
        tokens = []
        prev = -1
        for p in pred:
            if p != prev and p != vocab.word2index['<blank>']:
                tokens.append(vocab.index2word[p.item()])
            prev = p
        decoded.append(" ".join(tokens))
    return decoded
    
num_epochs = 35

for epoch in range(num_epochs):
    model.train()
    epoch_train_loss = 0
    epoch_train_preds = []
    epoch_train_labels = []

    for images, labels, label_lengths in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = torch.tensor(label_lengths, dtype=torch.long)

        output = model(images)
        output = output.permute(1, 0, 2)
        input_lengths = torch.full(size=(output.size(1),), fill_value=output.size(0), dtype=torch.long)

        loss = criterion(output, labels, input_lengths, label_lengths)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()

        # Decode predictions and accumulate for WER/CER calculation
        preds = decode(output, vocab)
        labels_decoded = []
        for lbl in labels:
            tokens = [vocab.index2word[l.item()] for l in lbl if l.item() != vocab.word2index['<blank>']]
            labels_decoded.append(" ".join(tokens))
        epoch_train_preds.extend(preds)
        epoch_train_labels.extend(labels_decoded)

    avg_train_loss = epoch_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Compute Train WER and CER
    train_wer.append(compute_wer(epoch_train_preds, epoch_train_labels))
    train_cer.append(compute_cer(epoch_train_preds, epoch_train_labels))

    # Validation
    model.eval()
    epoch_val_loss = 0
    epoch_val_preds = []
    epoch_val_labels = []

    with torch.no_grad():
        for images, labels, label_lengths in valid_loader:
            images = images.to(device)
            labels = labels.to(device)
            label_lengths = torch.tensor(label_lengths, dtype=torch.long)

            output = model(images)
            output = output.permute(1, 0, 2)
            input_lengths = torch.full(size=(output.size(1),), fill_value=output.size(0), dtype=torch.long)

            val_loss = criterion(output, labels, input_lengths, label_lengths)
            epoch_val_loss += val_loss.item()

            # Decode predictions and accumulate for WER/CER calculation
            preds = decode(output, vocab)
            labels_decoded = [vocab.index2word[label.item()] for label in labels[0]]
            epoch_val_preds.extend(preds)
            epoch_val_labels.extend(labels_decoded)

    avg_val_loss = epoch_val_loss / len(valid_loader)
    val_losses.append(avg_val_loss)
    scheduler.step(avg_val_loss)

    # Compute Validation WER and CER
    val_wer.append(compute_wer(epoch_val_preds, epoch_val_labels))
    val_cer.append(compute_cer(epoch_val_preds, epoch_val_labels))

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Train WER: {train_wer[-1]:.4f}, Val WER: {val_wer[-1]:.4f}")

# Plotting Loss and Accuracy Metrics
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Validation Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('CTC Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_wer, label='Train WER', marker='o')
plt.plot(val_wer, label='Validation WER', marker='o')
plt.xlabel('Epoch')
plt.ylabel('WER')
plt.title('Train vs Validation WER')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

