In [1]:
!pip install torchvision
!pip install torchaudio
!pip install transformers

Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.5.1->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.5.1->torchvision)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.5.1->torchvision)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch==2.5.1->torchvision)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch==2.5.1->torchvision)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch==2.5.1->torchvision)
  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.meta

# 1. Model

In [2]:
from einops import rearrange
import math
import torch
from torch import nn

class LanguageTransformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        nhead,
        num_encoder_layers,
        num_decoder_layers,
        dim_feedforward,
        max_seq_length,
        pos_dropout,
        trans_dropout,
    ):
        super().__init__()
        self.d_model = d_model
        self.embed_tgt = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, pos_dropout, max_seq_length)
        self.transformer = nn.Transformer(
            d_model,
            nhead,
            num_encoder_layers,
            num_decoder_layers,
            dim_feedforward,
            trans_dropout,
            batch_first=True  # Ensure batch_first is set
        )
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        tgt_mask = self.gen_nopeek_mask(tgt.shape[1]).to(src.device)
        src = self.pos_enc(src * math.sqrt(self.d_model))
        tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model))
        output = self.transformer(
            src, tgt,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )
        return self.fc(output)

    def gen_nopeek_mask(self, length):
        mask = (torch.triu(torch.ones(length, length)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0)
        return mask

    def forward_encoder(self, src):
        src = self.pos_enc(src * math.sqrt(self.d_model))
        memory = self.transformer.encoder(src)
        return memory

    def forward_decoder(self, tgt, memory):
        tgt_mask = self.gen_nopeek_mask(tgt.shape[1]).to(tgt.device)
        # Embed and add positional encoding
        tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model))
        # Decoder expects batch-first tensors (N, T, E)
        output = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask)
        return self.fc(output), memory
    def expand_memory(self, memory, beam_size):
        # Expand memory along the BATCH dimension (dim=0)
        return memory.repeat(beam_size, 1, 1)  # (batch*beam, seq_len, d_model)
    def get_memory(self, memory, i):
        # Select the i-th batch element (batch-first)
        return memory[i:i+1, :, :]  # (1, seq_len, d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)  # Shape: (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        seq_len = x.size(1)
        pe = self.pe[:seq_len, :].squeeze(1)  # (seq_len, d_model)
        x = x + pe.unsqueeze(0)  # Add (1, seq_len, d_model) to (batch, seq_len, d_model)
        return self.dropout(x)


class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(LearnedPositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        self.pos_embed = nn.Embedding(max_len, d_model)
        self.layernorm = LayerNorm(d_model)

    def forward(self, x):
        seq_len = x.size(0)
        pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
        pos = pos.unsqueeze(-1).expand(x.size()[:2])
        x = x + self.pos_embed(pos)
        return self.dropout(self.layernorm(x))


class LayerNorm(nn.Module):
    "A layernorm module in the TF style (epsilon inside the square root)."

    def __init__(self, d_model, variance_epsilon=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.variance_epsilon = variance_epsilon

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta

In [3]:
# from vietocr.model.seqmodel.transformer import LanguageTransformer
from torch import nn
from transformers import ViTModel

class ViTEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.projection = nn.Linear(768, config['transformer']['d_model'])
        
    def forward(self, x):
        outputs = self.vit(x)
        hidden_states = outputs.last_hidden_state  # (N, seq_len, 768)
        return self.projection(hidden_states)      # (N, seq_len, d_model)

class VietOCR(nn.Module):
    def __init__(
        self,
        vocab_size,  # Kept for compatibility but not used directly
        backbone,    # Not used since we're replacing CNN with ViT
        vit_args,    # Arguments for ViT configuration
        transformer_args,
        seq_modeling="transformer",
    ):
        super(VietOCR, self).__init__()

        # Replace CNN with ViTEncoder
        self.encoder = ViTEncoder(vit_args)
        self.seq_modeling = seq_modeling

        if seq_modeling == "transformer":
            # Remove explicit vocab_size, use transformer_args only
            self.transformer = LanguageTransformer(**transformer_args)
        else:
            raise ValueError("Not Supported Seq Model")

    def forward(self, img, tgt_input, tgt_key_padding_mask):
        """
        Shape:
            - img: (N, C, H, W) - Input image tensor (e.g., 3 channels, 224x224 for ViT)
            - tgt_input: (T, N) - Target input for the transformer decoder
            - tgt_key_padding_mask: (N, T) - Padding mask for the target sequence
            - output: (N, T, V) - Output logits over vocabulary
        """
        # Encode image using ViT
        src = self.encoder(img)  # Shape: (N, seq_len, d_model)

        if self.seq_modeling == "transformer":
            outputs = self.transformer(
                src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask
            )
        else:
            raise ValueError("Not Supported Seq Model")
        return outputs

2025-05-06 02:37:45.511466: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746499065.733918      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746499065.795449      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# 2. Vocab

In [4]:
class Vocab:
    def __init__(self, chars):
        self.pad = 0
        self.go = 1
        self.eos = 2
        self.unk = 3
        self.mask = 4
        self.char2id = {'<pad>': 0, '<go>': 1, '<eos>': 2, '<unk>': 3, '<mask>': 4}
        self.id2char = {0: '<pad>', 1: '<go>', 2: '<eos>', 3: '<unk>', 4: '<mask>'}
        
        for i, c in enumerate(sorted(chars), start=5):
            self.char2id[c] = i
            self.id2char[i] = c
    
    def encode(self, text):
        """Convert text to list of indices with <go> and <eos> tokens."""
        indices = [self.char2id['<go>']]
        for char in text:
            indices.append(self.char2id.get(char, self.char2id['<unk>']))
        indices.append(self.char2id['<eos>'])
        return torch.tensor(indices, dtype=torch.long)
    
    def decode(self, indices):
        """Convert list of indices to text."""
        text = []
        for idx in indices:
            # Handle both tensor and integer inputs
            idx_val = idx.item() if hasattr(idx, 'item') else idx
            if idx_val == self.char2id['<eos>']:
                break
            if idx_val not in (self.char2id['<pad>'], self.char2id['<go>'], self.char2id['<mask>']):
                text.append(self.id2char.get(idx_val, '<unk>'))
        return ''.join(text)

    def __len__(self):
        return len(self.char2id)
    
    def batch_decode(self, arr):
        """Decode a batch of index sequences."""
        texts = [self.decode(ids) for ids in arr]
        return texts

# Create vocabulary
def create_vocab(labels_file):
    with open(labels_file, 'r', encoding='utf-8') as f:
        labels = json.load(f)
    
    chars = set()
    for label in labels.values():
        chars.update(label)
    
    return Vocab(chars)

# 3. Dataloader

In [5]:
from torch.utils.data import Dataset, DataLoader  # Import Dataset
# Dataset
class CustomDataset(Dataset):
    def __init__(self, data_dir, labels_file, vocab, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.vocab = vocab
        
        with open(labels_file, 'r', encoding='utf-8') as f:
            self.labels = json.load(f)
        
        self.image_files = list(self.labels.keys())
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.data_dir, img_name)
        
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            
        label = self.labels[img_name]
        label_indices = self.vocab.encode(label)
        
        return {
            'img': image,
            'label': label_indices,
            'text': label,
            'filename': img_name
        }

def collate_fn(batch):
    images = torch.stack([item['img'] for item in batch])
    labels = [item['label'] for item in batch]
    filenames = [item['filename'] for item in batch]
    
    max_len = max(len(label) for label in labels)
    padded_labels = torch.zeros(len(labels), max_len, dtype=torch.long)
    padding_mask = torch.ones(len(labels), max_len, dtype=torch.bool)
    
    for i, label in enumerate(labels):
        padded_labels[i, :len(label)] = label
        padding_mask[i, :len(label)] = False
    
    return {
        'img': images,
        'label': padded_labels,
        'padding_mask': padding_mask,
        'filenames': filenames
    }

# 4. Loss

In [6]:
import torch
from torch import nn


class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, padding_idx, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim
        self.padding_idx = padding_idx

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 2))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
            true_dist[:, self.padding_idx] = 0
            mask = torch.nonzero(target.data == self.padding_idx, as_tuple=False)
            if mask.dim() > 0:
                true_dist.index_fill_(0, mask.squeeze(), 0.0)

        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

# 5. Beam

In [7]:
import torch

class Beam:
    def __init__(self, beam_size, candidates, vocab_size, device, sos_token, eos_token):
        self.beam_size = beam_size
        self.candidates = candidates
        self.vocab_size = vocab_size
        self.device = device
        self.sos_token = sos_token
        self.eos_token = eos_token

        # Initialize hypotheses: list of (sequence, score)
        self.hypotheses = [[torch.tensor([sos_token], device=device), 0.0]]  # Start with SOS token
        self.completed = []  # Store completed hypotheses (sequences that hit EOS)

    def get_current_state(self):
        """Return current hypotheses as a tensor of shape (beam_size, seq_len)."""
        if not self.hypotheses:
            return torch.tensor([], device=self.device).long()

        # Get max sequence length among current hypotheses
        max_len = max(len(hyp[0]) for hyp in self.hypotheses)
        # Pad sequences to max length
        padded = torch.ones(len(self.hypotheses), max_len, device=self.device).long() * self.eos_token
        for i, (seq, _) in enumerate(self.hypotheses):
            padded[i, :len(seq)] = seq
        return padded  # Shape: (beam_size, seq_len)

    def advance(self, log_probs):
        """Update beam with new log probabilities.
        log_probs: Shape (beam_size, vocab_size)
        """
        if not self.hypotheses:
            return

        # Get current scores
        scores = torch.tensor([score for _, score in self.hypotheses], device=self.device)
        # Compute scores for all possible next tokens
        # Shape: (beam_size, vocab_size)
        candidate_scores = scores.unsqueeze(1) + log_probs

        # Flatten to get top-k scores across all candidates
        flat_scores = candidate_scores.view(-1)  # Shape: (beam_size * vocab_size)
        topk_scores, topk_indices = flat_scores.topk(min(self.beam_size, len(flat_scores)), dim=0)

        # Compute which hypothesis and token each top-k score corresponds to
        new_hypotheses = []
        for score, idx in zip(topk_scores, topk_indices):
            prev_hyp_idx = idx // self.vocab_size
            token = idx % self.vocab_size
            prev_seq, prev_score = self.hypotheses[prev_hyp_idx]
            new_seq = torch.cat([prev_seq, torch.tensor([token], device=self.device)])

            if token == self.eos_token:
                # If EOS, add to completed hypotheses
                self.completed.append([new_seq, score.item()])
            else:
                # Otherwise, add to active hypotheses
                new_hypotheses.append([new_seq, score.item()])

        # Keep only top beam_size hypotheses
        self.hypotheses = new_hypotheses[:self.beam_size]

        # If no active hypotheses remain, stop
        if not self.hypotheses:
            self.hypotheses = [[torch.tensor([self.sos_token], device=self.device), 0.0]]

    def all_eos(self):
        """Check if all active hypotheses have reached EOS."""
        if not self.hypotheses:
            return True
        return all(seq[-1] == self.eos_token for seq, _ in self.hypotheses)

    def get_final(self, candidates):
        """Return the top candidates completed hypotheses."""
        # Sort completed hypotheses by score
        self.completed.sort(key=lambda x: x[1], reverse=True)
        # Take top candidates
        top_completed = self.completed[:candidates]
        if not top_completed:
            # If no completed hypotheses, return best active hypothesis
            if self.hypotheses:
                top_completed = [self.hypotheses[0]]
            else:
                top_completed = [[torch.tensor([self.sos_token], device=self.device), 0.0]]

        # Convert sequences to list of integers and return with scores
        final_sents = [seq.tolist() for seq, _ in top_completed]
        final_probs = [score for _, score in top_completed]
        return final_sents, final_probs

In [18]:
import torch
import numpy as np
import math
from PIL import Image
from torch.nn.functional import log_softmax, softmax


def batch_translate_beam_search(img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2):
    model.eval()
    device = img.device
    sents = []

    with torch.no_grad():
        src = model.encoder(img)  # Shape: (batch_size, seq_len, d_model)
        memories = model.transformer.forward_encoder(src)  # Shape: (batch_size, seq_len, d_model)
        batch_size = memories.size(0)

        for i in range(batch_size):
            memory = model.transformer.get_memory(memories, i)  # Shape: (1, seq_len, d_model)
            memory = model.transformer.expand_memory(memory, beam_size)  # Shape: (beam_size, seq_len, d_model)
            sent = beamsearch(
                memory, model, device, beam_size, candidates, max_seq_length, sos_token, eos_token
            )  # Shape: (candidates, max_seq_length)
            # assert sent.dim() == 2 and sent.size(0) == candidates and sent.size(1) == max_seq_length, \
            #     f"Expected shape ({candidates}, {max_seq_length}), got {sent.shape}"
            sents.append(sent)

        # Concatenate all sequences (all are already padded to max_seq_length)
        sents = torch.cat(sents, dim=0)  # Shape: (batch_size * candidates, max_seq_length)
    return sents


def beamsearch(memory, model, device, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2):
    final_sents = []
    final_probs = []

    # Initial input: <sos> token
    tgt_inp = torch.LongTensor(beam_size, 1).fill_(sos_token).to(device)  # Shape: (beam_size, 1)

    # Initial log probabilities
    log_probs = torch.zeros(beam_size, 1).to(device)  # Shape: (beam_size, 1)

    # Track whether each beam is finished
    finished = torch.zeros(beam_size, dtype=torch.bool).to(device)
    memory = memory.to(device)  # Ensure memory is on the correct device

    for _ in range(max_seq_length):
        # Forward pass through decoder
        decoder_outputs, memory = model.transformer.forward_decoder(tgt_inp, memory)

        # Get log probabilities for the last token
        log_prob = torch.log_softmax(decoder_outputs[:, -1, :], dim=-1)  # Shape: (beam_size, vocab_size)

        # Accumulate log probabilities
        log_probs = log_probs + log_prob  # Shape: (beam_size, vocab_size)

        # Flatten for top-k selection
        log_probs_flat = log_probs.view(-1)  # Shape: (beam_size * vocab_size)
        top_v, top_i = log_probs_flat.topk(beam_size, dim=0)

        # Compute new beam indices and token indices
        beam_indices = top_i // log_prob.size(-1)  # Shape: (beam_size)
        token_indices = top_i % log_prob.size(-1)  # Shape: (beam_size)

        # Update log_probs and tgt_inp
        new_log_probs = torch.zeros(beam_size, 1).to(device)
        new_tgt_inp = torch.zeros(beam_size, tgt_inp.size(1) + 1, dtype=torch.long).to(device)

        for j in range(beam_size):
            if finished[j]:
                new_log_probs[j] = log_probs[j].sum().unsqueeze(0)
                new_tgt_inp[j, :-1] = tgt_inp[j]
                new_tgt_inp[j, -1] = eos_token
                continue

            beam_idx = beam_indices[j]
            token_idx = token_indices[j]
            new_log_probs[j] = top_v[j].unsqueeze(0)
            new_tgt_inp[j, :-1] = tgt_inp[beam_idx]
            new_tgt_inp[j, -1] = token_idx

            # Check if the token is <eos>
            if token_idx == eos_token:
                finished[j] = True
                # Pad the sequence to max_seq_length before appending
                seq_len = new_tgt_inp[j:j+1].size(1)
                if seq_len < max_seq_length:
                    padding = torch.full((1, max_seq_length - seq_len), eos_token, dtype=torch.long, device=device)
                    padded_sent = torch.cat([new_tgt_inp[j:j+1], padding], dim=1)
                else:
                    padded_sent = new_tgt_inp[j:j+1]
                final_sents.append(padded_sent)
                final_probs.append(new_log_probs[j].item())  # Append scalar probability

        log_probs = new_log_probs
        tgt_inp = new_tgt_inp

        # If all beams are finished, break
        if finished.all():
            break

    # If no sentences finished, take the best and pad it
    if not final_sents:
        seq_len = tgt_inp[0:1].size(1)
        if seq_len < max_seq_length:
            padding = torch.full((1, max_seq_length - seq_len), eos_token, dtype=torch.long, device=device)
            padded_sent = torch.cat([tgt_inp[0:1], padding], dim=1)
        else:
            padded_sent = tgt_inp[0:1]
        final_sents.append(padded_sent)
        final_probs.append(log_probs[0].sum().item())  # Append scalar probability

    # Concatenate final sentences and probabilities
    final_sents = torch.cat(final_sents, dim=0)  # Shape: (num_finished_beams, max_seq_length)
    final_probs = torch.tensor(final_probs, device=device)  # Shape: (num_finished_beams,)

    # Select the best candidates
    _, idx = final_probs.topk(min(candidates, final_probs.size(0)), dim=0)  # Shape: (candidates,)
    idx = idx.squeeze()  # Ensure idx is 1D
    if idx.dim() == 0:  # Handle case where idx is a scalar (e.g., candidates=1)
        idx = idx.unsqueeze(0)
    return final_sents[idx]  # Shape: (candidates, max_seq_length)
def translate_beam_search(
    img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2
):
    model.eval()
    device = img.device

    with torch.no_grad():
        src = model.encoder(img)  # Shape: (1, seq_len, d_model)
        memory = model.transformer.forward_encoder(src)  # Shape: (1, seq_len, d_model)
        sent = beamsearch(
            memory, model, device, beam_size, candidates, max_seq_length, sos_token, eos_token
        )  # Shape: (candidates, max_seq_length)

    return sent


def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2):
    model.eval()
    device = img.device

    with torch.no_grad():
        src = model.encoder(img)  # Shape: (batch_size, seq_len, d_model)
        memory = model.transformer.forward_encoder(src)  # Shape: (batch_size, seq_len, d_model)

        translated_sentence = [[sos_token] * len(img)]
        char_probs = [[1] * len(img)]

        max_length = 0

        while max_length <= max_seq_length and not all(
            np.any(np.asarray(translated_sentence).T == eos_token, axis=1)
        ):
            tgt_inp = torch.LongTensor(translated_sentence).to(device)
            output, memory = model.transformer.forward_decoder(tgt_inp, memory)
            output = softmax(output, dim=-1)
            output = output.to("cpu")

            values, indices = torch.topk(output, 5)
            indices = indices[:, -1, 0]
            indices = indices.tolist()
            values = values[:, -1, 0]
            values = values.tolist()
            char_probs.append(values)

            translated_sentence.append(indices)
            max_length += 1

            del output

        translated_sentence = np.asarray(translated_sentence).T
        char_probs = np.asarray(char_probs).T
        char_probs = np.multiply(char_probs, translated_sentence > 3)
        char_probs = np.sum(char_probs, axis=-1) / (char_probs > 0).sum(-1)

    return translated_sentence, char_probs


def build_model(config):
    vocab = Vocab(config["vocab"])
    device = config["device"]

    model = VietOCR(
        len(vocab),
        config["backbone"],
        config["cnn"],
        config["transformer"],
        config["seq_modeling"],
    )

    model = model.to(device)
    return model, vocab


def resize(w, h, expected_height, image_min_width, image_max_width):
    new_w = int(expected_height * float(w) / float(h))
    round_to = 10
    new_w = math.ceil(new_w / round_to) * round_to
    new_w = max(new_w, image_min_width)
    new_w = min(new_w, image_max_width)
    return new_w, expected_height


def process_image(image, image_height, image_min_width, image_max_width):
    img = image.convert("RGB")
    w, h = img.size
    new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width)
    img = img.resize((new_w, image_height), Image.LANCZOS)
    img = np.asarray(img).transpose(2, 0, 1)
    img = img / 255
    return img


def process_input(image, image_height, image_min_width, image_max_width):
    img = process_image(image, image_height, image_min_width, image_max_width)
    img = img[np.newaxis, ...]
    img = torch.FloatTensor(img)
    return img


def predict(filename, config):
    img = Image.open(filename)
    img = process_input(img, config["image_height"], config["image_min_width"], config["image_max_width"])
    img = img.to(config["device"])
    model, vocab = build_model(config)
    s = translate(img, model)[0].tolist()
    s = vocab.decode(s)
    return s

# 6. Metrics

In [9]:
import os
import gdown
import yaml
import numpy as np
import uuid
import requests
import tempfile
from tqdm import tqdm


def download_weights(uri, cached=None, md5=None, quiet=False):
    if uri.startswith("http"):
        return download(url=uri, quiet=quiet)
    return uri


def download(url, quiet=False):
    tmp_dir = tempfile.gettempdir()
    filename = url.split("/")[-1]
    full_path = os.path.join(tmp_dir, filename)

    if os.path.exists(full_path):
        print("Model weight {} exsits. Ignore download!".format(full_path))
        return full_path

    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(full_path, "wb") as f:
            for chunk in tqdm(r.iter_content(chunk_size=8192)):
                # If you have chunk encoded response uncomment if
                # and set chunk_size parameter to None.
                # if chunk:
                f.write(chunk)
    return full_path


def download_config(id):
    url = "https://vocr.vn/data/vietocr/config/{}".format(id)
    r = requests.get(url)
    config = yaml.safe_load(r.text)
    return config


def compute_accuracy(ground_truth, predictions, mode="full_sequence"):
    """
    Computes accuracy
    :param ground_truth:
    :param predictions:
    :param display: Whether to print values to stdout
    :param mode: if 'per_char' is selected then
                 single_label_accuracy = correct_predicted_char_nums_of_single_sample / single_label_char_nums
                 avg_label_accuracy = sum(single_label_accuracy) / label_nums
                 if 'full_sequence' is selected then
                 single_label_accuracy = 1 if the prediction result is exactly the same as label else 0
                 avg_label_accuracy = sum(single_label_accuracy) / label_nums
    :return: avg_label_accuracy
    """
    if mode == "per_char":

        accuracy = []

        for index, label in enumerate(ground_truth):
            prediction = predictions[index]
            total_count = len(label)
            correct_count = 0
            try:
                for i, tmp in enumerate(label):
                    if tmp == prediction[i]:
                        correct_count += 1
            except IndexError:
                continue
            finally:
                try:
                    accuracy.append(correct_count / total_count)
                except ZeroDivisionError:
                    if len(prediction) == 0:
                        accuracy.append(1)
                    else:
                        accuracy.append(0)
        avg_accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
    elif mode == "full_sequence":
        try:
            correct_count = 0
            for index, label in enumerate(ground_truth):
                prediction = predictions[index]
                if prediction == label:
                    correct_count += 1
            avg_accuracy = correct_count / len(ground_truth)
        except ZeroDivisionError:
            if not predictions:
                avg_accuracy = 1
            else:
                avg_accuracy = 0
    else:
        raise NotImplementedError(
            "Other accuracy compute mode has not been implemented"
        )

    return avg_accuracy

# 7. Logger

In [10]:
import os


class Logger:
    def __init__(self, fname):
        path, _ = os.path.split(fname)
        os.makedirs(path, exist_ok=True)

        self.logger = open(fname, "w")

    def log(self, string):
        self.logger.write(string + "\n")
        self.logger.flush()

    def close(self):
        self.logger.close()

# 8. Trainer

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from tqdm import tqdm


class Trainer:
    def __init__(self, config, model, vocab, train_dataset, valid_dataset=None, pretrained=True):
        self.config = config
        self.model = model
        self.vocab = vocab
        self.device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.num_iters = config['train']['max_iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.batch_size = config['train']['batch_size']
        self.print_every = config['train']['print_every']
        self.valid_every = config['train']['valid_every']
        self.checkpoint_path = config['train']['checkpoint']
        self.export_weights = config['train']['export']
        self.metrics = None
        logger = config['train'].get('log', None)

        if logger:
            self.logger = Logger(logger)

        if pretrained and config.get('pretrain'):
            self.load_weights(config['pretrain'])

        self.optimizer = optim.AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09)
        self.scheduler = OneCycleLR(
            self.optimizer, total_steps=self.num_iters, **config['optimizer']
        )
        self.criterion = LabelSmoothingLoss(
            len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1
        )

        # Create data loaders
        self.train_gen = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=config['dataloader']['num_workers'],
            pin_memory=config['dataloader']['pin_memory'],
            collate_fn=collate_fn
        )
        if valid_dataset:
            self.valid_gen = DataLoader(
                valid_dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=config['dataloader']['num_workers'],
                pin_memory=config['dataloader']['pin_memory'],
                collate_fn=collate_fn
            )

        self.iter = 0
        self.train_losses = []

    def train(self):
        total_loss = 0
        total_loader_time = 0
        total_gpu_time = 0
        best_acc = 0

        data_iter = iter(self.train_gen)
        for i in tqdm(range(self.num_iters)):
            self.iter += 1
            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start
            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = "iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}".format(
                    self.iter,
                    total_loss / self.print_every,
                    self.optimizer.param_groups[0]["lr"],
                    total_loader_time,
                    total_gpu_time,
                )
                print(info)
                if hasattr(self, 'logger'):
                    self.logger.log(info)
                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0

            if self.valid_gen and self.iter % self.valid_every == 0:
                val_loss = self.validate()
                acc_full_seq, acc_per_char = self.precision(self.metrics)
                info = "iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}".format(
                    self.iter, val_loss, acc_full_seq, acc_per_char
                )
                print(info)
                if hasattr(self, 'logger'):
                    self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq

        # Save final model
        self.save_weights(self.export_weights)
        print(f'Saved final model to {self.export_weights}')

    def step(self, batch):
        self.model.train()
        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = (
            batch['img'],
            batch['label'][:, :-1],  # Exclude <eos>
            batch['label'][:, 1:],   # Exclude <go>
            batch['padding_mask'][:, :-1]
        )
    
        outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask)
        outputs = outputs.reshape(-1, outputs.size(2))  # Flatten (N, T, V) to (N*T, V)
        tgt_output = tgt_output.reshape(-1)  # Flatten (N, T) to (N*T)
    
        loss = self.criterion(outputs, tgt_output)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
        self.optimizer.step()
        self.scheduler.step()
    
        return loss.item()

    def validate(self):
        self.model.eval()
        total_loss = []
    
        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
    
                img = batch['img']
                tgt_input = batch['label'][:, :-1]
                tgt_output = batch['label'][:, 1:]
                tgt_padding_mask = batch['padding_mask'][:, :-1]
    
                outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask)
                outputs = outputs.flatten(0, 1)
                tgt_output = tgt_output.flatten()
                loss = self.criterion(outputs, tgt_output)
    
                total_loss.append(loss.item())
                del outputs
                del loss
    
        total_loss = np.mean(total_loss)
        self.model.train()
        return total_loss

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []

        self.model.eval()
        with torch.no_grad():
            for batch in self.valid_gen:
                batch = self.batch_to_device(batch)
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(batch['img'], self.model)

                pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
                actual_sent = self.vocab.batch_decode(batch['label'].tolist())
                img_files.extend([os.path.basename(f) for f in batch['filenames']])

                pred_sents.extend(pred_sent)
                actual_sents.extend(actual_sent)

                if sample and len(pred_sents) > sample:
                    break

        return pred_sents, actual_sents, img_files, prob

    def precision(self, sample=None):
        pred_sents, actual_sents, _, _ = self.predict(sample=sample)
        acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode="full_sequence")
        acc_per_char = compute_accuracy(actual_sents, pred_sents, mode="per_char")
        return acc_full_seq, acc_per_char

    def visualize_prediction(self, sample=16, errorcase=False, fontname="serif", fontsize=16):
        pred_sents, actual_sents, img_files, probs = self.predict(sample)
        data_dir = self.config['dataset']['data_root']

        if errorcase:
            wrongs = [i for i in range(len(pred_sents)) if pred_sents[i] != actual_sents[i]]
            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs] if probs else None

        img_files = img_files[:sample]
        fontdict = {"family": fontname, "size": fontsize}

        for vis_idx, img_file in enumerate(img_files):
            img_path = os.path.join(data_dir, img_file)
            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]
            prob = probs[vis_idx] if probs else None

            img = Image.open(img_path)
            plt.figure()
            plt.imshow(img)
            title = f"pred: {pred_sent} - actual: {actual_sent}"
            if prob is not None:
                title = f"prob: {prob:.3f} - {title}"
            plt.title(title, loc="left", fontdict=fontdict)
            plt.axis("off")
            plt.show()

    def batch_to_device(self, batch):
        return {
            'img': batch['img'].to(self.device, non_blocking=True),
            'label': batch['label'].to(self.device, non_blocking=True),
            'padding_mask': batch['padding_mask'].to(self.device, non_blocking=True),
            'filenames': batch['filenames']
        }

    def load_weights(self, filename):
        state_dict = torch.load(filename, map_location=self.device)
        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        torch.save(self.model.state_dict(), filename)

# 9. Train

In [12]:
import json
import os
from sklearn.model_selection import train_test_split

def split_dataset(labels_file, output_dir, train_ratio=0.8, random_seed=42):
    """
    Split the labels.json file into training and validation sets.
    
    Args:
        labels_file (str): Path to labels.json
        output_dir (str): Directory to save train_labels.json and valid_labels.json
        train_ratio (float): Proportion of data for training (default: 0.8)
        random_seed (int): Seed for reproducibility
    """
    # Load the labels
    with open(labels_file, 'r', encoding='utf-8') as f:
        labels = json.load(f)
    
    # Get list of image filenames and corresponding labels
    image_files = list(labels.keys())
    label_texts = list(labels.values())
    
    # Optional: Stratify by label length for balanced splits
    # Create bins based on text length
    lengths = [len(text) for text in label_texts]
    bins = [min(l, 20) for l in lengths]  # Cap at 20 for reasonable binning
    
    # Split the data
    train_files, valid_files, train_labels, valid_labels = train_test_split(
        image_files,
        label_texts,
        train_size=train_ratio,
        random_state=random_seed,
        stratify=bins  # Stratify by text length
    )
    
    # Create dictionaries for train and valid sets
    train_dict = dict(zip(train_files, train_labels))
    valid_dict = dict(zip(valid_files, valid_labels))
    
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Save train and valid JSON files
    train_json = os.path.join(output_dir, 'train_labels.json')
    valid_json = os.path.join(output_dir, 'valid_labels.json')
    
    with open(train_json, 'w', encoding='utf-8') as f:
        json.dump(train_dict, f, ensure_ascii=False, indent=2)
    
    with open(valid_json, 'w', encoding='utf-8') as f:
        json.dump(valid_dict, f, ensure_ascii=False, indent=2)
    
    print(f"Saved training labels to {train_json} ({len(train_dict)} samples)")
    print(f"Saved validation labels to {valid_json} ({len(valid_dict)} samples)")
    
    return train_json, valid_json

In [19]:
import os
import json
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split

# Configuration
def create_vit_config():
    config = {
        'transformer': {
            'vocab_size': None,  # Will be set after vocab creation
            'd_model': 256,
            'nhead': 8,
            'num_encoder_layers': 6,
            'num_decoder_layers': 6,
            'dim_feedforward': 2048,
            'max_seq_length': 1024,  # Increased to match ViT sequence length
            'pos_dropout': 0.1,
            'trans_dropout': 0.1
        },
        'train': {
            'batch_size': 8,
            'epochs': 100,
            'print_every': 200,
            'valid_every': 1000,
            'checkpoint': './checkpoint/vit_transformer_checkpoint.pth',
            'export': './weights/vit_transformer.pth',
            'max_iters': 10000,
            'learning_rate': 1e-4
        }
    }
    return config

# Function to split dataset (from Step 1)
def split_dataset(labels_file, output_dir, train_ratio=0.8, random_seed=42):
    with open(labels_file, 'r', encoding='utf-8') as f:
        labels = json.load(f)
    
    image_files = list(labels.keys())
    label_texts = list(labels.values())
    lengths = [len(text) for text in label_texts]
    bins = [min(l, 20) for l in lengths]
    
    train_files, valid_files, train_labels, valid_labels = train_test_split(
        image_files,
        label_texts,
        train_size=train_ratio,
        random_state=random_seed,
        stratify=bins
    )
    
    train_dict = dict(zip(train_files, train_labels))
    valid_dict = dict(zip(valid_files, valid_labels))
    
    os.makedirs(output_dir, exist_ok=True)
    train_json = os.path.join(output_dir, 'train_labels.json')
    valid_json = os.path.join(output_dir, 'valid_labels.json')
    
    with open(train_json, 'w', encoding='utf-8') as f:
        json.dump(train_dict, f, ensure_ascii=False, indent=2)
    with open(valid_json, 'w', encoding='utf-8') as f:
        json.dump(valid_dict, f, ensure_ascii=False, indent=2)
    
    print(f"Saved training labels to {train_json} ({len(train_dict)} samples)")
    print(f"Saved validation labels to {valid_json} ({len(valid_dict)} samples)")
    
    return train_json, valid_json

# Main training function
def train():
    # Configuration
    config = create_vit_config()
    config.update({
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'dataset': {
            'data_root': '/kaggle/input/dataset-ocr/dataset/data',
            'train_annotation': None,  # Will be set after splitting
            'valid_annotation': None,  # Will be set after splitting
            'name': 'ocr_dataset',
            'image_height': 224,
            'image_min_width': 224,
            'image_max_width': 224
        },
        'aug': {
            'image_aug': False,
            'masked_language_model': False
        },
        'optimizer': {
            'max_lr': 1e-4,
            'pct_start': 0.1,
            'anneal_strategy': 'cos'
        },
        'dataloader': {
            'num_workers': 1,
            'pin_memory': True
        },
        'predictor': {
            'beamsearch': True
        },
        'pretrain': '/kaggle/input/vietocr/pytorch/default/1/vgg_transformer.pth',
        'quiet': False
    })

    # Split the dataset
    original_labels_file = '/kaggle/input/dataset-ocr/dataset/labels.json'
    output_dir = './dataset_splits'
    train_json, valid_json = split_dataset(original_labels_file, output_dir, train_ratio=0.8, random_seed=42)
    
    # Update config with new annotation files
    config['dataset']['train_annotation'] = train_json
    config['dataset']['valid_annotation'] = valid_json

    # Create vocabulary
    vocab = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ '
    vocab = Vocab(vocab)
    config['transformer']['vocab_size'] = len(vocab)

    # Define image transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create datasets
    train_dataset = CustomDataset(
        data_dir=config['dataset']['data_root'],
        labels_file=config['dataset']['train_annotation'],
        vocab=vocab,
        transform=transform
    )
    valid_dataset = CustomDataset(
        data_dir=config['dataset']['data_root'],
        labels_file=config['dataset']['valid_annotation'],
        vocab=vocab,
        transform=transform
    )

    # Create model
    model = VietOCR(
        vocab_size=len(vocab),
        backbone=None,
        vit_args=config,
        transformer_args=config['transformer'],
        seq_modeling='transformer'
    )

    # Create trainer
    trainer = Trainer(
        config=config,
        model=model,
        vocab=vocab,
        train_dataset=train_dataset,
        valid_dataset=valid_dataset,
        pretrained=True
    )

    # Train model
    trainer.train()

if __name__ == '__main__':
    os.makedirs('./checkpoint', exist_ok=True)
    os.makedirs('./weights', exist_ok=True)
    os.makedirs('./dataset_splits', exist_ok=True)
    train()

 70%|███████   | 7000/10000 [1:03:46<39:50:07, 47.80s/it]

iter: 007000 - valid loss: 1.595 - acc full seq: 0.0000 - acc per char: 0.0765


 72%|███████▏  | 7200/10000 [1:05:03<17:49,  2.62it/s]   

iter: 007200 - train loss: 1.403 - lr: 2.20e-05 - load time: 0.34 - gpu time: 76.81


 74%|███████▍  | 7400/10000 [1:06:20<16:36,  2.61it/s]

iter: 007400 - train loss: 1.395 - lr: 1.92e-05 - load time: 0.32 - gpu time: 76.18


 76%|███████▌  | 7600/10000 [1:07:37<15:17,  2.62it/s]

iter: 007600 - train loss: 1.396 - lr: 1.65e-05 - load time: 0.33 - gpu time: 75.97


 78%|███████▊  | 7800/10000 [1:08:53<14:00,  2.62it/s]

iter: 007800 - train loss: 1.393 - lr: 1.40e-05 - load time: 0.36 - gpu time: 76.08


 80%|███████▉  | 7999/10000 [1:10:09<12:46,  2.61it/s]

iter: 008000 - train loss: 1.394 - lr: 1.17e-05 - load time: 0.32 - gpu time: 76.08


 80%|████████  | 8000/10000 [1:12:42<25:40:45, 46.22s/it]

iter: 008000 - valid loss: 1.578 - acc full seq: 0.0000 - acc per char: 0.0728


 82%|████████▏ | 8200/10000 [1:13:59<11:25,  2.63it/s]   

iter: 008200 - train loss: 1.386 - lr: 9.54e-06 - load time: 0.32 - gpu time: 76.39


 84%|████████▍ | 8400/10000 [1:15:16<10:13,  2.61it/s]

iter: 008400 - train loss: 1.389 - lr: 7.59e-06 - load time: 0.33 - gpu time: 76.29


 86%|████████▌ | 8600/10000 [1:16:33<08:54,  2.62it/s]

iter: 008600 - train loss: 1.388 - lr: 5.84e-06 - load time: 0.31 - gpu time: 75.98


 88%|████████▊ | 8800/10000 [1:17:49<07:40,  2.61it/s]

iter: 008800 - train loss: 1.384 - lr: 4.32e-06 - load time: 0.32 - gpu time: 76.07


 90%|████████▉ | 8999/10000 [1:19:05<06:20,  2.63it/s]

iter: 009000 - train loss: 1.388 - lr: 3.01e-06 - load time: 0.33 - gpu time: 76.05


 90%|█████████ | 9000/10000 [1:21:42<13:09:25, 47.37s/it]

iter: 009000 - valid loss: 1.587 - acc full seq: 0.0000 - acc per char: 0.0754


 92%|█████████▏| 9200/10000 [1:22:59<04:47,  2.78it/s]   

iter: 009200 - train loss: 1.390 - lr: 1.93e-06 - load time: 0.32 - gpu time: 76.32


 94%|█████████▍| 9400/10000 [1:24:16<03:48,  2.63it/s]

iter: 009400 - train loss: 1.384 - lr: 1.09e-06 - load time: 0.54 - gpu time: 76.07


 96%|█████████▌| 9600/10000 [1:25:32<02:32,  2.62it/s]

iter: 009600 - train loss: 1.390 - lr: 4.85e-07 - load time: 0.31 - gpu time: 75.99


 98%|█████████▊| 9800/10000 [1:26:49<01:16,  2.61it/s]

iter: 009800 - train loss: 1.387 - lr: 1.21e-07 - load time: 0.35 - gpu time: 76.21


100%|█████████▉| 9999/10000 [1:28:05<00:00,  2.61it/s]

iter: 010000 - train loss: 1.393 - lr: 4.03e-10 - load time: 0.32 - gpu time: 76.27


100%|██████████| 10000/10000 [1:30:42<00:00,  1.84it/s]

iter: 010000 - valid loss: 1.589 - acc full seq: 0.0000 - acc per char: 0.0754





Saved final model to ./weights/vit_transformer.pth
