In [1]:
import os
import torch
import pandas as pd
import callbacks
import data
import dataUtils as du
import image
import metric
import modelArc
import trainer
import json as js
from config import ConfigFile
from datetime import datetime
import torchvision.transforms as transforms
import cv2
import numpy as np
import matplotlib.pyplot as plt
from mltu.torch.dataProvider import DataProvider


In [2]:
#load data
#Specify path to main database
data_dir = r"D:\photos\RCNN4\BBOXES"
model_path = r"D:\Projects\reciept-scanner\RCNN\models"
database, vocab, max_len = [], set(), 0

largest_index = 492

for id in range(0, largest_index+1):

    img_path = os.path.join(data_dir, str(id) + ".jpg").replace("\\","/")
    
    if os.path.exists(img_path):
        
        with open(os.path.join(data_dir, str(id) + ".txt").replace("\\","/"), 'r') as file:
            ground_truths = [line.strip() for line in file.readlines()]
        
        for line in ground_truths:
            if not line.strip() == '':
                label = line.rstrip("\n")
                database.append([img_path, label])
                vocab.update(list(label))
                max_len = max(max_len, len(label))
    else:
        print("image with index " + str(id) + " do not exist")

print("dataset1 done")

#load second dataset
data_path = r"D:\photos\SORIE"

path = os.path.join(data_path, "train").replace("\\","/")
i = 1
while i <= 2:
    with open(os.path.join(path, "metadata.jsonl").replace("\\","/"), 'r') as file:
        for line in file:
            
            row = js.loads(line)
            img_path = os.path.join(path, row.get("file_name")).replace("\\","/")
            if os.path.exists(img_path):
                label = row.get("text").rstrip("\n")
                vocab.update(list(label))
                max_len = max(max_len, len(label))
                database.append([img_path, label])
            else:
                print("image with path " + str(img_path) + " do not exist")
    if i == 1:
        print("dataset2 done")
    
    i += 1
    path = os.path.join(data_path, "test").replace("\\","/")

print("dataset3 done")
print(len(database))


image with index 187 do not exist
image with index 190 do not exist
image with index 194 do not exist
image with index 197 do not exist
image with index 199 do not exist
image with index 354 do not exist
image with index 366 do not exist
image with index 370 do not exist
image with index 372 do not exist
image with index 381 do not exist
image with index 393 do not exist
dataset1 done
dataset2 done
dataset3 done
52433


In [3]:
#normalize images and load them as an np array

#get the mean and std of dataset
num_pixels = 0
channel_sum = torch.tensor([0.0, 0.0, 0.0])
channel_sum_squared = torch.tensor([0.0, 0.0, 0.0])

for pair in database:
    image_path = pair[0]
    img = cv2.imread(image_path)

    if img is None:
        print(f"Warning: Could not read image at {image_path}. Skipping.")
        continue

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    pair[0] = img

    height, width, num_channels = img.shape

    num_pixels += height * width

    channel_sum += torch.tensor(img.sum(axis=(0, 1)), dtype=torch.float64)
    channel_sum_squared += torch.tensor((img.astype(np.float64) ** 2).sum(axis=(0, 1)), dtype=torch.float64)

mean = channel_sum/num_pixels
variance = (channel_sum_squared / num_pixels) - (mean**2)
std = torch.sqrt(variance)
print("Number of pixels:", num_pixels)
print("Mean:", mean.numpy())
print("Standard Deviation:", std.numpy())

for pair in database:
    img_tensor = torch.tensor(pair[0], dtype=torch.float64) / 255.0

    norm_img = (img_tensor - mean) / std
    pair[0] = norm_img.numpy()

print("normalization done")


Number of pixels: 647645336
Mean: [200.22298 196.8244  193.20657]
Standard Deviation: [74.73709  75.59093  76.772415]
normalization done


In [6]:
#create data loaders
model_config = ConfigFile(name = "CRNN1", path = model_path, lr=0.0022, bs=64)

model_config.vocab = "".join(vocab)
model_config.max_txt_len = max_len
model_config.save()

dataset_loader = data.DataLoader(dataset = database, batch_size = model_config.batch_size, 
                                 data_preprocessors = [image.ImageReader(image.CVImage)], 
                                 transformers = [du.ImageResizer(model_config.width, model_config.height), du.LabelIndexer(model_config.vocab), 
                                                 du.LabelPadding(padding_value = len(model_config.vocab), max_word_len = max_len)])#, du.ImageShowCV2()

print("splitting data")
train_set, val_set = dataset_loader.split(split = 0.8)
print("splitting done")
train_set.augmentors = [
    du.RandomBrightness(),
    du.RandomErodeDilate(),
    du.RandomSharpen(),
    du.RandomRotate(angle=10),
    ]

config saved toD:/Projects/reciept-scanner/RCNN/models/202407151638
splitting data


KeyboardInterrupt: 

In [None]:
#initialize model, optimizer, and loss
model = modelArc.CRNN(len(model_config.vocab))
loss = trainer.CTCLoss(blank = len(model_config.vocab))
optimizer = torch.optim.Adam(model.parameters(), lr=model_config.lr)

if torch.cuda.is_available():
    model = model.cuda()
    print("CUDA Enabled...Training On GPU")


In [None]:
#initialze callbacks and trainer
earlystop = callbacks.EarlyStopping(monitor = "val_CER", patience = 10, verbose = True)
ckpt = callbacks.ModelCheckpoint((model_config.model_path + "/model.pt").replace("\\","/"), monitor = "val_CER", verbose = True)
tracker = callbacks.TensorBoard((model_config.model_path + "/logs").replace("\\","/"))
auto_lr = callbacks.ReduceLROnPlateau(monitor = "val_CER", factor=0.9, patience = 10, verbose = True)
save_model = callbacks.Model2onnx(saved_model_path = (os.path.join(model_path, datetime.strftime(datetime.now(), "%Y%m%d%H%M"),"model.pt").replace("\\","/")), input_shape = (1, model_config.height, model_config.width, 3), verbose = True, metadata = {"vocab": model_config.vocab})


train_struct = trainer.Trainer(model, optimizer, loss, metrics = [metric.CERMetric(model_config.vocab), metric.WERMetric(model_config.vocab)])

In [None]:
#train
train_struct.run(train_set, val_set, epochs=4, callbacks = [ckpt, tracker, auto_lr, save_model])#earlystop,

train_set.to_csv(os.path.join(model_config.model_path, "train.csv").replace("\\","/"))
val_set.to_csv(os.path.join(model_config.model_path, "val.csv").replace("\\","/"))