## Imports and utils

In [None]:
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
from torchvision.transforms import Resize
from torchvision.transforms.functional import to_pil_image
from typing import List, Tuple
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn import CTCLoss

In [None]:
# utils for decoding the labels

provinces = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "警", "学", "O"]
alphabets = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W',
             'X', 'Y', 'Z', 'O']
ads = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
       'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'O']

# decodes the plate from the file name
def decode_plate(label_str):
    indices = list(map(int, label_str.split('_')))
    province = provinces[indices[0]]
    alphabet = alphabets[indices[1]]
    ad = ''
    for i in range(2, len(indices)):
        ad += ads[indices[i]]

    return province + alphabet + ad

full_charset = provinces[:-1] + alphabets[:-1] + ads[:-1]
char_to_idx = {char: idx+1 for idx, char in enumerate(full_charset)}  # leave 0 for CTC blank
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

# encodes plate for the model
def encode_plate(text: str) -> List[int]:
    return [char_to_idx[c] for c in text if c in char_to_idx]


In [None]:
# torch dataset
class LicensePlateCCPDDataset(Dataset):
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        filename = self.image_files[idx]
        path = os.path.join(self.image_dir, filename)
    
        # load image
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
        # bbox from filename
        parts = filename.split('-')
        bbox_part = parts[2]
        x1y1, x2y2 = bbox_part.split('_')
        x1, y1 = map(int, x1y1.split('~'))
        x2, y2 = map(int, x2y2.split('~'))
    
        # crop given the plate bbox
        h, w = image.shape[:2]
        x1, x2 = max(0, x1), min(w, x2)
        y1, y2 = max(0, y1), min(h, y2)
        cropped = image[y1:y2, x1:x2]
    
        # resize as the paper
        cropped = cv2.resize(cropped, (144, 48))
        image_tensor = torch.tensor(cropped, dtype=torch.float32).permute(2, 0, 1) / 255.0

        # plate text
        plate_raw = parts[4]
        plate_text = decode_plate(plate_raw)
    
        return image_tensor, plate_text


In [None]:
# creates the dataset and dataloader
dataset = LicensePlateCCPDDataset("/kaggle/input/ccpd-weather/ccpd_weather")
loader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
# image downsampling (better than pooling)
class Focus(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels * 4, out_channels, 3, 1, 1)

    def forward(self, x):
        return self.conv(torch.cat([
            x[..., ::2, ::2],
            x[..., ::2, 1::2],
            x[..., 1::2, ::2],
            x[..., 1::2, 1::2]
        ], dim=1))

# convolution sequence block
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p)
        self.bn = nn.BatchNorm2d(out_ch)  
        self.act = nn.LeakyReLU(0.1, inplace=False)

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

# residual blocks
class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(ch, ch),
            ConvBlock(ch, ch)
        )
        self.bn = nn.BatchNorm2d(ch) 

    def forward(self, x):
        out = self.block(x)
        return self.bn(x + out)


# Image Global Feature Extractor block (combines the previous blocks)
class IGFE(nn.Module):
    def __init__(self):
        super().__init__()
        self.focus = Focus(3, 64)
        self.down1 = ConvBlock(64, 128, s=2)
        self.down2 = ConvBlock(128, 256, s=2)
        self.res = nn.Sequential(
            ResBlock(256),
            ResBlock(256),
            ResBlock(256),
            ResBlock(256)
        )
        self.conv_out = nn.Conv2d(256, 512, 1)

    def forward(self, x):
        x = self.focus(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.res(x)
        x = torch.clamp(x, -10, 10)  # safety clamp
        x = self.conv_out(x)
        return x

# transformer encoding from image
class TransformerEncoder(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_layers=3):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.randn(1, 108, d_model)) 
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2048)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.view(B, C, -1).permute(0, 2, 1)  
        x = self.dropout(x + self.pos_embed)
        x = self.encoder(x)
        return x



# prediction block (decodes the text)
class ParallelDecoder(nn.Module):
    def __init__(self, d_model=512, num_classes=92):
        super().__init__()
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        return self.head(x)

# full model (Parallel Deep-Learning License Plate Recognition)
class PDLPRModel(nn.Module):
    def __init__(self, num_classes=92):
        super().__init__()
        self.igfe = IGFE()
        self.encoder = TransformerEncoder()
        self.decoder = ParallelDecoder(num_classes=num_classes)

        self._init_weights()


    def _init_weights(self): # needed because of unstable training
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
    


    def forward(self, x):
        x = self.igfe(x)
        x = self.encoder(x)
        x = self.decoder(x)

        return x




In [None]:
# model initialization and utils
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PDLPRModel(num_classes=92).to(device)

ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True, reduction='sum') # reduction defined for stability
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9) # as in paper

In [None]:
# training loop
for epoch in range(50):
    model.train()
    total_loss = 0

    for imgs, texts in tqdm(loader):
        imgs = imgs.to(device)

        # encoding targets
        encoded_targets = [torch.tensor(encode_plate(t), dtype=torch.long) for t in texts]
        target_lengths = torch.tensor([len(seq) for seq in encoded_targets], dtype=torch.long)
        targets = torch.cat(encoded_targets)

        # all have same input length (108 from encoder block)
        input_lengths = torch.full(size=(imgs.size(0),), fill_value=108, dtype=torch.long)

        # forward
        logits = model(imgs)
        log_probs = logits.log_softmax(2).permute(1, 0, 2)

        # loss = simple_ctc_loss(log_probs,targets,input_lengths,target_lengths,blank=0)
        # uses CTC from torch
        loss = ctc_loss(
            log_probs,
            targets,
            input_lengths,
            target_lengths
        )
        optimizer.zero_grad()
        loss.backward()  
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

    scheduler.step()

    print(f"Epoch [{epoch}/50] - Loss: {total_loss / len(loader):.4f}")

In [None]:
# saves the weights
torch.save(model.state_dict(), "pdlpr_model_weights.pth")