In [71]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

In [72]:
 class CRNN(nn.Module):
  def __init__(self, char_list):
    super(CRNN, self).__init__()

    self.conv_1 = nn.Conv2d(in_channels=1, out_channels=64,kernel_size=(3, 3), padding = 'same')
    self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
    self.conv_2 = nn.Conv2d(in_channels=64, out_channels=128,kernel_size=(3, 3), padding = 'same')
    self.pool_2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
    self.conv_3 = nn.Conv2d(in_channels=128, out_channels=256,kernel_size=(3, 3), padding = 'same')
    self.conv_4 = nn.Conv2d(in_channels=256, out_channels=256,kernel_size=(3, 3), padding = 'same')
    self.pool_4 = nn.MaxPool2d(kernel_size=(2, 1))
    self.conv_5 = nn.Conv2d(in_channels=256, out_channels=512,kernel_size=(3, 3), padding = 'same')
    self.norm_5 = nn.BatchNorm2d(512)
    self.conv_6 = nn.Conv2d(in_channels=512, out_channels=512,kernel_size=(3, 3), padding = 'same')
    self.norm_6 = nn.BatchNorm2d(512)
    self.pool_6 = nn.MaxPool2d(kernel_size=(2,1))
    self.conv_7 = nn.Conv2d(in_channels=512, out_channels=512,kernel_size=(2, 2))
    self.bilstm_1 = nn.LSTM(
            input_size=512,
            hidden_size=128,
            num_layers=1,
            bidirectional=True,
            dropout=0.2,
            batch_first=True
        )
    self.bilstm_2 = nn.LSTM(
            input_size= 256,
            hidden_size=128,
            num_layers=1,
            bidirectional=True,
            dropout=0.2,
            batch_first=True
        )
    self.output = nn.Linear(in_features= 256, out_features= char_list + 1)
  def forward(self, x):
    x = self.conv_1(x)
    x = F.relu(x)
    x = self.pool_1(x)
    x = self.conv_2(x)
    x = F.relu(x)
    x = self.pool_2(x)
    x = self.conv_3(x)
    x = F.relu(x)
    x = self.conv_4(x)
    x = F.relu(x)
    x = self.pool_4(x)
    x = self.conv_5(x)
    x = F.relu(x)
    x = self.norm_5(x)
    x = self.conv_6(x)
    x = F.relu(x)
    x = self.norm_6(x)
    x = self.pool_6(x)
    x = self.conv_7(x)
    x = F.relu(x)
    x = x.squeeze()
    x = x.permute(0, 2, 1)
    x, _ = self.bilstm_1(x)
    x, _ = self.bilstm_2(x)
    x = self.output(x)
    x = F.softmax(x, dim = -1)
    return x




In [73]:
# Encoder label

def encoder_label(txt):
  dig_lst = []
  for index, char in enumerate(txt):
    try:
      dig_lst.append(char_list.index(char))
    except:
      print(char)
  return dig_lst

In [74]:
batch_size = 64
max_label_len = 5
labels = torch.FloatTensor(batch_size, max_label_len)
input_len = torch.IntTensor(batch_size, 1)
label_len = torch.IntTensor(batch_size, 1)
ctc_loss = nn.CTCLoss(blank=62)

# Unzip dataset


In [None]:
!unzip /content/drive/MyDrive/CRNN/word.zip -d dataset

Archive:  /content/drive/MyDrive/CRNN/word.zip
replace dataset/word.xml? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
# Xử lí file xml
import xml.etree.ElementTree as ET
import cv2
import os
import numpy as np
import string

char_list = string.ascii_letters+string.digits
training_img = []
training_txt = []
train_input_length = []
train_label_length = []
orig_txt = []

#lists for validation dataset
valid_img = []
valid_txt = []
valid_input_length = []
valid_label_length = []
valid_orig_txt = []
i = 1
max_label_len = 0
tree = ET.parse('dataset/word.xml')
root = tree.getroot()
print(f'Tên gốc: {root.tag}')
for child in root:
    file_name, word = child.attrib['file'], child.attrib['tag']
    print(file_name)
    img = cv2.cvtColor(cv2.imread(os.path.join('dataset', file_name)), cv2.COLOR_BGR2GRAY)
    # convert shape to (1, 32, 128)
    w, h = img.shape
    if h > 128 or w > 32:
      continue
    if w < 32:
      add_zeros = np.ones((32 - w, h)) * 255
      img = np.concatenate((img, add_zeros))
    if h < 128:
      add_zeros = np.ones((32, 128 - h)) * 255
      img = np.concatenate((img, add_zeros), axis = 1)
    img = np.expand_dims(img, 0)
    # normalize image
    img = img / 255
    if len(word) > max_label_len:
      max_label_len = len(word)
    if i % 5 == 0:
      valid_orig_txt.append(word)
      valid_label_length.append(len(word))
      valid_input_length.append(31)
      valid_img.append(img)
      valid_txt.append(encoder_label(word))
    else:
      orig_txt.append(word)
      train_label_length.append(len(word))
      train_input_length.append(31)
      training_img.append(img)
      training_txt.append(encoder_label(word))
    i += 1
def pad_sequences(sequences, maxlen, padding_value):
    padded_sequences = []

    for seq in sequences:
        seq_len = len(seq)
        if seq_len < maxlen:
            # Thêm padding 'post'
            seq = seq + [padding_value] * (maxlen - seq_len)
        padded_sequences.append(seq)

    return torch.tensor(padded_sequences)

train_padded_txt = np.array(pad_sequences(training_txt, max_label_len, len(char_list)))
valid_padded_txt = np.array(pad_sequences(valid_txt, max_label_len, len(char_list)))

In [None]:
training_img = np.array(training_img)
train_input_length = np.array(train_input_length)
train_label_length = np.array(train_label_length)

valid_img = np.array(valid_img)
valid_input_length = np.array(valid_input_length)
valid_label_length = np.array(valid_label_length)

In [None]:
# Sử dụng DataLoader

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
class Dataset(Dataset):
  def __init__(self, training_img, labels, training_label_length, train_input_length, transform = None):
    self.labels = labels
    self.training_img = training_img
    self.transform = transform
    self.label_lengths = training_label_length
    self.input_lengths = train_input_length
  def __len__(self):
    return len(self.training_img)
  def __getitem__(self, index):
    image = self.training_img[index]
    label = self.labels[index]
    label_length = self.label_lengths[index]
    input_length = self.input_lengths[index]
    return image, label, label_length, input_length

# Tạo dataset
training_dataset = Dataset(training_img, train_padded_txt,train_label_length,train_input_length,  None)
valid_dataset = Dataset(valid_img, valid_padded_txt,valid_label_length, train_input_length, None)
# Tạo dataloader
train_loader = DataLoader(training_dataset, batch_size = 4,  shuffle=True, num_workers=1)
valid_loader = DataLoader(valid_dataset, batch_size = 4,  shuffle=True, num_workers=1)

In [None]:
def validate(model, val_loader, criterion):
    model.eval()  # Đặt mô hình vào chế độ đánh giá (evaluation mode)
    val_loss = 0
    with torch.no_grad():  # Tắt gradient để tiết kiệm bộ nhớ và tăng tốc độ tính toán
        for images, labels, label_lengths, input_lengths in val_loader:
            images = images.float()
            outputs = model(images)  # [batch_size, seq_length, num_classes]
            outputs = outputs.log_softmax(2)  # Cần log_softmax để phù hợp với CTCLoss
            outputs = outputs.permute(1, 0, 2)
            # Tính toán loss cho tập validation
            loss = criterion(outputs, labels, input_lengths, label_lengths)
            val_loss += loss.item()

    # Tính loss trung bình cho toàn bộ tập validation
    val_loss /= len(val_loader)
    return val_loss

In [None]:
import torch.optim as optim

crnn = CRNN(len(char_list))
optimizer = optim.Adam(crnn.parameters(), lr = 0.0001)
ctc_loss = nn.CTCLoss()
num_epochs = 50
batch_size = 4
best_val_loss = float('inf')
for epoch in range(num_epochs):
  train_loss = 0
  crnn.train()
  for images, labels, label_lengths, input_lengths in train_loader:
    optimizer.zero_grad()
    images = images.float()
    outs = crnn(images)
    outs = outs.log_softmax(2).detach().requires_grad_()
    outs = outs.permute(1, 0, 2)
    loss = ctc_loss(outs, labels, input_lengths , label_lengths)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

  train_loss /= len(train_loader)
  val_loss = validate(crnn, valid_loader, ctc_loss)
  print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
  if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(crnn.state_dict(), 'best_crnn_model.pth')
        print(f'Model saved with Validation Loss: {val_loss:.4f}')


In [None]:
model_crnn = CRNN(len(char_list))
model_crnn.load_state_dict(torch.load("/content/best_crnn_model.pth"))
model_crnn.eval()

In [None]:
with torch.no_grad():
  for images, labels, label_lengths, input_lengths in valid_loader:
    images = images.float()
    predictions = model_crnn(images)

predictions.shape

In [None]:
labels[0]

In [None]:
torch.argmax(predictions[0][3])

In [None]:
char_list[55]