# ocr

In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")

import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
from utils import CTCLabelConverter

## data

In [2]:
from data.dataset import OCRDataset
from data.dataset import get_vocab

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
image_path = "/home/pc/Desktop/coding/vision/lib_ocr/trainning_images"
VOCAB = get_vocab(image_path)
toy_dataset = OCRDataset(image_path)


In [4]:
convert = CTCLabelConverter(VOCAB, device)

## model

In [5]:
from models.backbones import ResNet_FeatureExtractor
from models.pred_modules import CTC_Predictor

In [6]:
from torch import nn
class OCR_model_T1(nn.Module):
    def __init__(self, input_channels, num_classes):
        super().__init__()
        self.FeatureExtraction = ResNet_FeatureExtractor(input_channels,512)
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))
        self.Predictor = CTC_Predictor(512, num_classes)
        
        
    def forward(self, input):
        # import pdb; pdb.set_trace()
        # input: [b, c, h, w] 
        # input: torch.Size([8, 1, 32, 100])

        visual_feature = self.FeatureExtraction(input) 
        # visual_feature: torch.Size([8, 512, 1, 26])

        visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))  # [b, c, h, w] -> [b, w, c, h]
        # visual_feature: torch.Size([8, 26, 512, 1])

        visual_feature = visual_feature.squeeze(3)
        # visual_feature: torch.Size([8, 26, 512])

        prediction = self.Predictor(visual_feature.contiguous())
        # prediction: torch.Size([8, 26, 94])
        return prediction
    
model = OCR_model_T1(1, len(convert.character)).to(device)


In [7]:
import torch
from data.dataset import AlignCollate
align_collate = AlignCollate()
batch_size = 32
data_loader = torch.utils.data.DataLoader(toy_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=True,
                                          pin_memory=True,
                                            num_workers=4, # delete this line if you debug
                                          collate_fn=align_collate)
# import pdb; pdb.set_trace()
sample = next(iter(data_loader))

In [8]:
sample[0].shape

torch.Size([32, 1, 32, 100])

In [9]:
sample[1]

['Mua',
 'Ford Mustang',
 'ngoài. Có',
 'TTGT',
 'ha cà',
 'nào?. Mua nhà',
 '500m³ cát trái',
 'bà',
 'bóng',
 'trinh',
 'trong tháng',
 'thu hút',
 'Long -',
 'Bí.',
 'xóa',
 'Venezuela',
 'cao',
 'giá vé',
 'thi',
 'Bình xin phá',
 'y',
 'Nhìn',
 'dày khi',
 'Mê Linh',
 'công',
 'ký',
 'Lào Cai.',
 'công an.',
 "xanh'",
 'sát',
 'pháp lý.',
 'Sài Gòn.']

In [10]:

# with torch.inference_mode():
#     y_pred = model(sample[0])
#     print(y_pred.shape)

## traning loop

In [11]:
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


In [12]:
from torch.optim import Adam
from torch.nn import CTCLoss
from tqdm.auto import tqdm

optimizer = Adam(model.parameters(), lr=1e-3)
ctc_loss = CTCLoss().to(device)  # Ensure blank index matches your vocab
model.train()

for epoch in tqdm(range(30)):
    for idx, (image, labels) in tqdm(enumerate(data_loader)):
        # Encode labels
        text, length = convert.encode(labels, batch_max_length=len(VOCAB))
        
        
        # Forward pass
        preds = model(image.to(device))
        
        # Ensure `preds_size` matches batch size
        batch_size = image.size(0)
        preds_size = torch.IntTensor([preds.size(1)] * batch_size)  # Each sequence has `preds.size(1)` timesteps
        
        preds = preds.log_softmax(2).permute(1, 0, 2)  
        
        # Compute CTC loss
        loss = ctc_loss(preds, text, preds_size, length)
        
        model.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)  # gradient clipping with 5 (Default)
        optimizer.step()
        
        # Logging
        if idx % 100 == 0:
            print(f"Epoch: {epoch}, Batch: {idx}, Loss: {loss.item()}")
    

  0%|          | 0/30 [00:00<?, ?it/s]



Epoch: 0, Batch: 0, Loss: 29.564260482788086


32it [00:04,  7.24it/s]
  3%|▎         | 1/30 [00:04<02:09,  4.46s/it]

Epoch: 1, Batch: 0, Loss: 4.744544982910156


32it [00:04,  7.92it/s]
  7%|▋         | 2/30 [00:08<01:59,  4.25s/it]

Epoch: 2, Batch: 0, Loss: 4.214029312133789


32it [00:04,  7.98it/s]
 10%|█         | 3/30 [00:12<01:52,  4.17s/it]

Epoch: 3, Batch: 0, Loss: 3.7074358463287354


32it [00:04,  8.00it/s]
 13%|█▎        | 4/30 [00:16<01:47,  4.12s/it]

Epoch: 4, Batch: 0, Loss: 4.21854305267334


32it [00:03,  8.05it/s]
 17%|█▋        | 5/30 [00:20<01:42,  4.09s/it]

Epoch: 5, Batch: 0, Loss: 3.6330366134643555


32it [00:04,  7.98it/s]
 20%|██        | 6/30 [00:24<01:37,  4.08s/it]

Epoch: 6, Batch: 0, Loss: 3.702418804168701


32it [00:04,  7.94it/s]
 23%|██▎       | 7/30 [00:28<01:33,  4.08s/it]

Epoch: 7, Batch: 0, Loss: 3.5576744079589844


32it [00:03,  8.00it/s]
 27%|██▋       | 8/30 [00:32<01:29,  4.08s/it]

Epoch: 8, Batch: 0, Loss: 3.5335774421691895


32it [00:03,  8.00it/s]
 30%|███       | 9/30 [00:36<01:25,  4.07s/it]

Epoch: 9, Batch: 0, Loss: 3.534176826477051


32it [00:03,  8.01it/s]
 33%|███▎      | 10/30 [00:41<01:21,  4.06s/it]

Epoch: 10, Batch: 0, Loss: 3.650569438934326


32it [00:03,  8.02it/s]
 37%|███▋      | 11/30 [00:45<01:17,  4.06s/it]

Epoch: 11, Batch: 0, Loss: 3.542224884033203


32it [00:03,  8.02it/s]
 40%|████      | 12/30 [00:49<01:12,  4.05s/it]

Epoch: 12, Batch: 0, Loss: 3.3296008110046387


32it [00:04,  7.99it/s]
 43%|████▎     | 13/30 [00:53<01:08,  4.05s/it]

Epoch: 13, Batch: 0, Loss: 3.7096481323242188


32it [00:04,  7.94it/s]
 47%|████▋     | 14/30 [00:57<01:05,  4.06s/it]

Epoch: 14, Batch: 0, Loss: 2.925893783569336


32it [00:03,  8.00it/s]
 50%|█████     | 15/30 [01:01<01:00,  4.06s/it]

Epoch: 15, Batch: 0, Loss: 3.0345678329467773


32it [00:04,  7.97it/s]
 53%|█████▎    | 16/30 [01:05<00:56,  4.07s/it]

Epoch: 16, Batch: 0, Loss: 3.1404309272766113


32it [00:04,  7.97it/s]
 57%|█████▋    | 17/30 [01:09<00:52,  4.07s/it]

Epoch: 17, Batch: 0, Loss: 3.178518295288086


32it [00:04,  7.98it/s]
 60%|██████    | 18/30 [01:13<00:48,  4.07s/it]

Epoch: 18, Batch: 0, Loss: 2.815819263458252


32it [00:04,  7.98it/s]
 63%|██████▎   | 19/30 [01:17<00:44,  4.06s/it]

Epoch: 19, Batch: 0, Loss: 2.340116262435913


32it [00:04,  7.96it/s]
 67%|██████▋   | 20/30 [01:21<00:40,  4.07s/it]

Epoch: 20, Batch: 0, Loss: 2.3880372047424316


32it [00:04,  7.93it/s]
 70%|███████   | 21/30 [01:25<00:36,  4.08s/it]

Epoch: 21, Batch: 0, Loss: 2.0888001918792725


32it [00:04,  7.99it/s]
 73%|███████▎  | 22/30 [01:29<00:32,  4.08s/it]

Epoch: 22, Batch: 0, Loss: 1.9492268562316895


32it [00:04,  8.00it/s]
 77%|███████▋  | 23/30 [01:33<00:28,  4.07s/it]

Epoch: 23, Batch: 0, Loss: 1.9793633222579956


32it [00:04,  7.94it/s]
 80%|████████  | 24/30 [01:37<00:24,  4.08s/it]

Epoch: 24, Batch: 0, Loss: 1.5852329730987549


32it [00:04,  7.93it/s]
 83%|████████▎ | 25/30 [01:42<00:20,  4.09s/it]

Epoch: 25, Batch: 0, Loss: 1.6026864051818848


32it [00:04,  7.90it/s]
 87%|████████▋ | 26/30 [01:46<00:16,  4.09s/it]

Epoch: 26, Batch: 0, Loss: 0.9912487864494324


32it [00:04,  7.87it/s]
 90%|█████████ | 27/30 [01:50<00:12,  4.10s/it]

Epoch: 27, Batch: 0, Loss: 0.9377477169036865


32it [00:04,  7.90it/s]
 93%|█████████▎| 28/30 [01:54<00:08,  4.10s/it]

Epoch: 28, Batch: 0, Loss: 1.1247371435165405


32it [00:04,  7.90it/s]
 97%|█████████▋| 29/30 [01:58<00:04,  4.10s/it]

Epoch: 29, Batch: 0, Loss: 0.9792147874832153


32it [00:04,  7.84it/s]
100%|██████████| 30/30 [02:02<00:00,  4.09s/it]
