<a href="https://colab.research.google.com/github/kgabriella/FUNSD_OCR/blob/main/character_recognition_in_FUNSD_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.transforms.functional import to_tensor
from torchvision.transforms import ToTensor, RandomHorizontalFlip
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
from tqdm import tqdm

import pandas as pd
from torchvision.ops import box_iou
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, mobilenet_backbone
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import torchvision.transforms.functional as TF
import random
import matplotlib.transforms as mtransforms

from torchvision.transforms.functional import resize, pad
import torchvision.transforms.functional as F

import torchvision.transforms as T
from torchvision.transforms import ColorJitter as TorchvisionColorJitter
from ast import literal_eval

from torchvision.models.detection import fasterrcnn_resnet50_fpn

from torchvision.ops import box_iou

import string
from torchvision.datasets import ImageFolder
import math

import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import cv2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Bounding box detection

## Augmentation

In [None]:
class RandomApply(torch.nn.Module):
    def __init__(self, transforms, p=0.5):
        super().__init__()
        self.transforms = transforms
        self.p = p

    def forward(self, img, target):
        if self.p < random.random():
            return img, target
        for t in self.transforms:
            img, target = t(img, target)
        return img, target

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += '\n'
        for t in self.transforms:
            format_string += '    {0}\n'.format(t)
        format_string += '    p={}'.format(self.p)
        format_string += ')'
        return format_string

In [None]:
class ResizeAndPad(object):
    def __init__(self, target_size):
        self.target_size = target_size

    def __call__(self, image, target):
        original_width, original_height = image.width, image.height
        target_height, target_width = self.target_size

        aspect_ratio = original_width / original_height

        if original_height > original_width:
            new_height = target_height
            new_width = int(new_height * aspect_ratio)
        else:
            new_width = target_width
            new_height = int(new_width / aspect_ratio)

        image = resize(image, (new_height, new_width))

        if "boxes" in target:
            boxes = target["boxes"]
            scale_x = new_width / original_width
            scale_y = new_height / original_height
            boxes[:, [0, 2]] *= scale_x
            boxes[:, [1, 3]] *= scale_y
            target["boxes"] = boxes

        pad_height = max(target_height - new_height, 0)
        pad_width = max(target_width - new_width, 0)

        image = pad(image, (0, 0, pad_width, pad_height), fill=0, padding_mode='constant')

        return image, target

In [None]:
class ColorJitter(object):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.color_jitter = TorchvisionColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)

    def __call__(self, image, target):
        image = self.color_jitter(image)
        return image, target

In [None]:
class RandomHorizontalFlip(object):
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if torch.rand(1) < self.prob:
            image = torch.flip(image, [2])
            width = image.shape[2]
            boxes = target["boxes"]
            boxes[:, [0, 2]] = width - boxes[:, [2, 0]]
            target["boxes"] = boxes
        return image, target

In [None]:
class RandomScaling(object):
    def __init__(self, scale_range=(0.8, 1.2)):
        self.scale_range = scale_range

    def __call__(self, image, target):
        scale_factor = random.uniform(self.scale_range[0], self.scale_range[1])
        new_size = (int(image.shape[1] * scale_factor), int(image.shape[0] * scale_factor))
        image = TF.resize(image, new_size)

        if "boxes" in target:
            boxes = target["boxes"]
            boxes *= scale_factor
            target["boxes"] = boxes
        return image, target

In [None]:
class RandomTranslation(object):
    def __init__(self, translation_range=(0.1, 0.1)):
        self.translation_range = translation_range

    def __call__(self, image, target):
        tx = random.uniform(-self.translation_range[0], self.translation_range[0]) * image.shape[1]
        ty = random.uniform(-self.translation_range[1], self.translation_range[1]) * image.shape[0]
        image = TF.affine(image, angle=0, translate=[tx, ty], scale=1, shear=0)

        if "boxes" in target:
            boxes = target["boxes"]
            boxes[:, 0::2] += tx
            boxes[:, 1::2] += ty
            target["boxes"] = boxes
        return image, target

In [None]:
class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

In [None]:
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

In [None]:
def get_transform(train):
    transforms = [
        ResizeAndPad((1024, 1024)),
        ToTensor(),
    ]
    if train:
        additional_transforms = [
            RandomHorizontalFlip(0.5),
            #RandomScaling(scale_range=(0.8, 1.2)),
            RandomTranslation(translation_range=(0.1, 0.1)),
            RandomApply([
                ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
            ], p=0.4)
        ]
        transforms.extend(additional_transforms)

    return Compose(transforms)

## Loading data

In [None]:
class BBDataset(Dataset):
    def __init__(self, dataframe, root_dir, transforms=None):
        self.image_groups = dataframe.groupby('file path')['coordinates'].agg(list).reset_index()
        self.root_dir = root_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.image_groups)

    def __getitem__(self, idx):
        img_path = self.image_groups.iloc[idx]['file path']
        image = Image.open(img_path).convert("RGB")

        coordinates = self.image_groups.iloc[idx]['coordinates']
        box_list = [literal_eval(coord) if isinstance(coord, str) else coord for coord in coordinates]
        boxes = torch.as_tensor(box_list, dtype=torch.float32)
        num_objs = len(box_list)
        labels = torch.ones((num_objs,), dtype=torch.int64)
        target = {"boxes": boxes, "labels": labels, "image_id": torch.tensor([idx])}

        if self.transforms:
            image, target = self.transforms(image, target)

        image_pil = to_pil_image(image)

        return image, target

In [None]:
def my_collate_fn(batch):
    max_height = max(item[0].shape[1] for item in batch)
    max_width = max(item[0].shape[2] for item in batch)
    padded_imgs, targets = [], []
    for img, target in batch:
        pad_height = max_height - img.shape[1]
        pad_width = max_width - img.shape[2]
        padded_img = F.pad(img, (0, pad_width, 0, pad_height), fill=0, padding_mode="constant")
        padded_imgs.append(padded_img)
        targets.append(target)
    return torch.stack(padded_imgs, dim=0), targets

In [None]:
# train data
train_dataframe = pd.read_excel('/content/drive/MyDrive/OCR_project/All_bounding_boxes.xlsx')
train_data_path = '/content/drive/MyDrive/OCR_project/FUNSD_dataset/training_data/images/'

# validation data
val_dataframe = pd.read_excel('/content/drive/MyDrive/OCR_project/All_validation_bounding_boxes.xlsx')
val_data_path = '/content/drive/MyDrive/OCR_project/FUNSD_dataset/validation_data/images/'

# transformations
train_transforms = get_transform(train=True)
val_transforms = get_transform(train=False)

# datasets
train_dataset = BBDataset(train_dataframe, train_data_path, transforms=train_transforms)
val_dataset = BBDataset(val_dataframe, val_data_path, transforms=val_transforms)

# training and validation loaders
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0, collate_fn=my_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0, collate_fn=my_collate_fn)

In [None]:
# helper function to check bounding boxes are correctly placed on documents
def show_image_with_boxes(image, boxes, labels=None):
    fig, ax = plt.subplots(1)
    ax.imshow(image)

    for box in boxes:
        rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

    plt.show()

## The model

In [None]:
def get_model(num_classes):
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

In [None]:
def train_model(model, data_loader, optimizer, device, num_epochs=100):
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(data_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=True)
        for images, targets in progress_bar:
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            running_loss += losses.item()
            progress_bar.set_postfix({'loss': running_loss / len(data_loader)})
        print(f"Epoch {epoch + 1} Average Loss: {running_loss / len(data_loader)}")

model = get_model(num_classes=2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
train_model(model, train_loader, optimizer, device)


In [None]:
# saving
torch.save(model, '/content/drive/MyDrive/OCR_project/bb_detector_final.pth')

In [None]:
# loading
model = torch.load('/content/drive/MyDrive/OCR_project/bb_detector_final.pth')

model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

## Evaluation

In [None]:
def add_box(ax, box, label, color):
    """Helper function to add a box to the axes."""
    rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=2, edgecolor=color, facecolor='none')
    ax.add_patch(rect)
    ax.text(box[0], box[1], label, bbox=dict(facecolor=color, alpha=0.5), fontsize=8, color='white')

def evaluate_model(model, data_loader, device, iou_threshold=0.5, num_images_to_show=15):
    model.eval()
    all_ious = []
    all_precisions = []
    all_recalls = []
    images_shown = 0

    with torch.no_grad():
        for images, targets in data_loader:
            images = list(img.to(device) for img in images)
            outputs = model(images)

            for image, output, target in zip(images, outputs, targets):
                pred_boxes = output['boxes'].cpu()
                pred_scores = output['scores'].cpu()
                gt_boxes = target['boxes'].cpu()

                iou = box_iou(gt_boxes, pred_boxes)
                max_iou, max_indices = torch.max(iou, dim=1) if iou.numel() > 0 else (torch.tensor([]), torch.tensor([]))

                true_positives = (max_iou >= iou_threshold).sum().item()
                false_positives = (max_iou < iou_threshold).sum().item()
                false_negatives = len(gt_boxes) - true_positives

                precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0
                recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0

                all_ious.extend(max_iou.tolist())
                all_precisions.append(precision)
                all_recalls.append(recall)

                if images_shown < num_images_to_show:
                    fig, ax = plt.subplots(1, figsize=(12, 8))
                    img = image.mul(255).permute(1, 2, 0).byte().cpu().numpy()
                    ax.imshow(img)
                    for box in gt_boxes:
                        add_box(ax, box, "GT", 'green')
                    for box, score in zip(pred_boxes, pred_scores):
                        if score > iou_threshold:
                            add_box(ax, box, f"Pred: {score:.2f}", 'red')
                    plt.show()
                    images_shown += 1

    mean_iou = np.mean(all_ious) if all_ious else 0
    mean_precision = np.mean(all_precisions) if all_precisions else 0
    mean_recall = np.mean(all_recalls) if all_recalls else 0
    f1_score = 2 * (mean_precision * mean_recall) / (mean_precision + mean_recall) if (mean_precision + mean_recall) > 0 else 0

    print(f"Mean IoU: {mean_iou}")
    print(f"Mean Precision: {mean_precision}")
    print(f"Mean Recall: {mean_recall}")
    print(f"F1 Score: {f1_score}")

    return mean_iou, mean_precision, mean_recall, f1_score

evaluate_model(model, val_loader, device)


# Character recognition

## Using pytesseract

In [None]:
!pip install pytesseract

In [None]:
!sudo apt-get install tesseract-ocr


In [None]:
import pytesseract

In [None]:
print(pytesseract.__version__)

In [None]:
all_classes = list(string.digits + string.ascii_lowercase + string.ascii_uppercase)
class_labels = {i: all_classes[i] for i in range(len(all_classes))}


class_to_index = {v: k for k, v in class_labels.items()}


def decode_label(index):
    return class_labels[index]


def swap_case(char):
    if char.islower():
        return char.upper()
    elif char.isupper():
        return char.lower()
    return char


data_path = '/content/drive/MyDrive/OCR_project/segmented/'
dataset = ImageFolder(root=data_path)


sample_size = int(0.2 * len(dataset))
random_sample = random.sample(dataset.samples, sample_size)


def get_true_labels(random_sample):
    true_labels = []
    for _, label in random_sample:
        true_labels.append(decode_label(label))
    return true_labels


def predict_labels(random_sample):
    predicted_labels = []
    config = "--psm 10 -c tessedit_char_whitelist=abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
    for img_path, _ in tqdm(random_sample):
        img = Image.open(img_path)
        prediction = pytesseract.image_to_string(img, config=config).strip()
        if prediction:
            predicted_labels.append(prediction[0])
        else:
            predicted_labels.append('')
    return predicted_labels


def display_images_with_labels(random_sample, true_labels, predicted_labels, batch_size=25):
    num_images = len(random_sample)
    num_batches = math.ceil(num_images / batch_size)

    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, num_images)
        batch_random_sample = random_sample[start_idx:end_idx]
        batch_true_labels = true_labels[start_idx:end_idx]
        batch_predicted_labels = predicted_labels[start_idx:end_idx]

        num_cols = 5
        num_rows = math.ceil(len(batch_random_sample) / num_cols)

        fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(15, num_rows * 3))
        axes = axes.flatten()

        for idx, ax in enumerate(axes):
            if idx < len(batch_random_sample):
                img_path = batch_random_sample[idx][0]
                img = Image.open(img_path)
                ax.imshow(img)
                ax.axis('off')
                ax.set_title(f'Predicted: {swap_case(batch_predicted_labels[idx])}\nTrue: {swap_case(batch_true_labels[idx])}')
            else:
                ax.axis('off')

        plt.tight_layout()
        plt.show()


true_labels = get_true_labels(random_sample)
predicted_labels = predict_labels(random_sample)


correct = sum(1 for true, pred in zip(true_labels, predicted_labels) if true.lower() == pred.lower())
correct_case_sensitive = sum(1 for true, pred in zip(true_labels, predicted_labels) if true == pred)
total = len(true_labels)
accuracy = correct / total
accuracy_case_sensitive = correct_case_sensitive / total
print(f'Accuracy: {accuracy * 100:.2f}%')
print(f'Case sensitive accuracy: {accuracy_case_sensitive * 100:.2f}%')


display_images_with_labels(random_sample, true_labels, predicted_labels)


In [None]:
correct = sum(1 for true, pred in zip(true_labels, predicted_labels) if true.lower() == pred.lower())
correct_case_sensitive = sum(1 for true, pred in zip(true_labels, predicted_labels) if true == pred)
total = len(true_labels)
accuracy = correct / total
accuracy_case_sensitive = correct_case_sensitive / total
print(f'Accuracy: {accuracy * 100:.2f}%')
print(f'Case sensitive accuracy: {accuracy_case_sensitive * 100:.2f}%')

In [None]:
print(len(all_classes))

## Using my own model

### Preprocessing

In [None]:
def preprocess_image(cropped_img_np):
    if cropped_img_np.ndim == 2:
        gray = cropped_img_np
    elif cropped_img_np.shape[2] == 4:
        cropped_img_np = cv2.cvtColor(cropped_img_np, cv2.COLOR_RGBA2RGB)
        gray = cv2.cvtColor(cropped_img_np, cv2.COLOR_RGB2GRAY)
    elif cropped_img_np.shape[2] == 3:
        gray = cv2.cvtColor(cropped_img_np, cv2.COLOR_RGB2GRAY)
    else:
        raise ValueError("Unexpected number of channels in the input image")

    binary_inv = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                       cv2.THRESH_BINARY_INV, 11, 2)

    return gray, binary_inv

### Segmentation

In [None]:
def segment_lines(binary_image):
    horizontal_projection = np.sum(binary_image, axis=1)
    lines = []
    in_line = False
    line_start = 0
    threshold = binary_image.shape[1] * 0.1

    for i, value in enumerate(horizontal_projection):
        if value > threshold and not in_line:
            line_start = i
            in_line = True
        elif value <= threshold and in_line:
            line_end = i
            if line_end - line_start >= 3:
                lines.append((line_start, line_end))
            in_line = False

    if in_line:
        if i - line_start >= 3:
            lines.append((line_start, i))

    return lines

def segment_characters(line_image):
    vertical_projection = np.sum(line_image, axis=0)
    characters = []
    in_character = False
    character_start = 0

    for i, value in enumerate(vertical_projection):
        if value > 0 and not in_character:
            character_start = i
            in_character = True
        elif value == 0 and in_character:
            character_end = i
            characters.append((character_start, character_end))
            in_character = False

    if in_character:
        characters.append((character_start, i))

    return characters, vertical_projection

def refine_segmentation(line_image, characters, vertical_projection, expected_count, multi_line):
    if len(characters) >= expected_count:
        return characters

    split_needed = expected_count - len(characters)

    if multi_line:
        split_needed = min(split_needed, 3)

    while split_needed > 0:
        widest_boxes = sorted(characters, key=lambda x: x[1] - x[0], reverse=True)
        start, end = widest_boxes[0]
        width = end - start

        if multi_line and split_needed == 1:
            split_count = 2
        else:
            split_count = min(split_needed + 1, 2)

        step = width // split_count

        new_splits = [(start + i * step, start + (i + 1) * step) for i in range(split_count)]
        if new_splits[-1][1] < end:
            new_splits[-1] = (new_splits[-1][0], end)

        characters = characters[:characters.index(widest_boxes[0])] + new_splits + characters[characters.index(widest_boxes[0]) + 1:]
        split_needed -= (split_count - 1)

        if multi_line and split_needed <= 0:
            break

    return characters

In [None]:
# helper function to save segmented images to folders according to their labels
def save_characters_to_folders(char_images, label, image_path):
    if not isinstance(label, str):
        label = str(label)

    map_dict = {}
    for char_image, char in zip(char_images, label.replace(" ", "")):
        if char_image.size == 0:
            continue
        folder_path = f'/content/drive/MyDrive/OCR_project/segmented/{char}'
        os.makedirs(folder_path, exist_ok=True)
        if char in map_dict:
            map_dict[char] += 1
        else:
            map_dict[char] = 0
        file_path = os.path.join(folder_path, f"{char}_v2_{map_dict[char]}_{os.path.basename(image_path)}.png")
        cv2.imwrite(file_path, char_image)

def process_and_save_characters(img_path, coordinates, label):
    image = Image.open(img_path)
    left, top, right, bottom = eval(coordinates)
    expected_char_count = len(str(label).replace(" ", ""))

    img = image.crop((left, top, right, bottom))
    width = right - left
    height = bottom - top
    if height > width * 1.5:
        img = img.rotate(90, expand=True)
    cropped_img_np = np.array(img)

    gray, binary_inv = preprocess_image(cropped_img_np)
    lines = segment_lines(binary_inv)
    multi_line = len(lines) > 1

    characters_all_lines = []
    final_char_images = []
    for line in lines:
        line_image = binary_inv[line[0]:line[1], :]
        characters, vertical_projection = segment_characters(line_image)
        characters_all_lines.append((line, characters, vertical_projection))

        detected_char_count = len(characters)
        if detected_char_count != expected_char_count and detected_char_count <= expected_char_count + 6:
            refined_characters = refine_segmentation(line_image, characters, vertical_projection, expected_char_count, multi_line)
        else:
            refined_characters = characters

        for char_start, char_end in refined_characters:
            x1 = max(char_start, 0)
            y1 = max(line[0], 0)
            x2 = min(char_end, binary_inv.shape[1])
            y2 = min(line[1], binary_inv.shape[0])
            final_char_images.append(cropped_img_np[y1:y2, x1:x2])

    if len(final_char_images) == expected_char_count:
        save_characters_to_folders(final_char_images, label, img_path)

In [None]:
def show_image_with_box(df, idx, root_dir):
    img_path = os.path.join(root_dir, df.iloc[idx]['file path'])
    coordinates = df.iloc[idx]['coordinates']
    label = df.iloc[idx]['label']
    process_and_save_characters(img_path, coordinates, label)

    image = Image.open(img_path)
    left, top, right, bottom = eval(coordinates)
    img = image.crop((left, top, right, bottom))
    cropped_img_np = np.array(img)
    width = right - left
    height = bottom - top
    if height > width * 1.5:
        cropped_img_np = cv2.rotate(cropped_img_np, cv2.ROTATE_90_COUNTERCLOCKWISE)
    gray, binary_inv = preprocess_image(cropped_img_np)
    lines = segment_lines(binary_inv)

    fig, axes = plt.subplots(1, 4, figsize=(24, 6))
    axes[0].imshow(gray, cmap='gray')
    axes[0].set_title("Grayscale Image")

    axes[1].imshow(binary_inv, cmap='gray')
    axes[1].set_title("Binary Image (Adaptive Thresholding)")

    ax_img_chars = axes[2]
    ax_img_chars.imshow(binary_inv, cmap='gray')
    ax_img_chars.set_title("Detected Characters")

    ax_img_refined = axes[3]
    ax_img_refined.imshow(binary_inv, cmap='gray')
    ax_img_refined.set_title("Refined Segmentation")

    padding = 1
    characters_all_lines = []
    for line in lines:
        line_image = binary_inv[line[0]:line[1], :]
        characters, vertical_projection = segment_characters(line_image)
        characters_all_lines.append((line, characters, vertical_projection))

        for char_start, char_end in characters:
            x1 = max(char_start - padding, 0)
            y1 = max(line[0] - padding, 0)
            x2 = min(char_end, binary_inv.shape[1])
            y2 = min(line[1] + padding, binary_inv.shape[0])
            rect_char = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, edgecolor='red', facecolor='none')
            ax_img_chars.add_patch(rect_char)

    for line, characters, vertical_projection in characters_all_lines:
        detected_char_count = len(characters)
        expected_char_count = len(str(label).replace(" ", ""))
        multi_line = len(lines) > 1
        if detected_char_count != expected_char_count and detected_char_count <= expected_char_count + 6:
            refined_characters = refine_segmentation(line_image, characters, vertical_projection, expected_char_count, multi_line)

            for char_start, char_end in refined_characters:
                x1 = max(char_start - padding, 0)
                y1 = max(line[0] - padding, 0)
                x2 = min(char_end, binary_inv.shape[1])
                y2 = min(line[1] + padding, binary_inv.shape[0])
                rect_refined_char = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, edgecolor='green', facecolor='none')
                ax_img_refined.add_patch(rect_refined_char)

    plt.tight_layout()
    plt.show()

file_path = '/content/drive/MyDrive/OCR_project/All_bounding_boxes.xlsx'
data = pd.read_excel(file_path)

for i in range(len(data)):
    show_image_with_box(data, i, '/content/drive/MyDrive/OCR_project/FUNSD_dataset/training_data/images/')


### CNN

In [None]:
# checking categories
def summarize_files_in_subfolders(folder_path):
    summary = {}

    for root, dirs, files in os.walk(folder_path):
        if root == folder_path:
            continue

        file_count = len([f for f in files if os.path.isfile(os.path.join(root, f))])

        relative_path = os.path.relpath(root, folder_path)

        summary[relative_path] = file_count

    return summary

folder_path = '/content/drive/MyDrive/OCR_project/segmented/'

file_summary = summarize_files_in_subfolders(folder_path)

for subfolder, count in file_summary.items():
    print(f"{subfolder}: {count} files")

In [None]:
# CNN architecture
class CharacterCNN(nn.Module):
    def __init__(self, num_classes):
        super(CharacterCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
class CharacterDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.dataset = ImageFolder(root=self.root_dir)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
# transformations
train_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    #transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

val_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
# data folder
data_path = '/content/drive/MyDrive/OCR_project/segmented/'

num_classes = len(os.listdir(data_path))

full_dataset = CharacterDataset(root_dir=data_path)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CharacterCNN(num_classes=num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

In [None]:
# training loop with early stopping
num_epochs = 100
early_stop_counter = 0
best_val_loss = float('inf')
patience = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, data in progress_bar:
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_description(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / (i + 1):.4f}')

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("Early stopping!")
            break

    scheduler.step(avg_val_loss)

In [None]:
# Save the trained model
torch.save(model.state_dict(), '/content/drive/MyDrive/OCR_project/CNN_segmented_final.pth')

### Evaluation

In [None]:
# evaluate and show true & predicted labels
def evaluate_model(model, validation_loader, device):
    model.eval()
    correct = 0
    total = 0

    num_images = len(validation_loader.dataset)
    rows = (num_images // 10) + 1 if num_images % 10 != 0 else num_images // 10
    fig, axes = plt.subplots(nrows=rows, ncols=10, figsize=(20, 2 * rows))
    plt.subplots_adjust(wspace=1, hspace=1)
    count = 0

    with torch.no_grad():
        for images, labels in validation_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for i in range(images.shape[0]):
                if count >= num_images:
                    break
                row = count // 10
                col = count % 10
                ax = axes[row, col] if rows > 1 else axes[col]
                img = images[i].cpu().numpy().transpose((1, 2, 0))
                img = (img - img.min()) / (img.max() - img.min())
                ax.imshow(img)
                ax.axis('off')
                ax.set_title(f'Predicted: {predicted[i].item()}\nTrue: {labels[i].item()}')
                count += 1
            if count >= num_images:
                break

    plt.show()
    print(f'Accuracy on validation set: {100 * correct / total}%')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evaluate_model(model.to(device), val_loader, device)
