<a href="https://colab.research.google.com/github/negarhonarvar/Equation-Solving-OCR/blob/main/EquationSolving_OCR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Equation Detection and Solving with OCR
In this task, we shall classify equations into two categories:


*   Handwritten
*   Typped

afterwards, we shall solve each equation and report the results of it.



## Libraries

In [1]:
import os
import csv
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset
from PIL import Image
import cv2

## Drive Mount

In [3]:
from google.colab import drive
drive.mount('/content/drive')

DRIVE_DIR = "/content/drive/MyDrive/OCR_data"
TRAIN_DIR = os.path.join(DRIVE_DIR, "train")
TEST_DIR = os.path.join(DRIVE_DIR, "test")
TRAIN_CSV = os.path.join(DRIVE_DIR, "train_info.csv")
SUBMISSION_CSV = os.path.join(DRIVE_DIR, "submission.csv")

Mounted at /content/drive


## HyperParameters

In [4]:
BATCH_SIZE = 32
EPOCHS = 20
LR = 1e-4
VAL_SPLIT = 0.1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## Data Processing

Reading the train_info.csv file and extracting information and details on training data

In [5]:
def read_train_csv(csv_path):
    df = pd.read_csv(csv_path)
    return df

we duplicated the minority class of Handwritten data (only 150) to balance our dataset and prevent overfitting

In [6]:
def balance_dataset(df):
    typed_df = df[df["type"] == 0]
    handwritten_df = df[df["type"] == 1]
    len_typed = len(typed_df)
    len_hand = len(handwritten_df)
    if len_hand < len_typed:
        factor = math.ceil(len_typed / len_hand)
        oversampled = pd.concat([handwritten_df]*factor, ignore_index=True)
        balanced = pd.concat([typed_df, oversampled], ignore_index=True)
    else:
        factor = math.ceil(len_hand / len_typed)
        oversampled = pd.concat([typed_df]*factor, ignore_index=True)
        balanced = pd.concat([oversampled, handwritten_df], ignore_index=True)
    balanced = balanced.sample(frac=1.0, random_state=42).reset_index(drop=True)
    return balanced

The class below implements a torch dataset which helps us for the classification task.

In [7]:
class ExpressionTypeDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):

        df = df[df["path"].apply(lambda x: os.path.exists(os.path.join(root_dir, x)))]
        self.df = df.reset_index(drop=True)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        filename = row["path"]
        label = int(row["type"])    # 0=typed,1=handwritten
        expression = row["answer"]
        path = os.path.join(self.root_dir, filename)
        pil_img = Image.open(path).convert("RGB")
        if self.transform:
            img_tensor = self.transform(pil_img)
        else:
            img_tensor = T.ToTensor()(pil_img)
        return img_tensor, label, expression


Data Augmentation

In [8]:
train_transforms = T.Compose([
    T.Resize((224,224)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],
                [0.229,0.224,0.225])
])
val_transforms = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],
                [0.229,0.224,0.225])
])

## Classification Model

In [9]:
def create_type_model(num_classes=2):
    model = torchvision.models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = True
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

## Character Classification

We'll assume we have a separate folder structure

"char_train" with subfolders "0","1","2",...,"9","plus","minus","times","div","lparen","rparen" ,
to train a single-character classifier.

In [10]:
class SingleCharDataset(Dataset):

    def __init__(self, root_dir, transform=None):

        self.samples = []
        self.transform = transform
        subfolders = sorted(os.listdir(root_dir))
        label_map = {}

        # mapping
        #  0->"0", 1->"1",..., 9->"9", 10->"plus", 11->"minus",12->"times",13->"div",14->"lparen",15->"rparen"

        idx = 0
        for subf in subfolders:
            label_map[subf] = idx
            idx+=1
        for subf in subfolders:
            sub_path = os.path.join(root_dir, subf)
            if not os.path.isdir(sub_path):
                continue
            label = label_map[subf]
            for file in os.listdir(sub_path):
                if file.lower().endswith(('.png','.jpg','.jpeg')):
                    self.samples.append((os.path.join(sub_path, file), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        pil_img = Image.open(path).convert("RGB")
        if self.transform:
            img_tensor = self.transform(pil_img)
        else:
            img_tensor = T.ToTensor()(pil_img)
        return img_tensor, label


### Character Classification Model

In [11]:
def create_char_model(num_classes=16):
    model = torchvision.models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = True
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

### Character Segmentation

In [12]:
def segment_characters(pil_img):
    # we convert data to grayscale, threshold, find contours

    img_cv = np.array(pil_img.convert("L"))
    _, thresh = cv2.threshold(img_cv, 0, 255, cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    bboxes = []
    for cnt in contours:
        x,y,w,h = cv2.boundingRect(cnt)

        if w<5 or h<5:
            continue
        bboxes.append((x,y,w,h))

    bboxes.sort(key=lambda b: b[0])
    char_images = []
    for (x,y,w,h) in bboxes:
        crop = img_cv[y:y+h, x:x+w]

        pil_crop = Image.fromarray(crop)
        char_images.append((x, pil_crop))
    return char_images


### Character Classification into Bounding Box

In [13]:
def classify_characters(char_images, model, transform, label_map_rev):

    # char_images is list of (x, PILimage)
    # label_map_rev is like this = {0:'0',1:'1',...10:'plus',11:'minus',12:'times',13:'div',14:'lparen',15:'rparen'}

    results = []
    model.eval()
    with torch.no_grad():
        for (xpos, pil_img) in char_images:
            rgb_img = pil_img.convert("RGB")
            tensor_img = transform(rgb_img).unsqueeze(0).to(DEVICE)
            out = model(tensor_img)
            _, pred = torch.max(out, 1)
            pred_label = pred.item()
            results.append((xpos, label_map_rev[pred_label]))

    results.sort(key=lambda r: r[0])

    recognized = [r[1] for r in results]
    return recognized


### To string Conversion

In [14]:
def symbols_to_expression(symbols):

    op_map = {
        'plus': '+',
        'minus': '-',
        'times': '×',
        'div': '÷',
        'lparen': '(',
        'rparen': ')'
    }

    expr = ""
    digit_buffer = ""
    for s in symbols:
        if s.isdigit():
            digit_buffer += s
        else:

            if digit_buffer != "":
                expr += digit_buffer
                digit_buffer = ""

            expr += op_map.get(s, '')
    if digit_buffer != "":
        expr += digit_buffer
    return expr

### Expression Evalutaion

In [15]:
def evaluate_expression(expr_str):

    expr_str = expr_str.replace('×','*')
    expr_str = expr_str.replace('÷','/')
    try:
        val = eval(expr_str)
    except:
        val = 0
    return round(val, 2)

## OCR Based on CNN and Tesseract

My main model is a two‐stage system:
   - first, I train a CNN on MNIST to serve as a backbone for digit recognition;
   - Second, for each test equation image, I use Tesseract (via pytesseract) to detect the equation’s bounding boxes and extract individual characters;
   - finally, my model classifies each character with the MNIST CNN, reconstructs the equation string and solves it.

In [25]:
!sudo apt-get install tesseract-ocr

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tesseract-ocr is already the newest version (4.1.1-2.1build1).
0 upgraded, 0 newly installed, 0 to remove and 30 not upgraded.


In [27]:
!pip install pytesseract

Collecting pytesseract
  Downloading pytesseract-0.3.13-py3-none-any.whl.metadata (11 kB)
Downloading pytesseract-0.3.13-py3-none-any.whl (14 kB)
Installing collected packages: pytesseract
Successfully installed pytesseract-0.3.13


In [28]:
import os
import cv2
import csv
import math
import pytesseract
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from PIL import Image

In [29]:
class MNIST_CNN(nn.Module):
    def __init__(self):
        super(MNIST_CNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 28x28 -> 28x28
            nn.ReLU(),
            nn.MaxPool2d(2),  # 28x28 -> 14x14
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 14x14 -> 14x14
            nn.ReLU(),
            nn.MaxPool2d(2)   # 14x14 -> 7x7
        )
        self.fc = nn.Linear(64 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

def train_mnist_cnn(num_epochs=5, batch_size=64, lr=1e-3, device="cuda" if torch.cuda.is_available() else "cpu"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_dataset = MNIST(root='./mnist_data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    model = MNIST_CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * imgs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        print(f"Epoch {epoch+1}/{num_epochs}: Loss: {running_loss/total:.4f}, Acc: {correct/total*100:.2f}%")

    # Save the MNIST model for later use
    torch.save(model.state_dict(), "mnist_cnn.pth")
    return model

In [30]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if not os.path.exists("mnist_cnn.pth"):
    print("Training MNIST CNN...")
    mnist_model = train_mnist_cnn(num_epochs=5, device=device)
else:
    print("Loading pretrained MNIST CNN...")
    mnist_model = MNIST_CNN().to(device)
    mnist_model.load_state_dict(torch.load("mnist_cnn.pth", map_location=device))
mnist_model.eval()

Training MNIST CNN...


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 483kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.41MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.79MB/s]


Epoch 1/5: Loss: 0.1426, Acc: 95.53%
Epoch 2/5: Loss: 0.0477, Acc: 98.56%
Epoch 3/5: Loss: 0.0348, Acc: 98.88%
Epoch 4/5: Loss: 0.0261, Acc: 99.15%
Epoch 5/5: Loss: 0.0201, Acc: 99.34%


MNIST_CNN(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Linear(in_features=3136, out_features=10, bias=True)
)

In [31]:
# For better localization, use pytesseract.image_to_boxes

def ocr_equation_with_mnist(image_path, mnist_model, device):
    full_ocr = pytesseract.image_to_string(Image.open(image_path), config="--psm 6")
    print(f"Tesseract full OCR: {full_ocr.strip()}")

    # The output format: character, x1, y1, x2, y2, page
    boxes = pytesseract.image_to_boxes(Image.open(image_path))
    if not boxes:
        return "", 0

    pil_img = Image.open(image_path).convert("L")  # grayscale
    img_np = np.array(pil_img)
    H = pil_img.height

    # Process boxes
    recognized_chars = []
    for line in boxes.splitlines():
        parts = line.split(" ")
        if len(parts) < 6:
            continue
        char, x1, y1, x2, y2 = parts[0], int(parts[1]), int(parts[2]), int(parts[3]), int(parts[4])

        # Convert coordinates: pytesseract's y are from bottom; convert to top-based.
        y1_new = H - y2
        y2_new = H - y1

        crop = img_np[y1_new:y2_new, x1:x2]

        # Skip if region is too small
        if crop.size == 0 or crop.shape[0] < 10 or crop.shape[1] < 10:
            continue

        # Resize crop to Mnist Standard of 28x28
        crop_resized = cv2.resize(crop, (28,28), interpolation=cv2.INTER_LINEAR)
        crop_tensor = transforms.ToTensor()(Image.fromarray(crop_resized)).unsqueeze(0).to(device)
        crop_tensor = transforms.Normalize((0.1307,), (0.3081,))(crop_tensor)

        # MNIST CNN classification
        with torch.no_grad():
            outputs = mnist_model(crop_tensor)
            _, pred = torch.max(outputs, 1)
            recognized_chars.append((x1, pred.item()))

    recognized_chars.sort(key=lambda x: x[0])
    reconstructed = "".join(str(digit) for _, digit in recognized_chars)
    print("Reconstructed digits from MNIST CNN:", reconstructed)
    try:
        result_value = eval(reconstructed)
    except Exception as e:
        result_value = 0
    return reconstructed, result_value

## Main

The Pipeline of our model is implemented below in the following order:


1.   Train typed/handwritten model
2.   Train single-char model
3.   For test images:
     - typed/handwritten classification
     - if typed or handwritten, we do character segmentation
     - classify each char
     - build expression string
     - evaluate
     - output in submission.csv


EasyOCR will detect both digits and operators from each test image, then postprocesses the recognized string (including replacing common mis‐detections) and evaluate the equation. A simple heuristic based on EasyOCR's confidence scores is used to decide whether the expression appears “typed” (if the average confidence is high) or “handwritten.”


In [39]:
!pip install easyocr

Collecting easyocr
  Downloading easyocr-1.7.2-py3-none-any.whl.metadata (10 kB)
Collecting python-bidi (from easyocr)
  Downloading python_bidi-0.6.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Collecting pyclipper (from easyocr)
  Downloading pyclipper-1.3.0.post6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.0 kB)
Collecting ninja (from easyocr)
  Downloading ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->easyocr)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->easyocr)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->easyocr)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (

In [40]:
import easyocr

In [41]:
def main():

    reader = easyocr.Reader(['en'], gpu=True if os.environ.get('CUDA_VISIBLE_DEVICES') else False)

    test_files = sorted(os.listdir(TEST_DIR), key=lambda x: int(os.path.splitext(x)[0]))
    submission_rows = []

    for file in test_files:
        image_path = os.path.join(TEST_DIR, file)
        result = reader.readtext(image_path, detail=1)

        if result:

            result = sorted(result, key=lambda r: r[0][0][0])
            recognized_expr = "".join([r[1] for r in result])
            avg_conf = sum([r[2] for r in result]) / len(result)
        else:
            recognized_expr = ""
            avg_conf = 0.0

        recognized_expr = recognized_expr.replace(" ", "")
        recognized_expr = recognized_expr.replace("x", "*").replace("X", "*").replace("÷", "/")

        try:
            value = eval(recognized_expr)
        except Exception as e:
            value = 0

        pred_type = 0 if avg_conf > 0.9 else 1

        submission_rows.append([str(pred_type), f"{round(value,2):.2f}"])
        print(f"For image {file}, recognized expression: '{recognized_expr}', result: {round(value,2)}, avg_conf: {avg_conf:.3f}")

    with open(SUBMISSION_CSV, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["type", "answer"])
        for row in submission_rows:
            writer.writerow(row)
    print(f"Submission saved to {SUBMISSION_CSV}")

In [42]:
if __name__ == "__main__":
    main()



Progress: |██████████████████████████████████████████████████| 100.0% Complete



Progress: |--------------------------------------------------| 0.0% CompleteProgress: |--------------------------------------------------| 0.1% CompleteProgress: |--------------------------------------------------| 0.1% CompleteProgress: |--------------------------------------------------| 0.2% CompleteProgress: |--------------------------------------------------| 0.2% CompleteProgress: |--------------------------------------------------| 0.3% CompleteProgress: |--------------------------------------------------| 0.4% CompleteProgress: |--------------------------------------------------| 0.4% CompleteProgress: |--------------------------------------------------| 0.5% CompleteProgress: |--------------------------------------------------| 0.5% CompleteProgress: |--------------------------------------------------| 0.6% CompleteProgress: |--------------------------------------------------| 0.6% CompleteProgress: |--------------------------------------------------| 0.7% Complet



For image 234.png, recognized expression: '7741(T1*7)', result: 0, avg_conf: 0.781
For image 235.png, recognized expression: '23+26+75*9)', result: 0, avg_conf: 0.817
For image 236.png, recognized expression: '49*8*78', result: 30576, avg_conf: 0.329
For image 237.png, recognized expression: '(6?+(a)', result: 0, avg_conf: 0.187
For image 238.png, recognized expression: '30:99*5756q', result: 0, avg_conf: 0.438
For image 239.png, recognized expression: '(792790)+84=11', result: 0, avg_conf: 0.641
For image 240.png, recognized expression: '39-51-98', result: -110, avg_conf: 0.544
For image 241.png, recognized expression: '24*44*531100', result: 560841600, avg_conf: 0.683
For image 242.png, recognized expression: '(37+42)_90', result: 0, avg_conf: 0.411
For image 243.png, recognized expression: '72*7)+37*88', result: 0, avg_conf: 0.571
For image 244.png, recognized expression: '81+49_63~90', result: 0, avg_conf: 0.615
For image 245.png, recognized expression: '3*58+98', result: 272, avg_



For image 304.png, recognized expression: '57(77-78)', result: 0, avg_conf: 0.782
For image 305.png, recognized expression: '96+23-(76*32)', result: -2313, avg_conf: 0.776
For image 306.png, recognized expression: '96;23;67559', result: 0, avg_conf: 0.696
For image 307.png, recognized expression: 'S<0(82*|4)', result: 0, avg_conf: 0.163
For image 308.png, recognized expression: '29-5-97*80', result: -7736, avg_conf: 0.367
For image 309.png, recognized expression: '8+5|', result: 0, avg_conf: 0.545
For image 310.png, recognized expression: '98*17+(72*16)', result: 2818, avg_conf: 0.472
For image 311.png, recognized expression: '5;96+71', result: 0, avg_conf: 0.915
Submission saved to /content/drive/MyDrive/OCR_data/submission.csv
