<a href="https://colab.research.google.com/github/negarhonarvar/Equation-Solving-OCR/blob/main/EquationSolving_OCR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Equation Detection and Solving with OCR
In this task, we shall classify equations into two categories:


*   Handwritten
*   Typped

afterwards, we shall solve each equation and report the results of it.



## Libraries

In [None]:
import os
import csv
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset
from PIL import Image
import cv2

## Drive Mount

In [None]:
from google.colab import drive
drive.mount('/content/drive')

DRIVE_DIR = "/content/drive/MyDrive/OCR_data"
TRAIN_DIR = os.path.join(DRIVE_DIR, "train")
TEST_DIR = os.path.join(DRIVE_DIR, "test")
TRAIN_CSV = os.path.join(DRIVE_DIR, "train_info.csv")
SUBMISSION_CSV = os.path.join(DRIVE_DIR, "submission.csv")

Mounted at /content/drive


## HyperParameters

In [None]:
BATCH_SIZE = 8
EPOCHS = 20
LR = 1e-4
VAL_SPLIT = 0.1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## Data Processing

Reading the train_info.csv file and extracting information and details on training data

In [None]:
def read_train_csv(csv_path):
    df = pd.read_csv(csv_path)
    return df

we duplicated the minority class of Handwritten data (only 150) to balance our dataset and prevent overfitting

In [None]:
def balance_dataset(df):
    typed_df = df[df["type"] == 0]
    handwritten_df = df[df["type"] == 1]
    len_typed = len(typed_df)
    len_hand = len(handwritten_df)
    if len_hand < len_typed:
        factor = math.ceil(len_typed / len_hand)
        oversampled = pd.concat([handwritten_df]*factor, ignore_index=True)
        balanced = pd.concat([typed_df, oversampled], ignore_index=True)
    else:
        factor = math.ceil(len_hand / len_typed)
        oversampled = pd.concat([typed_df]*factor, ignore_index=True)
        balanced = pd.concat([oversampled, handwritten_df], ignore_index=True)
    balanced = balanced.sample(frac=1.0, random_state=42).reset_index(drop=True)
    return balanced

The class below implements a torch dataset which helps us for the classification task.

In [None]:
class ExpressionTypeDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):

        df = df[df["path"].apply(lambda x: os.path.exists(os.path.join(root_dir, x)))]
        self.df = df.reset_index(drop=True)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        filename = row["path"]
        label = int(row["type"])    # 0=typed,1=handwritten
        expression = row["answer"]
        path = os.path.join(self.root_dir, filename)
        pil_img = Image.open(path).convert("RGB")
        if self.transform:
            img_tensor = self.transform(pil_img)
        else:
            img_tensor = T.ToTensor()(pil_img)
        return img_tensor, label, expression


Data Augmentation

In [None]:
train_transforms = T.Compose([
    T.Resize((224,224)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],
                [0.229,0.224,0.225])
])
val_transforms = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],
                [0.229,0.224,0.225])
])

## Classification Model

In [None]:
def create_type_model(num_classes=2):
    model = torchvision.models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = True
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

## Character Classification

We'll assume we have a separate folder structure

"char_train" with subfolders "0","1","2",...,"9","plus","minus","times","div","lparen","rparen" ,
to train a single-character classifier.

In [None]:
class SingleCharDataset(Dataset):

    def __init__(self, root_dir, transform=None):

        self.samples = []
        self.transform = transform
        subfolders = sorted(os.listdir(root_dir))
        label_map = {}

        # mapping
        #  0->"0", 1->"1",..., 9->"9", 10->"plus", 11->"minus",12->"times",13->"div",14->"lparen",15->"rparen"

        idx = 0
        for subf in subfolders:
            label_map[subf] = idx
            idx+=1
        for subf in subfolders:
            sub_path = os.path.join(root_dir, subf)
            if not os.path.isdir(sub_path):
                continue
            label = label_map[subf]
            for file in os.listdir(sub_path):
                if file.lower().endswith(('.png','.jpg','.jpeg')):
                    self.samples.append((os.path.join(sub_path, file), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        pil_img = Image.open(path).convert("RGB")
        if self.transform:
            img_tensor = self.transform(pil_img)
        else:
            img_tensor = T.ToTensor()(pil_img)
        return img_tensor, label


### Character Classification Model

In [None]:
def create_char_model(num_classes=16):
    model = torchvision.models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = True
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

### Character Segmentation

In [None]:
def segment_characters(pil_img):
    # we convert data to grayscale, threshold, find contours

    img_cv = np.array(pil_img.convert("L"))
    _, thresh = cv2.threshold(img_cv, 0, 255, cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    bboxes = []
    for cnt in contours:
        x,y,w,h = cv2.boundingRect(cnt)

        if w<5 or h<5:
            continue
        bboxes.append((x,y,w,h))

    bboxes.sort(key=lambda b: b[0])
    char_images = []
    for (x,y,w,h) in bboxes:
        crop = img_cv[y:y+h, x:x+w]

        pil_crop = Image.fromarray(crop)
        char_images.append((x, pil_crop))
    return char_images


### Character Classification into Bounding Box

In [None]:
def classify_characters(char_images, model, transform, label_map_rev):

    # char_images is list of (x, PILimage)
    # label_map_rev is like this = {0:'0',1:'1',...10:'plus',11:'minus',12:'times',13:'div',14:'lparen',15:'rparen'}

    results = []
    model.eval()
    with torch.no_grad():
        for (xpos, pil_img) in char_images:
            rgb_img = pil_img.convert("RGB")
            tensor_img = transform(rgb_img).unsqueeze(0).to(DEVICE)
            out = model(tensor_img)
            _, pred = torch.max(out, 1)
            pred_label = pred.item()
            results.append((xpos, label_map_rev[pred_label]))

    results.sort(key=lambda r: r[0])

    recognized = [r[1] for r in results]
    return recognized


### To string Conversion

In [None]:
def symbols_to_expression(symbols):

    op_map = {
        'plus': '+',
        'minus': '-',
        'times': '×',
        'div': '÷',
        'lparen': '(',
        'rparen': ')'
    }

    expr = ""
    digit_buffer = ""
    for s in symbols:
        if s.isdigit():
            digit_buffer += s
        else:

            if digit_buffer != "":
                expr += digit_buffer
                digit_buffer = ""

            expr += op_map.get(s, '')
    if digit_buffer != "":
        expr += digit_buffer
    return expr

### Expression Evalutaion

In [None]:
def evaluate_expression(expr_str):

    expr_str = expr_str.replace('×','*')
    expr_str = expr_str.replace('÷','/')
    try:
        val = eval(expr_str)
    except:
        val = 0
    return round(val, 2)

## OCR based on Clova AI

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T

In [None]:
class BidirectionalLSTM(nn.Module):
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(nIn, nHidden, num_layers=2, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)
        output = output.view(T, b, -1)
        return output

In [None]:
class VGG_FeatureExtractor(nn.Module):
    def __init__(self, input_channel, output_channel=512):
        super(VGG_FeatureExtractor, self).__init__()
        self.ConvNet = nn.Sequential(
            nn.Conv2d(input_channel, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),  # 64
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),  # 128
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)),
            nn.Conv2d(512, output_channel, kernel_size=2, stride=1, padding=0),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.ConvNet(x)

In [None]:
class CRAFTDetector(nn.Module):
    def __init__(self, pretrained=True):
        super(CRAFTDetector, self).__init__()

        vgg16_bn = torchvision.models.vgg16_bn(pretrained=pretrained)
        self.features = vgg16_bn.features
        self.conv1 = nn.Conv2d(512, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(128, 1, kernel_size=1)

    def forward(self, x):
        feat = self.features(x)
        x = F.relu(self.conv1(feat))
        score_map = torch.sigmoid(self.conv2(x))  # (N,1,H,W)
        return score_map

In [None]:
def get_text_box(score_map, threshold=0.5):
    score_np = score_map.squeeze().cpu().detach().numpy()
    binary = ((score_np > threshold) * 255).astype(np.uint8)
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        box = max(contours, key=lambda cnt: cv2.contourArea(cnt))
        x, y, w, h = cv2.boundingRect(box)
        return (x, y, w, h)
    return None

In [None]:
class CRNNRecognizer(nn.Module):
    def __init__(self, imgH=32, nc=1, nclass=17, nh=256):
        super(CRNNRecognizer, self).__init__()
        self.FeatureExtraction = VGG_FeatureExtractor(nc, output_channel=512)
        self.SequenceModeling = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass)
        )

    def forward(self, x):
        features = self.FeatureExtraction(x)
        b, c, h, w = features.size()
        # The expected output height should be 1.
        assert h == 1, "the height of conv features must be 1"
        features = features.squeeze(2)  # shape: (b, c, w)
        features = features.permute(2, 0, 1)  # shape: (w, b, c)
        output = self.SequenceModeling(features)
        output = F.log_softmax(output, dim=2)
        return output

In [None]:
def ctc_greedy_decoder(output, blank=0):
    output = output.cpu().detach().numpy()
    pred_indices = np.argmax(output[:, 0, :], axis=1)
    print("Raw tokens:", pred_indices)  # Debug print
    decoded = []
    prev = -1
    for idx in pred_indices:
        if idx != prev and idx != blank:
            decoded.append(idx)
        prev = idx
    mapping = {
        1: '0', 2: '1', 3: '2', 4: '3', 5: '4',
        6: '5', 7: '6', 8: '7', 9: '8', 10: '9',
        11: '+', 12: '-', 13: '×', 14: '÷', 15: '(', 16: ')'
    }
    expr = ""
    for token in decoded:
        expr += mapping.get(token, '')
    return expr


In [None]:
def recognize_expression(image_path, craft_model, crnn_model, device):
    pil_img = Image.open(image_path).convert("RGB")
    craft_transform = T.Compose([
        T.Resize((768,768)),
        T.ToTensor(),
        T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
    ])
    input_tensor = craft_transform(pil_img).unsqueeze(0).to(device)
    with torch.no_grad():
        score_map = craft_model(input_tensor)
    box = get_text_box(score_map, threshold=0.7)
    if box is None:
        crop = np.array(pil_img.convert("L"))
    else:
        x, y, w, h = box
        orig_w, orig_h = pil_img.size
        scale_x = orig_w / 768.0
        scale_y = orig_h / 768.0
        x = int(x * scale_x)
        y = int(y * scale_y)
        w = int(w * scale_x)
        h = int(h * scale_y)
        img_np = np.array(pil_img.convert("L"))
        crop = img_np[y:y+h, x:x+w]
    # Save crop for debugging
    cv2.imwrite("debug_crop.png", crop)
    if crop.size == 0:
        return ""
    h_crop, w_crop = crop.shape
    new_h = 32
    new_w = max(1, int(w_crop * new_h / h_crop))
    crop_resized = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    crop_pil = Image.fromarray(crop_resized).convert("L")
    crnn_transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.5], std=[0.5])
    ])
    crnn_input = crnn_transform(crop_pil).unsqueeze(0).to(device)
    with torch.no_grad():
        crnn_out = crnn_model(crnn_input)
    # Adjust blank token index: use blank=16 (instead of 0)
    recognized_expr = ctc_greedy_decoder(crnn_out, blank=16)
    recognized_expr = recognized_expr.replace(" ", "")
    return recognized_expr

def ctc_greedy_decoder(output, blank=16):
    output = output.cpu().detach().numpy()
    pred_indices = np.argmax(output[:,0,:], axis=1)
    decoded = []
    prev = -1
    for idx in pred_indices:
        if idx != prev and idx != blank:
            decoded.append(idx)
        prev = idx
    mapping = {
        1: '0', 2: '1', 3: '2', 4: '3', 5: '4',
        6: '5', 7: '6', 8: '7', 9: '8', 10: '9',
        11: '+', 12: '-', 13: '×', 14: '÷', 15: '(', 16: ')'
    }
    expr = "".join(mapping.get(token, '') for token in decoded)
    return expr


## Main

The Pipeline of our model is implemented below in the following order:


1.   Train typed/handwritten model
2.   Train single-char model
3.   For test images:
     - typed/handwritten classification
     - if typed or handwritten, we do character segmentation
     - classify each char
     - build expression string
     - evaluate
     - output in submission.csv





In [None]:
def main():
    print("Reading train CSV...")
    df = read_train_csv(TRAIN_CSV)
    print("Balancing typed vs handwritten data...")
    df_balanced = balance_dataset(df)

    print("Creating typed/handwritten dataset...")
    full_dataset = ExpressionTypeDataset(df_balanced, TRAIN_DIR, transform=train_transforms)
    n_data = len(full_dataset)
    n_val = int(VAL_SPLIT * n_data)
    n_train = n_data - n_val
    train_ds, val_ds = random_split(full_dataset, [n_train, n_val])
    val_ds.dataset.transform = val_transforms

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    print("Creating typed/handwritten model...")
    type_model = create_type_model(num_classes=2).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(type_model.parameters(), lr=LR)

    best_acc = 0.0
    best_weights = None

    print("Training typed/handwritten classifier...")
    for epoch in range(EPOCHS):
        type_model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for imgs, labels, _ in train_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            optimizer.zero_grad()
            out = type_model(imgs)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * imgs.size(0)
            _, preds = torch.max(out, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        train_loss = running_loss / total
        train_acc = correct / total

        type_model.eval()
        val_loss_ = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for imgs, labels, _ in val_loader:
                imgs = imgs.to(DEVICE)
                labels = labels.to(DEVICE)
                out = type_model(imgs)
                loss = criterion(out, labels)
                val_loss_ += loss.item() * imgs.size(0)
                _, preds = torch.max(out, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        val_loss = val_loss_ / val_total
        val_acc = val_correct / val_total

        if val_acc > best_acc:
            best_acc = val_acc
            best_weights = {k: v.cpu() for k, v in type_model.state_dict().items()}

        print(f"Epoch [{epoch+1}/{EPOCHS}] - Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc*100:.2f}%")

    if best_weights is not None:
        type_model.load_state_dict({k: v.to(DEVICE) for k, v in best_weights.items()})
    print(f"Best typed/handwritten validation accuracy: {best_acc*100:.2f}%")

    print("Loading OCR models (CRAFT + CRNN)...")
    craft_model = CRAFTDetector(pretrained=True).to(DEVICE)
    crnn_model = CRNNRecognizer(imgH=32, nc=1, nclass=17, nh=256).to(DEVICE)

    pretrained_path = "/content/drive/MyDrive/OCR_data/None-VGG-BiLSTM-CTC.pth"
    if os.path.exists(pretrained_path):
        crnn_model.load_state_dict(torch.load(pretrained_path, map_location=DEVICE), strict=False)
        print("Successfully loaded pretrained CRNN weights from Clova AI model.")
    else:
        print("Pretrained CRNN weights not found at", pretrained_path)


    craft_model.eval()
    crnn_model.eval()

    print("Predicting on test data with OCR...")
    test_files = sorted(os.listdir(TEST_DIR), key=lambda x: int(os.path.splitext(x)[0]))
    submission_rows = []
    type_model.eval()

    with torch.no_grad():
        for file in test_files:
            test_path = os.path.join(TEST_DIR, file)

            pil_img = Image.open(test_path).convert("RGB")
            type_input = val_transforms(pil_img).unsqueeze(0).to(DEVICE)
            out = type_model(type_input)
            _, pred_label = torch.max(out, 1)
            pred_type = pred_label.item()


            recognized_expr = recognize_expression(test_path, craft_model, crnn_model, DEVICE)
            print(f"For image {file}, recognized expression: '{recognized_expr}'")

            expr_for_eval = recognized_expr.replace("×", "*").replace("÷", "/")
            try:
                result_value = eval(expr_for_eval)
            except Exception as e:
                result_value = 0
            answer = round(result_value, 2)

            submission_rows.append([str(pred_type), f"{answer:.2f}"])

    with open(SUBMISSION_CSV, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["type", "answer"])
        for row in submission_rows:
            writer.writerow(row)

    print(f"Submission saved to {SUBMISSION_CSV}")
    print("Done!")


In [None]:
if __name__ == "__main__":
    main()

Reading train CSV...
Balancing typed vs handwritten data...
Creating typed/handwritten dataset...
Creating typed/handwritten model...


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 199MB/s]


Training typed/handwritten classifier...
Epoch [1/20] - Train Loss: 0.0311, Train Acc: 98.83% | Val Loss: 0.0000, Val Acc: 100.00%
Epoch [2/20] - Train Loss: 0.0008, Train Acc: 100.00% | Val Loss: 0.0000, Val Acc: 100.00%
Epoch [3/20] - Train Loss: 0.0003, Train Acc: 100.00% | Val Loss: 0.0000, Val Acc: 100.00%
Epoch [4/20] - Train Loss: 0.0068, Train Acc: 99.61% | Val Loss: 0.0000, Val Acc: 100.00%
Epoch [5/20] - Train Loss: 0.0002, Train Acc: 100.00% | Val Loss: 0.0000, Val Acc: 100.00%
Epoch [6/20] - Train Loss: 0.0243, Train Acc: 99.14% | Val Loss: 0.0000, Val Acc: 100.00%
Epoch [7/20] - Train Loss: 0.0004, Train Acc: 100.00% | Val Loss: 0.0000, Val Acc: 100.00%
Epoch [8/20] - Train Loss: 0.0218, Train Acc: 99.22% | Val Loss: 2.6406, Val Acc: 51.41%
Epoch [9/20] - Train Loss: 0.0106, Train Acc: 99.61% | Val Loss: 0.0065, Val Acc: 100.00%
Epoch [10/20] - Train Loss: 0.0016, Train Acc: 100.00% | Val Loss: 0.0001, Val Acc: 100.00%
Epoch [11/20] - Train Loss: 0.0002, Train Acc: 100.00%

Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
100%|██████████| 528M/528M [00:07<00:00, 69.9MB/s]
  crnn_model.load_state_dict(torch.load(pretrained_path, map_location=DEVICE), strict=False)


Successfully loaded pretrained CRNN weights from Clova AI model.
Predicting on test data with OCR...
For image 0.png, recognized expression: '7×5×5'
For image 1.png, recognized expression: '7×5×5'
For image 2.png, recognized expression: '7×5×5'
For image 3.png, recognized expression: '7×5×5'
For image 4.png, recognized expression: '7×5×5'
For image 5.png, recognized expression: '7×5×5'
For image 6.png, recognized expression: '7×5×5'
For image 7.png, recognized expression: '7×5×5'
For image 8.png, recognized expression: '7×5×5'
For image 9.png, recognized expression: '7×5×5'
For image 10.png, recognized expression: '7×5×5'
For image 11.png, recognized expression: '7×5×5'
For image 12.png, recognized expression: '7×5×5'
For image 13.png, recognized expression: '7×5×5'
For image 14.png, recognized expression: '7×5×5'
For image 15.png, recognized expression: '7×5×5'
For image 16.png, recognized expression: '7×5×5'
For image 17.png, recognized expression: '7×5×5'
For image 18.png, recognize