In [23]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

from constants import label_mapping, train_params, test_params, pred_map
from helper import HOG_Dataset, get_hog
from model import DocumentAlignmentNet, train_cnn, evaluate_cnn

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [3]:
train_data = pd.read_csv('data/train_data.csv')
train_data.head()

Unnamed: 0,path,angles
0,data/rotated_images/train/fhdw0081_8_180.png,180
1,data/rotated_images/train/gkbm0020_6_180.png,180
2,data/rotated_images/train/jlxw0228_6_90.png,90
3,data/rotated_images/train/jmlw0023_1_180.png,180
4,data/rotated_images/train/ffcn0226_5_180.png,180


In [4]:
test_data = pd.read_csv('data/test_data.csv')
test_data.head()

Unnamed: 0,path,angles
0,data/rotated_images/test/yrvw0217_69_0.png,0
1,data/rotated_images/test/yrvw0217_68_0.png,0
2,data/rotated_images/test/tpcw0217_1_90.png,90
3,data/rotated_images/test/ysbw0217_11_0.png,0
4,data/rotated_images/test/zzcn0020_2_180.png,180


In [5]:
training_data= HOG_Dataset(train_data)
testing_data = HOG_Dataset(test_data)

training_loader = DataLoader(training_data, **train_params)
testing_loader = DataLoader(testing_data, **test_params)

model = DocumentAlignmentNet()

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 40

#writer
# writer = SummaryWriter(f"logs/alignment_prediction_{fol_num}")

In [6]:
for epoch in range(num_epochs):
    
    #best model params
    best_model_state = None
    best_val_accuracy = 0.0
    epochs_without_improvement = 0
    
    #train and val loop
    train_loss, train_acc = train_cnn(model, criterion, optimizer, training_loader, epoch)
    print(f'Epoch [{epoch}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}')
    val_loss, val_acc = evaluate_cnn(model, criterion, testing_loader, epoch)
    print(f'Epoch [{epoch}/{num_epochs}], Test Loss: {val_loss:.4f}, Test Accuracy: {val_acc:.4f}')
    
    # scheduler.step()

    #early stopping
    if val_acc > best_val_accuracy:
            best_val_accuracy = val_acc
            best_model_state = model.state_dict()
            epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= 4: #if model acc does not improve for 4 epochs
            print(f"Early stopping at epoch {epoch}.")
            break

100%|██████████| 24/24 [00:51<00:00,  2.13s/it]


Epoch [0/40], Train Loss: 1.4169, Train Accuracy: 0.2211


100%|██████████| 7/7 [00:11<00:00,  1.60s/it]


Epoch [0/40], Test Loss: 1.3783, Test Accuracy: 0.2692


100%|██████████| 24/24 [00:49<00:00,  2.05s/it]


Epoch [1/40], Train Loss: 1.3480, Train Accuracy: 0.3579


100%|██████████| 7/7 [00:11<00:00,  1.58s/it]


Epoch [1/40], Test Loss: 1.3397, Test Accuracy: 0.5962


100%|██████████| 24/24 [00:49<00:00,  2.05s/it]


Epoch [2/40], Train Loss: 1.2804, Train Accuracy: 0.4368


100%|██████████| 7/7 [00:11<00:00,  1.57s/it]


Epoch [2/40], Test Loss: 1.2788, Test Accuracy: 0.4808


100%|██████████| 24/24 [00:49<00:00,  2.06s/it]


Epoch [3/40], Train Loss: 1.2132, Train Accuracy: 0.4947


100%|██████████| 7/7 [00:10<00:00,  1.57s/it]


Epoch [3/40], Test Loss: 1.2352, Test Accuracy: 0.5769


100%|██████████| 24/24 [00:48<00:00,  2.03s/it]


Epoch [4/40], Train Loss: 1.1185, Train Accuracy: 0.6474


100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


Epoch [4/40], Test Loss: 1.2177, Test Accuracy: 0.5769


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [5/40], Train Loss: 1.0488, Train Accuracy: 0.7368


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [5/40], Test Loss: 1.1151, Test Accuracy: 0.7308


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [6/40], Train Loss: 0.9768, Train Accuracy: 0.7895


100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


Epoch [6/40], Test Loss: 1.0683, Test Accuracy: 0.7692


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [7/40], Train Loss: 0.9929, Train Accuracy: 0.7684


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [7/40], Test Loss: 1.0116, Test Accuracy: 0.8269


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [8/40], Train Loss: 0.9713, Train Accuracy: 0.8000


100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


Epoch [8/40], Test Loss: 0.9813, Test Accuracy: 0.8269


100%|██████████| 24/24 [00:48<00:00,  2.03s/it]


Epoch [9/40], Train Loss: 0.9291, Train Accuracy: 0.8316


100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


Epoch [9/40], Test Loss: 0.9638, Test Accuracy: 0.8269


100%|██████████| 24/24 [00:48<00:00,  2.01s/it]


Epoch [10/40], Train Loss: 0.8987, Train Accuracy: 0.8526


100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


Epoch [10/40], Test Loss: 0.9450, Test Accuracy: 0.8462


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [11/40], Train Loss: 0.8594, Train Accuracy: 0.9053


100%|██████████| 7/7 [00:10<00:00,  1.54s/it]


Epoch [11/40], Test Loss: 0.8630, Test Accuracy: 0.9231


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [12/40], Train Loss: 0.8451, Train Accuracy: 0.9105


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [12/40], Test Loss: 0.8596, Test Accuracy: 0.9038


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [13/40], Train Loss: 0.8606, Train Accuracy: 0.9000


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [13/40], Test Loss: 0.8990, Test Accuracy: 0.8462


100%|██████████| 24/24 [00:49<00:00,  2.06s/it]


Epoch [14/40], Train Loss: 0.8311, Train Accuracy: 0.9316


100%|██████████| 7/7 [00:11<00:00,  1.58s/it]


Epoch [14/40], Test Loss: 0.8516, Test Accuracy: 0.8846


100%|██████████| 24/24 [00:50<00:00,  2.11s/it]


Epoch [15/40], Train Loss: 0.7987, Train Accuracy: 0.9579


100%|██████████| 7/7 [00:11<00:00,  1.62s/it]


Epoch [15/40], Test Loss: 0.8334, Test Accuracy: 0.9423


100%|██████████| 24/24 [00:51<00:00,  2.15s/it]


Epoch [16/40], Train Loss: 0.8033, Train Accuracy: 0.9526


100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


Epoch [16/40], Test Loss: 0.8586, Test Accuracy: 0.9231


100%|██████████| 24/24 [00:48<00:00,  2.03s/it]


Epoch [17/40], Train Loss: 0.7969, Train Accuracy: 0.9684


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [17/40], Test Loss: 0.8436, Test Accuracy: 0.8846


100%|██████████| 24/24 [00:48<00:00,  2.03s/it]


Epoch [18/40], Train Loss: 0.7891, Train Accuracy: 0.9684


100%|██████████| 7/7 [00:10<00:00,  1.57s/it]


Epoch [18/40], Test Loss: 0.8429, Test Accuracy: 0.9038


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [19/40], Train Loss: 0.7857, Train Accuracy: 0.9632


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [19/40], Test Loss: 0.8181, Test Accuracy: 0.9423


100%|██████████| 24/24 [00:49<00:00,  2.04s/it]


Epoch [20/40], Train Loss: 0.7850, Train Accuracy: 0.9789


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [20/40], Test Loss: 0.7943, Test Accuracy: 0.9808


100%|██████████| 24/24 [00:48<00:00,  2.04s/it]


Epoch [21/40], Train Loss: 0.7847, Train Accuracy: 0.9684


100%|██████████| 7/7 [00:11<00:00,  1.61s/it]


Epoch [21/40], Test Loss: 0.8166, Test Accuracy: 0.9423


100%|██████████| 24/24 [00:48<00:00,  2.03s/it]


Epoch [22/40], Train Loss: 0.7914, Train Accuracy: 0.9474


100%|██████████| 7/7 [00:11<00:00,  1.58s/it]


Epoch [22/40], Test Loss: 0.8308, Test Accuracy: 0.9423


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [23/40], Train Loss: 0.7808, Train Accuracy: 0.9789


100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


Epoch [23/40], Test Loss: 0.8134, Test Accuracy: 0.9231


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Epoch [24/40], Train Loss: 0.7703, Train Accuracy: 0.9842


100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


Epoch [24/40], Test Loss: 0.8034, Test Accuracy: 0.9615


100%|██████████| 24/24 [00:48<00:00,  2.04s/it]


Epoch [25/40], Train Loss: 0.7789, Train Accuracy: 0.9737


100%|██████████| 7/7 [00:10<00:00,  1.54s/it]


Epoch [25/40], Test Loss: 0.7921, Test Accuracy: 0.9808


100%|██████████| 24/24 [00:49<00:00,  2.07s/it]


Epoch [26/40], Train Loss: 0.7644, Train Accuracy: 0.9842


100%|██████████| 7/7 [00:10<00:00,  1.57s/it]


Epoch [26/40], Test Loss: 0.8145, Test Accuracy: 0.9231


100%|██████████| 24/24 [00:50<00:00,  2.12s/it]


Epoch [27/40], Train Loss: 0.7666, Train Accuracy: 0.9789


100%|██████████| 7/7 [00:11<00:00,  1.57s/it]


Epoch [27/40], Test Loss: 0.7955, Test Accuracy: 0.9615


100%|██████████| 24/24 [00:51<00:00,  2.13s/it]


Epoch [28/40], Train Loss: 0.7724, Train Accuracy: 0.9789


100%|██████████| 7/7 [00:11<00:00,  1.64s/it]


Epoch [28/40], Test Loss: 0.7911, Test Accuracy: 0.9808


100%|██████████| 24/24 [00:51<00:00,  2.17s/it]


Epoch [29/40], Train Loss: 0.7799, Train Accuracy: 0.9684


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [29/40], Test Loss: 0.7826, Test Accuracy: 0.9808


100%|██████████| 24/24 [00:51<00:00,  2.13s/it]


Epoch [30/40], Train Loss: 0.7690, Train Accuracy: 0.9737


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [30/40], Test Loss: 0.8074, Test Accuracy: 0.9423


100%|██████████| 24/24 [00:51<00:00,  2.14s/it]


Epoch [31/40], Train Loss: 0.7721, Train Accuracy: 0.9737


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [31/40], Test Loss: 0.7829, Test Accuracy: 0.9615


100%|██████████| 24/24 [00:51<00:00,  2.15s/it]


Epoch [32/40], Train Loss: 0.7660, Train Accuracy: 0.9895


100%|██████████| 7/7 [00:11<00:00,  1.61s/it]


Epoch [32/40], Test Loss: 0.8230, Test Accuracy: 0.9231


100%|██████████| 24/24 [00:52<00:00,  2.17s/it]


Epoch [33/40], Train Loss: 0.7652, Train Accuracy: 0.9789


100%|██████████| 7/7 [00:10<00:00,  1.57s/it]


Epoch [33/40], Test Loss: 0.7850, Test Accuracy: 0.9808


100%|██████████| 24/24 [00:51<00:00,  2.15s/it]


Epoch [34/40], Train Loss: 0.7620, Train Accuracy: 0.9895


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [34/40], Test Loss: 0.7925, Test Accuracy: 0.9615


100%|██████████| 24/24 [00:50<00:00,  2.11s/it]


Epoch [35/40], Train Loss: 0.7572, Train Accuracy: 0.9895


100%|██████████| 7/7 [00:11<00:00,  1.57s/it]


Epoch [35/40], Test Loss: 0.7997, Test Accuracy: 0.9615


100%|██████████| 24/24 [00:50<00:00,  2.12s/it]


Epoch [36/40], Train Loss: 0.7633, Train Accuracy: 0.9789


100%|██████████| 7/7 [00:10<00:00,  1.57s/it]


Epoch [36/40], Test Loss: 0.7753, Test Accuracy: 0.9808


100%|██████████| 24/24 [00:50<00:00,  2.11s/it]


Epoch [37/40], Train Loss: 0.7570, Train Accuracy: 0.9895


100%|██████████| 7/7 [00:10<00:00,  1.57s/it]


Epoch [37/40], Test Loss: 0.7811, Test Accuracy: 0.9808


100%|██████████| 24/24 [00:50<00:00,  2.11s/it]


Epoch [38/40], Train Loss: 0.7567, Train Accuracy: 0.9895


100%|██████████| 7/7 [00:10<00:00,  1.56s/it]


Epoch [38/40], Test Loss: 0.8173, Test Accuracy: 0.9423


100%|██████████| 24/24 [00:50<00:00,  2.11s/it]


Epoch [39/40], Train Loss: 0.7573, Train Accuracy: 0.9895


100%|██████████| 7/7 [00:10<00:00,  1.57s/it]

Epoch [39/40], Test Loss: 0.7989, Test Accuracy: 0.9615





In [7]:
# model.load_state_dict(torch.load('cnn_model.pth'))
# model

DocumentAlignmentNet(
  (conv_layers): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Dropout(p=0.2, inplace=False)
    (5): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Dropout(p=0.2, inplace=False)
    (10): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Dropout(p=0.2, inplace=False)
    (15): Conv2d(64, 128, kernel_siz

## Prediction

In [24]:
import torch
model = DocumentAlignmentNet()
model.load_state_dict(torch.load('cnn_model.pth'))
model

DocumentAlignmentNet(
  (conv_layers): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Dropout(p=0.2, inplace=False)
    (5): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Dropout(p=0.2, inplace=False)
    (10): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Dropout(p=0.2, inplace=False)
    (15): Conv2d(64, 128, kernel_siz

In [25]:
pred_map

{0: 0, 1: 90, 2: 180, 3: 270}

In [26]:
prediction_data = pd.read_csv("data/prediction_data.csv")
prediction_data.head()

Unnamed: 0,path,angles
0,data/rotated_images/prediction/qlbp0225_8_270.png,270
1,data/rotated_images/prediction/hxmn0000_5_90.png,90
2,data/rotated_images/prediction/qmkp0227_1_270.png,270
3,data/rotated_images/prediction/hxkw0023_2_0.png,0
4,data/rotated_images/prediction/hxlf0065_1_180.png,180


In [27]:
# Open the document image
def get_prediction(path):
    hog_features = get_hog(path)
    pixels = training_data.transform(hog_features)
    # print()
    prediction = model.predict(pixels.unsqueeze(0))
    
    return pred_map.get(prediction.item())

In [28]:
get_prediction("data/rotated_images/prediction/qkmg0065_1_180.png")

180

In [16]:
from sklearn.metrics import accuracy_score, confusion_matrix

prediction_data['prediction'] = prediction_data['path'].apply(lambda x: get_prediction(x))

In [17]:
print("Accuracy Score: ", accuracy_score(prediction_data['angles'], prediction_data['prediction']))
print("Confusion Matrix: ", confusion_matrix(prediction_data['angles'], prediction_data['prediction']))

Accuracy Score:  0.8857142857142857
Confusion Matrix:  [[ 8  1  1  0]
 [ 0  6  0  0]
 [ 0  2  7  0]
 [ 0  0  0 10]]
