## Imports

In [None]:
import cv2
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import torch.nn as nn
import torch

In [None]:
# utils for decoding the labels

provinces = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "警", "学", "O"]
alphabets = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W',
             'X', 'Y', 'Z', 'O']
ads = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
       'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'O']

# decodes the plate from the file name
def decode_plate(label_str):
    indices = list(map(int, label_str.split('_')))
    province = provinces[indices[0]]
    alphabet = alphabets[indices[1]]
    ad = ''
    for i in range(2, len(indices)):
        ad += ads[indices[i]]

    return province + alphabet + ad

full_charset = provinces[:-1] + alphabets[:-1] + ads[:-1]
char_to_idx = {char: idx+1 for idx, char in enumerate(full_charset)}  # leave 0 for CTC blank
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

# encodes plate for the model
def encode_plate(text: str):
    return [char_to_idx[c] for c in text if c in char_to_idx]

# decodes plate from the model
def decode_plate_model(indices):
    return ''.join([idx_to_char.get(idx, '') for idx in indices if idx != 0])

In [None]:
# torch dataset
class LicensePlateCCPDDataset(Dataset):
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        filename = self.image_files[idx]
        path = os.path.join(self.image_dir, filename)
    
        # load image
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
        # bbox from filename
        parts = filename.split('-')
        bbox_part = parts[2]
        x1y1, x2y2 = bbox_part.split('_')
        x1, y1 = map(int, x1y1.split('~'))
        x2, y2 = map(int, x2y2.split('~'))
    
        # crop given the plate bbox
        h, w = image.shape[:2]
        x1, x2 = max(0, x1), min(w, x2)
        y1, y2 = max(0, y1), min(h, y2)
        cropped = image[y1:y2, x1:x2]
    
        # resize as the paper
        cropped = cv2.resize(cropped, (144, 48))
        image_tensor = (torch.tensor(cropped, dtype=torch.float32).permute(2, 0, 1) / 255.0 - 0.5) / 0.5 # between -1 and 1

        # plate text
        plate_raw = parts[4]
        plate_text = decode_plate(plate_raw)
    
        return image_tensor, plate_text


In [None]:
# creates the dataset and dataloader
dataset = LicensePlateCCPDDataset("/kaggle/input/ccpd-weather/ccpd_weather")
loader = DataLoader(dataset, batch_size=16, shuffle=True)

## Model

In [None]:
# focus block
class Focus(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels * 4, out_channels, 3, 1, 1)

    def forward(self, x):
        return self.conv(torch.cat([
            x[..., ::2, ::2],
            x[..., ::2, 1::2],
            x[..., 1::2, ::2],
            x[..., 1::2, 1::2]
        ], dim=1))


# convolutional block
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.LeakyReLU(0.1, inplace=True)

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


# residual block
class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(ch, ch),
            ConvBlock(ch, ch)
        )
        self.bn = nn.BatchNorm2d(ch)

    def forward(self, x):
        return self.bn(x + self.block(x))


# IGFE
class IGFE(nn.Module):
    def __init__(self):
        super().__init__()
        self.focus = Focus(3, 64)
        self.down1 = ConvBlock(64, 128, s=2)
        self.res1 = nn.Sequential(ResBlock(128), ResBlock(128))
        self.down2 = ConvBlock(128, 256, s=2)
        self.res2 = nn.Sequential(ResBlock(256), ResBlock(256))
        self.conv_out = nn.Conv2d(256, 512, kernel_size=1)

    def forward(self, x):
        x = self.focus(x)
        x = self.down1(x)
        x = self.res1(x)
        x = self.down2(x)
        x = self.res2(x)
        x = torch.clamp(x, -10, 10)
        return self.conv_out(x)  # (B, 512, 6, 18)


# encoder unit
class EncoderUnit(nn.Module):
    def __init__(self, d_model=512, nhead=8):
        super().__init__()
        self.cnn1 = nn.Conv1d(d_model, 1024, kernel_size=1)
        self.mha = nn.MultiheadAttention(embed_dim=1024, num_heads=nhead, batch_first=True)
        self.cnn2 = nn.Conv1d(1024, d_model, kernel_size=1)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x_cnn = x.transpose(1, 2)              
        x_mha_in = F.relu(self.cnn1(x_cnn)).transpose(1, 2)
        attn_out, _ = self.mha(x_mha_in, x_mha_in, x_mha_in) 
        x_proj = self.cnn2(attn_out.transpose(1, 2)).transpose(1, 2)
        return self.norm(x + x_proj)


# full encoder
class PDLPR_Encoder(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_layers=3):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.randn(1, 108, d_model))
        self.dropout = nn.Dropout(p=0.1)
        self.layers = nn.Sequential(*[
            EncoderUnit(d_model=d_model, nhead=nhead) for _ in range(num_layers)
        ])

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.view(B, C, -1).permute(0, 2, 1)  # (B, 108, 512)
        x = self.dropout(x + self.pos_embed[:, :x.size(1), :])
        return self.layers(x)  # (B, 108, 512)


# FFN block (non-linear transformation)
class FeedForward(nn.Module):
    def __init__(self, d_model=512, hidden_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, d_model)
        )
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        return self.norm(x + self.net(x))


# decoder unit
class DecoderUnit(nn.Module):
    def __init__(self, d_model=512, nhead=8):
        super().__init__()
        self.masked_mha = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)

        self.encoder_proj = nn.Sequential(
            nn.Conv1d(d_model, d_model, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(d_model, d_model, kernel_size=1)
        )

        self.cross_mha = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model)

    def forward(self, tgt, memory, tgt_mask=None):
        mha_out1, _ = self.masked_mha(tgt, tgt, tgt, attn_mask=tgt_mask)
        x = self.norm1(tgt + mha_out1)

        mem_proj = memory.transpose(1, 2)
        mem_proj = self.encoder_proj(mem_proj).transpose(1, 2)

        mha_out2, _ = self.cross_mha(x, mem_proj, mem_proj)
        x = self.norm2(x + mha_out2)

        return self.ffn(x)


# full parallel decoder
class ParallelDecoder(nn.Module):
    def __init__(self, d_model=512, num_classes=92, nhead=8, num_layers=3, max_seq_len=18):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, d_model))
        self.dropout = nn.Dropout(0.1)

        self.decoder_layers = nn.ModuleList([
            DecoderUnit(d_model, nhead) for _ in range(num_layers)
        ])

        self.out_proj = nn.Linear(d_model, num_classes)

    def forward(self, encoder_out):
        B = encoder_out.size(0)
        T = 18
        device = encoder_out.device

        tgt = torch.zeros(B, T, encoder_out.size(2), device=device) + self.pos_embed[:, :T, :]
        tgt = self.dropout(tgt)

        mask = torch.triu(torch.ones(T, T, device=device) * float('-inf'), diagonal=1)

        for layer in self.decoder_layers:
            tgt = layer(tgt, encoder_out, tgt_mask=mask)

        return self.out_proj(tgt)  # (B, T, num_classes)


# final PDLPRModel
class PDLPRModel(nn.Module):
    def __init__(self, num_classes=92):
        super().__init__()
        self.igfe = IGFE()
        self.encoder = PDLPR_Encoder()
        self.decoder = ParallelDecoder(num_classes=num_classes)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.igfe(x)           # (B, 512, 6, 18)
        x = self.encoder(x)        # (B, 108, 512)
        x = self.decoder(x)        # (B, 18, num_classes)
        return x

In [15]:
# model initialization and utils
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PDLPRModel(num_classes=92).to(device)

ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True) # reduction defined for stability
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9) # as in paper

In [17]:
def ctc_decode(preds):
    results = []
    prev = -1
    for p in preds:
        if p != prev and p != 0:  # skip duplicates and blank token
            results.append(p)
        prev = p
    return decode_plate_model(results)


# training loop
for epoch in range(300):
    model.train()
    total_loss = 0

    for imgs, texts in loader:
        imgs = imgs.to(device)

        # encoding targets
        encoded_targets = [torch.tensor(encode_plate(t), dtype=torch.long) for t in texts]
        target_lengths = torch.tensor([len(seq) for seq in encoded_targets], dtype=torch.long)
        targets = torch.cat(encoded_targets)
        
        # forward
        logits = model(imgs)
        log_probs = logits.log_softmax(2).permute(1, 0, 2)

         # all have same input length (108 from encoder block)
        input_lengths = torch.full(
            size=(logits.size(0),), 
            fill_value=logits.size(1),  # 18
            dtype=torch.long,
            device=logits.device
        )

        # loss = simple_ctc_loss(log_probs,targets,input_lengths,target_lengths,blank=0)
        # uses CTC from torch
        loss = ctc_loss(
            log_probs,
            targets,
            input_lengths,
            target_lengths
        )


        optimizer.zero_grad()
        loss.backward()  
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

    scheduler.step()

    print(f"Epoch [{epoch}/300] - Loss: {total_loss / len(loader):.4f}")

  7%|▋         | 41/625 [00:12<02:54,  3.35it/s]


KeyboardInterrupt: 

In [None]:
# saves the weights
torch.save(model.state_dict(), "pdlpr_model_weights.pth")

In [None]:
# overfitting function in order to check if there is any structural problem

def overfit_single_plate(model, loader, encode_plate, ctc_decode, device):
    model.to(device)
    model.train()

    # testing reducing model complexity
    if hasattr(model.decoder, 'decoder_layers'):
        model.decoder.decoder_layers = nn.ModuleList([model.decoder.decoder_layers[0]])

    # disable dropout (testing)
    if hasattr(model.encoder, 'dropout'):
        model.encoder.dropout.p = 0.0
    if hasattr(model.decoder, 'dropout'):
        model.decoder.dropout.p = 0.0

    # only one sample
    img, text = next(iter(loader))
    img = img[0:1].to(device)
    text = text[0]
    print(f"\n True plate: {text}")

    # encode target
    target = torch.tensor(encode_plate(text), dtype=torch.long).to(device)
    targets = target
    target_lengths = torch.tensor([len(target)], dtype=torch.long).to(device)

    # loss and optimizer
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # overfitting loop
    for epoch in range(1, 201):
        model.train()
        optimizer.zero_grad()

        logits = model(img)
        log_probs = logits.log_softmax(2).permute(1, 0, 2)
        input_lengths = torch.full((1,), logits.size(1), dtype=torch.long).to(device)

        loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        with torch.no_grad():
            pred_tokens = log_probs.detach().argmax(dim=2).permute(1, 0)
            decoded = ctc_decode(pred_tokens[0].tolist())

        if decoded == text:
            print("overfit successful")
            break

overfit_single_plate(model, loader, encode_plate, ctc_decode, device)