In [1]:
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import os
import torch
import torch.nn as nn
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont
from nltk.corpus import words

import sys
sys.path.append('./../../')
from models.rnn.rnn import BitCounterRNN, BitCountingDataset, collate_fn_rnn
from models.ocr.ocr import OCRModel, OCRDataset

In [2]:
def generate_word_images(word_list, image_dir, image_size=(256, 64)):
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)
    for word in tqdm(word_list, desc="Generating Images"):
        img = Image.new('L', image_size, color=255)
        draw = ImageDraw.Draw(img)
        font = ImageFont.load_default()
        font_size = 24
        font = ImageFont.truetype("arial.ttf", font_size)
        bbox = draw.textbbox((0, 0), word, font=font)
        x = (image_size[0] - (bbox[2] - bbox[0])) // 2
        y = (image_size[1] - (bbox[3] - bbox[1])) // 2
        draw.text((x, y), word, font=font, fill=0)
        img.save(os.path.join(image_dir, f"{word}.png")) 

image_dir = "./../../data/external/word_images"
# word_list = words.words()
# word_list = list(set(word_list))
# word_list = word_list[:100000]
# generate_word_images(word_list, image_dir)

In [3]:
def create_image_label_lists(image_dir):
    image_paths = []
    labels = []
    for filename in os.listdir(image_dir):
        if filename.endswith(".png"):
            if ('-' in filename) or ('_' in filename):
                continue
            label = os.path.splitext(filename)[0]
            image_paths.append(os.path.join(image_dir, filename))
            labels.append(label)
    return image_paths, labels

In [19]:
image_paths, labels = create_image_label_lists(image_dir)

max_word_length = 0
for i in range(len(labels)):
    max_word_length = max(max_word_length, len(labels[i]))

print(f"Max Word Length: {max_word_length}")

np.random.seed(0)
temp_paths = image_paths.copy()
temp_labels = labels.copy()
perm = np.random.permutation(len(image_paths))
for i in range(len(image_paths)):
    image_paths[i] = temp_paths[perm[i]]
    labels[i] = temp_labels[perm[i]]

train_size = int(0.8 * len(image_paths))
val_size = int(0.1 * len(image_paths))
test_size = len(image_paths) - train_size - val_size

train_paths, val_paths, test_paths = image_paths[:train_size], image_paths[train_size:train_size + val_size], image_paths[train_size + val_size:]
train_labels, val_labels, test_labels = labels[:train_size], labels[train_size:train_size + val_size], labels[train_size + val_size:]

train_dataset = OCRDataset(train_paths, train_labels, max_length=max_word_length)
val_dataset = OCRDataset(val_paths, val_labels, max_length=max_word_length)
test_dataset = OCRDataset(test_paths, test_labels, max_length=max_word_length)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Max Word Length: 24


In [23]:
def decode_label(one_hot_encoded):
    char_map = "@ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
    decoded = ""
    for i in range(one_hot_encoded.shape[0]):
        idx = torch.argmax(one_hot_encoded[i]).item()
        if idx == 0:
            break
        decoded += char_map[idx]
    return decoded

In [24]:
def get_weights(labels):
    weigths = torch.ones(53, dtype=torch.float)
    # char_map = "@ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
    weigths[0] = 0.2
    return weigths

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = OCRModel(num_classes=53, max_length=max_word_length).to(device)
weigth = get_weights(train_labels).to(device)
criterion = nn.CrossEntropyLoss(weight=weigth)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [26]:
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device='cuda'):
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}] - Training")
        
        for images, labels in train_loader_tqdm:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            outputs = outputs.permute(0, 2, 1)
            labels = labels.permute(0, 2, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        correct_chars = 0
        total_chars = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                outputs = outputs.permute(0, 2, 1)
                labels = labels.permute(0, 2, 1)
                val_loss += criterion(outputs, labels).item()
                outputs = outputs.permute(0, 2, 1)
                labels = labels.permute(0, 2, 1)
                for i in range(len(labels)):
                    predicted_label = decode_label(outputs[i])
                    true_label = decode_label(labels[i])
                    
                    # if i < 5000 and i % 500 == 0:
                    #     print(f"Predicted: {predicted_label}, True: {true_label}")
                    for j in range(len(true_label)):
                        if j < len(predicted_label) and predicted_label[j] == true_label[j]:
                            correct_chars += 1
                        total_chars += 1

        val_loss /= len(val_loader)
        avg_correct_chars = correct_chars / total_chars
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Avg Correct Chars in Val: {avg_correct_chars:.4f}")


In [27]:
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device)

Epoch [1/10] - Training: 100%|██████████| 1247/1247 [02:51<00:00,  7.26it/s]


Epoch [1/10], Train Loss: 1.0778, Val Loss: 0.7991, Avg Correct Chars in Val: 0.3680


Epoch [2/10] - Training: 100%|██████████| 1247/1247 [02:41<00:00,  7.72it/s]


Epoch [2/10], Train Loss: 0.5046, Val Loss: 0.2573, Avg Correct Chars in Val: 0.7963


Epoch [3/10] - Training: 100%|██████████| 1247/1247 [02:22<00:00,  8.73it/s]


Epoch [3/10], Train Loss: 0.2053, Val Loss: 0.1788, Avg Correct Chars in Val: 0.8597


Epoch [4/10] - Training: 100%|██████████| 1247/1247 [02:19<00:00,  8.96it/s]


Epoch [4/10], Train Loss: 0.1252, Val Loss: 0.1763, Avg Correct Chars in Val: 0.8687


Epoch [5/10] - Training: 100%|██████████| 1247/1247 [02:50<00:00,  7.33it/s]


Epoch [5/10], Train Loss: 0.0914, Val Loss: 0.0639, Avg Correct Chars in Val: 0.9505


Epoch [6/10] - Training: 100%|██████████| 1247/1247 [02:59<00:00,  6.93it/s]


Epoch [6/10], Train Loss: 0.0747, Val Loss: 0.0565, Avg Correct Chars in Val: 0.9565


Epoch [7/10] - Training: 100%|██████████| 1247/1247 [03:00<00:00,  6.92it/s]


Epoch [7/10], Train Loss: 0.0656, Val Loss: 0.0430, Avg Correct Chars in Val: 0.9671


Epoch [8/10] - Training: 100%|██████████| 1247/1247 [03:28<00:00,  5.99it/s]


Epoch [8/10], Train Loss: 0.0600, Val Loss: 0.0460, Avg Correct Chars in Val: 0.9654


Epoch [9/10] - Training: 100%|██████████| 1247/1247 [02:57<00:00,  7.04it/s]


Epoch [9/10], Train Loss: 0.0527, Val Loss: 0.0504, Avg Correct Chars in Val: 0.9611


Epoch [10/10] - Training: 100%|██████████| 1247/1247 [02:49<00:00,  7.37it/s]


Epoch [10/10], Train Loss: 0.0482, Val Loss: 0.0474, Avg Correct Chars in Val: 0.9642


In [28]:
def random_baseline_accuracy(labels):
    correct_chars = 0
    total_chars = 0
    for label in labels:
        predicted_label = ''.join(np.random.choice(list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'), len(label)))
        for i in range(len(label)):
            if predicted_label[i] == label[i]:
                correct_chars += 1
            total_chars += 1
    return correct_chars / total_chars

random_accuracy = random_baseline_accuracy(val_labels)
print(f"Random Baseline Accuracy: {random_accuracy:.4f}")

Random Baseline Accuracy: 0.0180


In [29]:
model.eval()
test_loss = 0
correct_chars = 0
total_chars = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        outputs = outputs.permute(0, 2, 1)
        labels = labels.permute(0, 2, 1)
        test_loss += criterion(outputs, labels).item()
        outputs = outputs.permute(0, 2, 1)
        labels = labels.permute(0, 2, 1)
        for i in range(len(labels)):
            predicted_label = decode_label(outputs[i])
            true_label = decode_label(labels[i])
            # if i%100 == 0:
            print(predicted_label, "-> Predicted")
            print(true_label, "-> True")
            for j in range(len(true_label)):
                if j < len(predicted_label) and predicted_label[j] == true_label[j]:
                    correct_chars += 1
                total_chars += 1

test_loss /= len(test_loader)
avg_correct_chars = correct_chars / total_chars

bizarre -> Predicted
bizarre -> True
eacnwhere -> Predicted
eachwhere -> True
taxably -> Predicted
taxably -> True
ellagelliferous -> Predicted
eflagelliferous -> True
palaeodenorolgiiag -> Predicted
palaeodendrologic -> True
adulatory -> Predicted
adulatory -> True
psychogenetical -> Predicted
psychogenetical -> True
roaded -> Predicted
roaded -> True
glossophorous -> Predicted
glossophorous -> True
muddily -> Predicted
muddily -> True
trimorphism -> Predicted
trimorphism -> True
suuamigerous -> Predicted
squamigerous -> True
jocund -> Predicted
jocund -> True
ophthalmopo -> Predicted
ophthalmopod -> True
vulpecular -> Predicted
vulpecular -> True
mulishness -> Predicted
mulishness -> True
neurofifrily -> Predicted
neurofibril -> True
enzymolytic -> Predicted
enzymolytic -> True
exterminist -> Predicted
exterminist -> True
interpenetratived -> Predicted
interpenetrative -> True
peculiar -> Predicted
peculiar -> True
barbastel -> Predicted
barbastel -> True
violaceously -> Predicted
vi

In [30]:
print(f"Test Loss: {test_loss:.4f}, Avg Correct Chars in Test: {avg_correct_chars:.4f}")

Test Loss: 0.0465, Avg Correct Chars in Test: 0.9645
