In [1]:
import torch
import torch.nn.functional as F
import math
import os
import regex
import warnings
import numpy as np
import wandb
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.transforms import v2
from PIL import Image
from sklearn.utils.class_weight import compute_class_weight
from torchmetrics.functional.text import char_error_rate
from sklearn.metrics import f1_score, recall_score, precision_score
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
wandb.login(key='cb54d60f74230aa62ccce17a9c1718368ad7183e', relogin=True)
wandb.init(project='datn', reinit=True)

# 1. Vocabulary & Dataset

In [2]:
class Vocabulary:
    def __init__(self, vocab_file, seq_length):
        self.vocab_file = vocab_file
        self.seq_length = seq_length
        self.characters = set()
        self.string_to_index = {"<pad>": 0, "<start>": 1, "<end>": 2, " ": 3}
        self.index_to_string = {0: "<pad>", 1: "<start>", 2: "<end>", 3: " "}

    def build_vocab(self):
        with open(self.vocab_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.rstrip()
                self.characters.add(line)
            f.close()
        for index, token in enumerate(set(self.characters)):
            self.string_to_index[token] = index + 4
            self.index_to_string[index + 4] = token

    def vectorize_text(self, sentence, add_special_token=True):
        tokens = regex.findall(r'\X', sentence)
        vectors = []
        if add_special_token:
            vectors.append(self.string_to_index["<start>"])
        for token in tokens:
            vectors.append(self.string_to_index[token])
        if add_special_token:
            vectors.append(self.string_to_index["<end>"])
        n = self.seq_length - len(vectors)
        if n > 0:
            vectors.extend([self.string_to_index["<pad>"] for _ in range(n)])
        return vectors

    def convert_text(self, vectors):
        texts = [self.index_to_string[vector] for vector in vectors]
        return texts

class CustomDataset(Dataset):
    def __init__(self, folder_name, file_path, vocab_file, seq_length, image_size):
        self.folder_name = folder_name
        self.file_path = file_path
        self.vocab_file = vocab_file
        self.seq_length = seq_length
        self.image_size = image_size
        # Get input and output
        self.input = []
        self.output = []
        with open(self.file_path, 'r', encoding='utf-8') as f:
            for line in f:
                image_name, label = line.rstrip().split('--------')
                self.input.append(image_name)
                self.output.append(label)
        # Build vocabulary.txt
        self.vocab = Vocabulary(self.vocab_file, self.seq_length)
        self.vocab.build_vocab()
        # Transform
        self.transform = v2.Compose([
            v2.PILToTensor(),
            v2.ToDtype(torch.float)
        ])

    def __len__(self):
        return len(self.input)

    def __getitem__(self, index):
        # Get image
        path = self.input[index]
        image_path = os.path.join(self.folder_name, path)
        image = Image.open(image_path).convert("RGB")
        image_tensor = self.transform(image)
        label = self.output[index]
        vector_label = self.vocab.vectorize_text(label)
        return image_tensor, torch.Tensor(vector_label).int()

# 2. Model

In [3]:
class Embedding(torch.nn.Module):
    def __init__(self, vocab_size, n_dim):
        super(Embedding, self).__init__()
        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=n_dim)

    def forward(self, x):
        out = self.embedding(x)
        return out
    
class Positional_Encoding(torch.nn.Module):
    def __init__(self, seq_length, n_dim):
        super(Positional_Encoding, self).__init__()
        self.seq_length = seq_length
        self.n_dim = n_dim

    def forward(self):
        # positional vector
        position_encode = torch.zeros((self.seq_length, self.n_dim))
        for pos in range(self.seq_length):
            for i in range(0, self.n_dim, 2):
                position_encode[pos, i] = math.sin(pos / (10000 ** (2 * i / self.n_dim)))
                position_encode[pos, i+1] = math.cos(pos / (10000 ** (2 * i / self.n_dim)))
        return position_encode
    
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, n_head, n_dim):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.n_dim = n_dim
        self.n_dim_each_head = int(self.n_dim / self.n_head)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # init query, key, value
        self.flat = torch.nn.Flatten(-2)
        self.query_matrix = torch.nn.Linear(self.n_dim_each_head, self.n_dim_each_head, bias=False)
        self.key_matrix = torch.nn.Linear(self.n_dim_each_head, self.n_dim_each_head, bias=False)
        self.value_matrix = torch.nn.Linear(self.n_dim_each_head, self.n_dim_each_head, bias=False)
        self.output_matrix = torch.nn.Linear(self.n_dim_each_head * self.n_head, self.n_dim_each_head * self.n_head, bias=False)

    def forward(self, query, key, value, mask=None):  # (batch_size, seq_length, n_dim)
        batch_size = key.size(0)
        seq_length = key.size(1)
        seq_length_query = query.size(1)
        # divide head => (batch_size, seq_length, n_head, n_dim_each_head)
        query = query.view(batch_size, seq_length_query, self.n_head, self.n_dim_each_head)
        key = key.view(batch_size, seq_length, self.n_head, self.n_dim_each_head)
        value = value.view(batch_size, seq_length, self.n_head, self.n_dim_each_head)
        q = self.query_matrix(query)
        k = self.key_matrix(key)
        v = self.value_matrix(value)
        # transpose => (batch_size, n_head, seq_length, n_dim_each_head)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        # -------------------------- Compute MultiHead-Attention --------------------------
        """
        - Step 1: compute matmul(q, k^T)
        - Step 2: scale with sqrt(n_dim)
        - Step 3: compute softmax => matrix A
        - Step 4: compute matmul of matrix A and value matrix
        - Step 5: concatenate matrix => matrix Z
        - Step 4: compute matmul of matrix Z and matrix W0
        """
        k_T = k.transpose(-1, -2)  # => (batch_size, n_head, n_dim_each_head, seq_length)
        product = torch.matmul(q, k_T)  # => (batch_size, n_head, seq_length_query, seq_length)
        product = product / math.sqrt(self.n_dim_each_head)
        if mask is not None:
            product = product.masked_fill(mask == 0, float("-1e20"))
        product = product.to(self.device)
        scores = F.softmax(product, dim=-1)  # => (batch_size, n_head, seq_length_query, seq_length)
        scores = torch.matmul(scores, v)  # => (batch_size, n_head, seq_length_query, n_dim_each_head)
        scores = scores.transpose(1, 2)  # => (batch_size, seq_length_query, n_head, n_dim_each_head)
        scores = self.flat(scores)
        output = self.output_matrix(scores)
        return output

In [4]:
class TransformerBlock(torch.nn.Module):
    def __init__(self, n_head, n_dim, n_expansion):
        super(TransformerBlock, self).__init__()
        # parameters
        self.n_head = n_head
        self.n_dim = n_dim
        self.n_expansion = n_expansion
        # instances
        self.multihead = MultiHeadAttention(n_head=self.n_head, n_dim=self.n_dim)
        self.norm_attention = torch.nn.LayerNorm(self.n_dim)
        self.feedforward = torch.nn.Sequential(
            torch.nn.Linear(self.n_dim, self.n_expansion * self.n_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.n_expansion * self.n_dim, self.n_dim),
            torch.nn.ReLU(),
        )
        self.norm_feedforward = torch.nn.LayerNorm(self.n_dim)

    def forward(self, query, key, value):
        multihead_vector = self.multihead(query, key, value)
        add_norm_vector = self.norm_attention(multihead_vector + query)
        feed_forward_vector = self.feedforward(add_norm_vector)
        output = self.norm_feedforward(feed_forward_vector + add_norm_vector)
        return output

In [5]:
class ViT(torch.nn.Module):
    def __init__(self, input_chanel, output_chanel, n_head, n_expansion, n_layer):
        super(ViT, self).__init__()
        # Parameters
        self.input_chanel = input_chanel
        self.output_chanel = output_chanel
        self.n_head = n_head
        self.n_expansion = n_expansion
        self.n_layer = n_layer
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Instance
        self.patch_embedding = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=self.input_chanel, out_channels=self.output_chanel, kernel_size=32, stride=32, padding=0),
            torch.nn.BatchNorm2d(self.output_chanel),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Flatten(2)
        )
        self.dropout = torch.nn.Dropout(0.2)
        self.transformer_block = TransformerBlock(self.n_head, self.output_chanel, self.n_expansion)

    def add_cls_token(self, x):
        batch_size = x.shape[0]
        cls_token = torch.nn.Parameter(data=torch.zeros(batch_size, 1, self.output_chanel), requires_grad=True).to(self.device)
        return torch.concat([cls_token, x], dim=1)

    def forward(self, x):
        """ Input shape: (batch_size, chanel, height, width) """
        x = self.patch_embedding(x)     # => (batch_size, seq_len, output_chanel)
        x = x.transpose(-1, -2)
        x = self.add_cls_token(x)       # => (batch_size, seq_len+1, output_chanel)
        position = Positional_Encoding(seq_length=x.shape[1], n_dim=self.output_chanel)
        x = x + position().requires_grad_(False).to(self.device)
        for _ in range(self.n_layer):
            x = self.transformer_block(x, x, x)
            x = self.dropout(x)
        return x

class EncoderBlock(torch.nn.Module):
    def __init__(self, input_chanel, hidden_dim, output_chanel, n_head, n_expansion, n_layer):
        super(EncoderBlock, self).__init__()
        # Parameters
        self.input_chanel = input_chanel
        self.hidden_dim = hidden_dim
        self.output_chanel = output_chanel
        self.n_head = n_head
        self.n_expansion = n_expansion
        self.n_layer = n_layer
        # Instances
        self.vit = ViT(self.input_chanel, self.hidden_dim, self.n_head, self.n_expansion, self.n_layer)
        self.fc_out = torch.nn.Sequential(
            torch.nn.Linear(self.hidden_dim, self.output_chanel),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
        )

    def forward(self, x):
        x = self.vit(x)
        x = self.fc_out(x)
        return x

In [6]:
class DecoderLayer(torch.nn.Module):
    def __init__(self, n_head, n_dim, seq_length, vocab_size, n_expansion):
        super(DecoderLayer, self).__init__()
        # parameters
        self.n_head = n_head
        self.n_dim = n_dim
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        self.n_expansion = n_expansion
        # instance
        self.mask_attention = MultiHeadAttention(n_head=self.n_head, n_dim=self.n_dim)
        self.norm_mask_attention = torch.nn.LayerNorm(self.n_dim)
        self.transformer_block = TransformerBlock(self.n_head, self.n_dim, self.n_expansion)

    def forward(self, x, key, value, mask):
        masked_output = self.mask_attention(query=x, key=x, value=x, mask=mask)
        masked_output = self.norm_mask_attention(x + masked_output)
        output = self.transformer_block(query=masked_output, key=key, value=value)
        return output

class DecoderBlock(torch.nn.Module):
    def __init__(self, n_head, n_dim, seq_length, vocab_size, n_expansion, n_layer):
        super(DecoderBlock, self).__init__()
        # parameters
        self.n_head = n_head
        self.n_dim = n_dim
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        self.n_expansion = n_expansion
        self.n_layer = n_layer
        # instance
        self.embedding = Embedding(vocab_size=self.vocab_size, n_dim=self.n_dim)
        self.decoder_layers = DecoderLayer(self.n_head, self.n_dim, self.seq_length, self.vocab_size, self.n_expansion)
        self.fc_output = torch.nn.Linear(self.n_dim, self.vocab_size)
        self.dropout = torch.nn.Dropout(0.2)

    def forward(self, output_encoder, input_decoder, mask):
        embedding_vector = self.embedding(input_decoder)
        position = Positional_Encoding(seq_length=embedding_vector.shape[1], n_dim=self.n_dim)
        embedding_vector = embedding_vector + position().requires_grad_(False).to(device)
        output = self.decoder_layers(x=embedding_vector, key=output_encoder, value=output_encoder, mask=mask)
        for _ in range(self.n_layer - 1):
            output = self.decoder_layers(x=output, key=output_encoder, value=output_encoder, mask=mask)
            output = self.dropout(output)
        return self.fc_output(output)

In [7]:
class Model(torch.nn.Module):
    def __init__(self, n_dim_model,
        input_channel_encoder, hidden_dim_encoder, n_head_encoder, n_expansion_encoder, n_layer_encoder,
        n_head_decoder, seq_length_decoder, vocab_size_decoder, n_expansion_decoder, n_layer_decoder,
    ):
        super(Model, self).__init__()
        # Parameters
        self.n_dim_model = n_dim_model
        self.input_chanel_encoder = input_channel_encoder
        self.hidden_dim_encoder = hidden_dim_encoder
        self.n_head_encoder = n_head_encoder
        self.n_expansion_encoder = n_expansion_encoder
        self.n_layer_encoder = n_layer_encoder
        self.n_head_decoder = n_head_decoder
        self.seq_length_decoder = seq_length_decoder
        self.vocab_size_decoder = vocab_size_decoder
        self.n_expansion_decoder = n_expansion_decoder
        self.n_layer_decoder = n_layer_decoder
        # Instances
        self.encoder = EncoderBlock(self.input_chanel_encoder, self.hidden_dim_encoder, self.n_dim_model,  self.n_head_encoder, self.n_expansion_encoder, self.n_layer_encoder)
        self.decoder = DecoderBlock(self.n_head_decoder, self.n_dim_model, self.seq_length_decoder, self.vocab_size_decoder, self.n_expansion_decoder, self.n_layer_decoder)

    @staticmethod
    def create_mask(seq_input):
        batch_size, seq_len = seq_input.shape
        mask = torch.tril(torch.ones((seq_len, seq_len))) 
        return mask.expand(batch_size, 1, seq_len, seq_len)

    def forward(self, input_encoder, input_decoder):
        output_encoder = self.encoder(input_encoder)
        mask = self.create_mask(input_decoder).to('cuda' if torch.cuda.is_available() else 'cpu')
        output_decoder = self.decoder(output_encoder, input_decoder, mask)
        return output_decoder

# 3. Metrics

In [8]:
class CustomSchedule(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        super(CustomSchedule, self).__init__(optimizer)

    def get_lr(self):
        step = self.last_epoch + 1
        arg1 = (step ** -0.5)
        arg2 = step * (self.warmup_steps ** -1.5)

        return [((self.d_model ** -0.5) * min(arg1, arg2)) for base_lr in self.base_lrs]

In [9]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0, verbose=False):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.counter = 0
        self.best_metric = float('inf')  # Initialize with positive infinity for loss
        self.early_stop = False

    def __call__(self, current_metric):
        if self.best_metric - current_metric > self.delta:
            self.best_metric = current_metric
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("Early stopping triggered.")
        return self.early_stop

In [10]:
def decode_input(input_decoder, vocab):
    # input_decoder: (batch_size, seq_len)
    index = vocab['<pad>']
    vocabulary = list(vocab.keys())
    sequences = []
    for batch in input_decoder:
        text = ''.join([vocabulary[i] for i in batch if i != index])
        sequences.append(text)
    return sequences

def decode_output(output_model, vocab):
    # output_model: (batch_size, seq_len, vocab_size)
    index = vocab['<pad>']
    vocabulary = list(vocab.keys())
    sequences = []
    for batch in output_model:
        output = torch.argmax(batch, dim=-1)
        text = ''.join([vocabulary[i] for i in output if i != index])
        sequences.append(text)
    return sequences

def compute_cer(input_decoder, output_model, vocab):
    target = decode_input(input_decoder, vocab)
    predict = decode_output(output_model, vocab)
    with open('log.txt', 'a', encoding='utf-8') as f:
        f.write(f'{target} -------- {predict} \n')
    cer = char_error_rate(predict, target)
    return cer.item()

def compute_recall_precision_f1(input_decoder, output_model):
    y = []
    y_pred = []
    for batch in output_model:
        output = torch.argmax(batch, dim=-1)
        y_pred.extend([i.cpu() for i in output])
    for batch in input_decoder:
        y.extend([i.cpu() for i in batch])
    recall = recall_score(np.asarray(y), np.asarray(y_pred), average='micro', zero_division=0)
    precision = precision_score(np.asarray(y), np.asarray(y_pred), average='micro', zero_division=0)
    f1 = f1_score(np.asarray(y), np.asarray(y_pred), average='micro', zero_division=0)
    return recall, precision, f1

In [11]:
def compute_metrics(input_decoder, output_model, vocab):
    input_decoder = input_decoder[:, 1:].contiguous()
    output_model = output_model[:, :-1, :].contiguous()
    recall, precision, f1 = compute_recall_precision_f1(input_decoder, output_model)
    cer = compute_cer(input_decoder, output_model, vocab)
    return cer, recall, precision, f1

# 4. Training model

## a. Create dataset & Dataloader

In [3]:
train_path = '../dataset/augment_data'
train_file = '../dataset/augment_labels.txt'
vocab_file = '../dataset/vocab.txt'
train_ratio = 0.8
batch_size = 128
seq_length = 192
image_size = (640, 640)
dataset = CustomDataset(train_path, train_file, vocab_file, seq_length, image_size)
n_train = int(train_ratio * len(dataset))
n_val = len(dataset) - n_train
train_dataset, val_dataset = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
vocab = dataset.vocab.string_to_index

In [4]:
print("Number of training examples: {}".format(len(train_loader)))
print("Number of validate examples: {}".format(len(val_loader)))

Number of training examples: 687
Number of validate examples: 172


In [5]:
print("Num of vocab:", len(vocab))

Num of vocab: 233


In [6]:
labels = []
for line in dataset.output:
    temp = []
    tokens = regex.findall(r'\X', line)
    temp.append(dataset.vocab.string_to_index['<start>'])
    temp.extend([dataset.vocab.string_to_index[token] for token in tokens])
    temp.append(dataset.vocab.string_to_index['<end>'])
    n = seq_length - len(temp)
    if n > 0:
        for _ in range(n):
            temp.append(dataset.vocab.string_to_index['<pad>'])
    labels.extend(temp)

In [7]:
weights = compute_class_weight(class_weight="balanced", classes=np.unique(labels), y = labels)

In [8]:
weights.shape

(233,)

In [9]:
print(vocab.keys())

dict_keys(['<pad>', '<start>', '<end>', ' ', 'ồ', 'ơ', 'ã', 'Y', '2', 'Ẹ', "'", 'Ẽ', 'Ỉ', 'U', 'e', 'I', 'p', '6', 'Ỏ', 'ỡ', 'Ủ', 'ộ', 'Ệ', 'Ý', 'Ấ', 'A', 'T', 'Â', 'ấ', 'B', 'Ắ', 'Ị', 'ế', 'C', 'Á', 's', '/', 'n', 'R', 'ọ', '-', 'm', 'Ỹ', 'Ĩ', 'ề', 'd', 'ẹ', 'ẩ', 'Ù', 'Ẫ', 'c', 'ị', 'ẹ', ':', '^', '?', 'Đ', '<', 'Ẩ', 'Ó', 'ư', '1', 'ú', 'Ỵ', 'q', 'í', 'ễ', '%', 'v', 'õ', 'ả', 'J', 'Ổ', 'â', 'V', 'Ế', 'Ự', 'ẵ', 'É', 'i', 'ứ', 'ở', 'Ỳ', 'Ữ', 'ặ', 'ý', '_', 'ê', 'Ê', 'E', 'Ớ', 'Ư', 'Ỷ', '=', '0', 'Ă', 'Ơ', 'Ễ', 't', 'ỵ', 'ạ', 'Ố', 'Ả', 'D', 'M', '>', 'ỹ', 'è', 'Ẵ', ';', 'ệ', '5', '9', '#', 'S', ',', 'ắ', 'Ợ', 'á', 'ò', 'k', '[', 'ặ', 'ừ', 'ẽ', 'Ỗ', 'Ồ', 'L', 'ĩ', 'N', 'ỗ', 'à', '.', 'ằ', 'j', 'Ằ', '{', 'Ọ', 'Ô', 'Ú', '&', 'È', ')', 'ủ', 'Ử', ']', 'ó', '*', '|', 'Ỡ', 'ố', '7', 'Z', 'ô', '!', 'ổ', 'ă', 'Ụ', 'Ẳ', 'h', 'Ẻ', 'w', 'ớ', 'ỏ', 'Ạ', '}', 'é', 'Ầ', '"', 'K', '8', 'Ở', 'Í', '∈', 'P', 'g̃', 'Ì', '

In [10]:
weights

array([5.60930589e-03, 8.24034335e-01, 8.24034335e-01, 9.51417143e-02,
       5.99767264e+00, 3.94099946e+00, 1.03376971e+01, 1.37146781e+01,
       6.81192621e+00, 4.11440343e+03, 9.14311874e+01, 6.85733906e+02,
       8.22880687e+02, 2.18851246e+01, 2.89542817e+00, 9.43670512e+00,
       1.82537863e+00, 1.24302219e+01, 4.11440343e+03, 2.05720172e+02,
       1.71433476e+02, 3.82379501e+00, 6.63613457e+01, 2.57150215e+02,
       1.84502396e+01, 4.88066837e+00, 8.48680576e-01, 1.32722691e+02,
       6.20573670e+00, 3.28364201e+00, 5.87771919e+02, 6.85733906e+02,
       4.24165302e+00, 2.47557367e+00, 4.15596306e+01, 2.79321346e+00,
       1.24678892e+01, 2.12235811e-01, 1.19258071e+01, 9.50208645e+00,
       2.29854940e+01, 1.27855918e+00, 4.11440343e+03, 2.05720172e+03,
       6.23394460e+00, 3.78858511e+00, 4.11440343e+02, 2.18851246e+01,
       8.22880687e+02, 2.05720172e+03, 6.35330981e-01, 5.42797287e+00,
       1.64576137e+02, 1.53522516e+01, 2.05720172e+03, 6.53079910e+01,
      

## b. Create model

In [None]:
n_dim_model = 512
# --- Encoder Parameters ---
input_chanel_encoder = 3
hidden_dim_encoder = 768
n_head_encoder = 12
n_expansion_encoder = 4
n_layer_encoder = 12
# --- Decoder Parameters ---
n_head_decoder = 8
seq_length_decoder = seq_length
vocab_size_decoder = len(vocab)
n_expansion_decoder = 4
n_layer_decoder = 6
model = Model(n_dim_model, input_chanel_encoder, hidden_dim_encoder, n_head_encoder, n_expansion_encoder, n_layer_encoder,
              n_head_decoder, seq_length_decoder, vocab_size_decoder, n_expansion_decoder, n_layer_decoder).to(device)

In [None]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_params)

In [None]:
print(model)

## c. Metrics & Optimizer & Loss function

In [None]:
epochs = 40
criterion = torch.nn.CrossEntropyLoss(weight=torch.Tensor(weights).to(device), ignore_index=0)
optimizer = torch.optim.AdamW(model.parameters(), betas=(0.9, 0.98), eps=1e-09, weight_decay=1e-5)
lr_scheduler = CustomSchedule(optimizer, d_model=n_dim_model, warmup_steps=4000)
# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=2, threshold=0.01)
early_stopping = EarlyStopping(patience=5, delta=0.01, verbose=True)

## d. Training Step

In [None]:
results = {
    "train_loss": [],
    "train_cer": [],
    "val_loss": [],
    "val_cer": []
}
for epoch in range(epochs):
    model.train()
    vocab_size_decoder = len(vocab)
    loss_value = 0.0
    cer_value = 0.0
    recall_value, precision_value, f1_value = 0.0, 0.0, 0.0
    for batch_idx, (input_encoder, input_decoder) in enumerate(train_loader):
        # Predict
        optimizer.zero_grad()
        lr_scheduler.step()
        input_encoder, input_decoder = input_encoder.to(device), input_decoder.to(device)
        output_model = model(input_encoder, input_decoder)
        # Compute loss
        target = input_decoder[:, 1:].contiguous()
        pred = output_model[:, :-1, :].contiguous()
        loss = criterion(pred.view(-1, vocab_size_decoder), target.view(-1).long())
        loss_value += loss.item()
        # Compute metrics
        current_lr = optimizer.param_groups[0]['lr']
        cer, recall, precision, f1 = compute_metrics(input_decoder, output_model, vocab)
        cer_value += cer
        recall_value += recall
        precision_value += precision
        f1_value += f1
        loss.backward()
        optimizer.step()
        wandb.log({
            'loss per step': loss.item(), 
            'lr': current_lr, 
            'cer per step': cer,
            'recall per step': recall, 
            'precision per step': precision, 
            'f1 per step': f1
        })
    loss_train = loss_value / len(train_loader)
    cer_train = cer_value / len(train_loader)
    recall_value = recall_value / len(train_loader)
    precision_value = precision_value / len(train_loader)
    f1_value = f1_value / len(train_loader)
    wandb.log({
        'train loss': loss_train,
        'train cer': cer_train, 
        'train recall': recall_value,
        'train precision': precision_value,
        'train f1': f1_value,
    })
    # Evaluate
    loss_value = 0.0
    cer_value = 0.0
    recall_value, precision_value, f1_value = 0.0, 0.0, 0.0
    with torch.no_grad():
        model.eval()
        for batch_idx, (input_encoder, input_decoder) in enumerate(val_loader):
            # Predict
            input_encoder, input_decoder = input_encoder.to(device), input_decoder.to(device)
            output_model = model(input_encoder, input_decoder)
            # Compute loss
            target = input_decoder[:, 1:].contiguous()
            pred = output_model[:, :-1, :].contiguous()
            loss = criterion(pred.view(-1, vocab_size_decoder), target.view(-1).long()).item()
            cer, recall, precision, f1 = compute_metrics(input_decoder, output_model, vocab)
            cer_value += cer
            loss_value += loss
            recall_value += recall
            precision_value += precision
            f1_value += f1
        loss_val = loss_value / len(val_loader)
        cer_val = cer_value / len(val_loader)
        recall_value = recall_value / len(val_loader)
        precision_value = precision_value / len(val_loader)
        f1_value = f1_value / len(val_loader)
        wandb.log({
            'val loss': loss_val,
            'val cer': cer_val, 
            'val recall': recall_value,
            'val precision': precision_value,
            'val f1': f1_value,
        })
    print(f"Epoch: {epoch + 1} | train_loss: {loss_train:.2f} | val_loss: {loss_val:.2f}")
    if early_stopping(loss_val):
        print("Early Stopping Training Progress!")
        break
    torch.save({
        'model_state_dict': model.state_dict(),
        'vocab': vocab,
        'results': results,
        'optimizer': optimizer.state_dict()
    }, f'../checkpoints/checkpoint-{epoch}.pth.tar')

In [None]:
wandb.finish()