In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict
from scipy.special import logsumexp
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader,random_split
import numpy as np
from torchmetrics.text import CharErrorRate
metric = CharErrorRate()

from PIL import Image
from config import ModelConfigs
configs = ModelConfigs()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_folder, label_file, transform=None):
        self.image_folder = image_folder
        self.label_file = label_file
        self.transform = transform

        with open(label_file, 'r', encoding="utf8") as f:
            self.data = f.readlines()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        line = self.data[idx]
        line = line.rstrip('\n')
        line = line.split("\t")
        
        image_path = line[0]
        label = ' '.join(line[1:])

        target = [configs.CHAR2LABEL[c] for c in label]
        
        target_length = [len(target)]

        target = torch.LongTensor(target)
        target_length = torch.LongTensor(target_length)
                    
        image = Image.open(os.path.join(self.image_folder, image_path))
        image = self.transform(image)

        return image, target, target_length
    
def synth90k_collate_fn(batch):
    images, targets, target_lengths = zip(*batch)
    images = torch.stack(images, 0)
    targets = torch.cat(targets, 0)
    target_lengths = torch.cat(target_lengths, 0)
    return images, targets, target_lengths

In [None]:
data_transform = transforms.Compose([
    transforms.Resize((configs.height, configs.width)),
    transforms.Grayscale(1),
    transforms.ToTensor(),      
    transforms.Normalize((0.5,), (0.5,))

])

image_folder = "data_set/new_train"
label_file = "data_set/train_gt.txt"
custom_dataset = CustomDataset(image_folder, label_file, transform=data_transform)
train_size = int(0.9 * len(custom_dataset))
test_size = len(custom_dataset) - train_size
train_dataset, test_dataset =random_split(custom_dataset, [train_size, test_size])
batch_size = 20
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=synth90k_collate_fn)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,collate_fn=synth90k_collate_fn)


In [None]:
plt.imshow(custom_dataset[0][0].squeeze(),cmap="gray")
custom_dataset[0][0].squeeze()

In [None]:
NINF = -1 * float('inf')
DEFAULT_EMISSION_THRESHOLD = 0.01
def _reconstruct(labels, blank=0):
    new_labels = []
    # merge same labels
    previous = None
    for l in labels:
        if l != previous:
            new_labels.append(l)
            previous = l
    # delete blank
    new_labels = [l for l in new_labels if l != blank]

    return new_labels


def greedy_decode(emission_log_prob, blank=0, **kwargs):
    labels = np.argmax(emission_log_prob, axis=-1)
    labels = _reconstruct(labels, blank=blank)
    return labels

def prefix_beam_decode(emission_log_prob, blank=0, **kwargs):
    beam_size = 5
    emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))

    length, class_count = emission_log_prob.shape

    beams = [(tuple(), (0, NINF))]  # (prefix, (blank_log_prob, non_blank_log_prob))
    # initial of beams: (empty_str, (log(1.0), log(0.0)))

    for t in range(length):
        new_beams_dict = defaultdict(lambda: (NINF, NINF))  # log(0.0) = NINF

        for prefix, (lp_b, lp_nb) in beams:
            for c in range(class_count):
                log_prob = emission_log_prob[t, c]
                if log_prob < emission_threshold:
                    continue

                end_t = prefix[-1] if prefix else None

                # if new_prefix == prefix
                new_lp_b, new_lp_nb = new_beams_dict[prefix]

                if c == blank:
                    new_beams_dict[prefix] = (
                        logsumexp([new_lp_b, lp_b + log_prob, lp_nb + log_prob]),
                        new_lp_nb
                    )
                    continue
                if c == end_t:
                    new_beams_dict[prefix] = (
                        new_lp_b,
                        logsumexp([new_lp_nb, lp_nb + log_prob])
                    )

                # if new_prefix == prefix + (c,)
                new_prefix = prefix + (c,)
                new_lp_b, new_lp_nb = new_beams_dict[new_prefix]

                if c != end_t:
                    new_beams_dict[new_prefix] = (
                        new_lp_b,
                        logsumexp([new_lp_nb, lp_b + log_prob, lp_nb + log_prob])
                    )
                else:
                    new_beams_dict[new_prefix] = (
                        new_lp_b,
                        logsumexp([new_lp_nb, lp_b + log_prob])
                    )

        # sorted by log(blank_prob + non_blank_prob)
        beams = sorted(new_beams_dict.items(), key=lambda x: logsumexp(x[1]), reverse=True)
        beams = beams[:beam_size]

    labels = list(beams[0][0])
    return labels

def ctc_decode(log_probs, label2char=None, blank=0, method='g', beam_size=10):
    emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
    # size of emission_log_probs: (batch, length, class)

    decoders = {
        'g': greedy_decode,
        'b': prefix_beam_decode,
    }
    decoder = decoders[method]

    decoded_list = []
    for emission_log_prob in emission_log_probs:
        decoded = decoder(emission_log_prob, blank=blank, beam_size=beam_size)
        if label2char:
            decoded = [label2char[l] for l in decoded]
        decoded_list.append(decoded)
    return decoded_list


In [None]:
class CRNN(nn.Module):

    def __init__(self, img_channel, img_height, img_width, num_class,
                 map_to_seq_hidden=128, rnn_hidden=512, leaky_relu=False):
        super(CRNN, self).__init__()

        self.cnn, (output_channel, output_height, output_width) = \
            self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)

        self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)

        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
        self.rnn2 = nn.LSTM(2*rnn_hidden, rnn_hidden, bidirectional=True)

        self.dense = nn.Linear(2*rnn_hidden, num_class)

    def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        channels = [img_channel, 64, 128, 256, 256, 512, 512, 1024]
        kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
        strides = [1, 1, 1, 1, 1, 1, 1]
        paddings = [1, 1, 1, 1, 1, 1, 0]

        cnn = nn.Sequential()

        def conv_relu(i, batch_norm=False):
            # shape of input: (batch, input_channel, height, width)
            input_channel = channels[i]
            output_channel = channels[i+1]

            cnn.add_module(
                f'conv{i}',
                nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
            )

            if batch_norm:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))

            relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
            cnn.add_module(f'relu{i}', relu)

        # size of image: (channel, height, width) = (img_channel, img_height, img_width)
        
        conv_relu(0)
        cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
        # (64, img_height // 2, img_width // 2)

        conv_relu(1)
        cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
        # (128, img_height // 4, img_width // 4)

        conv_relu(2)
        conv_relu(3,batch_norm=True)
        cnn.add_module(
            'pooling2',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (256, img_height // 8, img_width // 4)

        conv_relu(4, batch_norm=True)
        conv_relu(5, batch_norm=True)
        cnn.add_module(
            'pooling3',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (512, img_height // 16, img_width // 4)

        conv_relu(6)  # (512, img_height // 16 - 1, img_width // 4 - 1)

        output_channel, output_height, output_width = \
            channels[-1], img_height // 16 - 1, img_width // 4 - 1
        return cnn, (output_channel, output_height, output_width)

    def forward(self, images):
        # shape of images: (batch, channel, height, width)

        conv = self.cnn(images)
        batch, channel, height, width = conv.size()

        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(2, 0, 1)  # (width, batch, feature)
        seq = self.map_to_seq(conv)

        recurrent, _ = self.rnn1(seq)
        recurrent, _ = self.rnn2(recurrent)

        output = self.dense(recurrent)
        return output  # shape: (seq_len, batch, num_class)

In [None]:

model2 = CRNN(1,configs.height,configs.width,num_class=len(configs.CHARS)+1,leaky_relu=False)
#model2.load_state_dict(torch.load("models/08_handwriting_recognition_torch\crnn_4_loss0.5615450739860535.pt"))
model2.to(device)


In [None]:
torch.cuda.empty_cache()

In [None]:

def train_batch(crnn, data, optimizer, criterion, device):
    crnn.train()
    images, targets, target_lengths = [d.to(device) for d in data]

    logits = crnn(images)
    log_probs = torch.nn.functional.log_softmax(logits, dim=2)

    batch_size = images.size(0)
    input_lengths = torch.LongTensor([logits.size(0)] * batch_size)
    target_lengths = torch.flatten(target_lengths)

    loss = criterion(log_probs, targets, input_lengths, target_lengths)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(crnn.parameters(), 5) # gradient clipping with 5
    optimizer.step()
    return loss.item()

def test_batch(crnn, data,criterion,device):
    crnn.eval()
    crnn.to(device)
    tot_count = 0
    tot_loss = 0
    tot_correct = 0
    wrong_cases = []

    pbar_total =len(testloader)
    pbar = tqdm(total=pbar_total, desc="Evaluate")
    
    with torch.no_grad():
        for i, data in enumerate(testloader):

            images, targets, target_lengths = [d.to(device) for d in data]

            logits = crnn(images)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            batch_size = images.size(0)
            input_lengths = torch.LongTensor([logits.size(0)] * batch_size)

            loss = criterion(log_probs, targets, input_lengths, target_lengths)

            preds = ctc_decode(log_probs,method="g")
            reals = targets.cpu().numpy().tolist()
            target_lengths = target_lengths.cpu().numpy().tolist()

            tot_count += batch_size
            tot_loss += loss.item()
            target_length_counter = 0
            for pred, target_length in zip(preds, target_lengths):
                real = reals[target_length_counter:target_length_counter + target_length]
                target_length_counter += target_length
                if pred == real:
                    tot_correct += 1
                else:
                    wrong_cases.append((real, pred))
            
            pbar.update(1)
        pbar.close()

        evaluation = {
            'loss': tot_loss / tot_count,
            'acc': tot_correct / tot_count,
            'wrong_cases': wrong_cases
        }

        return evaluation

optimizer = optim.RMSprop(model2.parameters(), lr=configs.learning_rate)
criterion = nn.CTCLoss()
criterion.to(device)

i = 1
for epoch in range(1, 20 + 1):
    print(f'epoch: {epoch}')
    tot_train_loss = 0.
    tot_train_count = 0
    for train_data in tqdm(trainloader):
        
        loss = train_batch(model2, train_data, optimizer, criterion, device)
        train_size = train_data[0].size(0)
        
        tot_train_loss += loss
        tot_train_count += train_size
        
    if i % 1 == 0 :
                    prefix = 'crnn'
                    
                    save_model_path = (configs.model_path+f'/{prefix}_{i}_loss{loss}.pt')
                    torch.save(model2.state_dict(), save_model_path)
                    print('save model at ', save_model_path)
    i += 1

    print('train_loss: ', tot_train_loss / tot_train_count)

    evaluation = test_batch(model2,testloader,criterion,device)
    print(evaluation)
        