In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

In [3]:
class EfficientNetV2LOCR(nn.Module):
    def __init__(self, len_chars, shape_inp_img):
        super(EfficientNetV2LOCR, self).__init__()
        self.len_chars = len_chars
        self.shape_inp_img = shape_inp_img

        self.base_model = timm.create_model('efficientnetv2_l',
                                            pretrained=False,
                                            num_classes=0,
                                            features_only=True).to(device)
        
        self.dropout = nn.Dropout(0.3)

        self.lstm = nn.LSTM(
            input_size=self._get_lstm_input_size(),
            hidden_size=256,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
            dropout=0.3 if True else 0.0
        ).to(device)

        self.batch_norm = nn.BatchNorm1d(256).to(device)
        
        self.fc = nn.Linear(256, len_chars).to(device)  

    def _get_lstm_input_size(self):
        dummy_input = torch.zeros(1, *self.shape_inp_img).to(device)
        with torch.no_grad():
            features = self.base_model(dummy_input)
            _, C, H, W = features[-1].shape
            return C  

    def forward(self, x):
        features = self.base_model(x)
        feature = features[-1] 

        batch_size, C, H, W = feature.size()
        x = feature.permute(0, 2, 3, 1).contiguous().view(batch_size, H * W, C)

        x = self.dropout(x)

        lstm_out, _ = self.lstm(x)
        lstm_out_forward = lstm_out[:, :, :256]
        lstm_out_backward = lstm_out[:, :, 256:]
        x = (lstm_out_forward + lstm_out_backward) / 2

        x = x.permute(0, 2, 1).contiguous()
        x = self.batch_norm(x)
        x = x.permute(0, 2, 1).contiguous()

        x = self.fc(x)
        x = F.log_softmax(x, dim=2)  
        x = x.permute(1, 0, 2)  

        return x


In [4]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import os

In [5]:
vocabulary = "-1234567890ABEKMHOPCTYX"

char_to_idx = {char: idx + 1 for idx, char in enumerate(vocabulary)}
char_to_idx['blank'] = 0  

idx_to_char = {idx: char for char, idx in char_to_idx.items()}

In [6]:
class PlateDataset(Dataset):
    def __init__(self, dir_path: str, transform=None):
        self.dir_path = dir_path
        self.transform = transform

        self.image_files = [f for f in os.listdir(dir_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.image_files.sort()

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        """
        Возвращаем изображение и метку по индексу

        Args:
            idx (int): Индекс элемента.
        
        Returns:
            tuple: (изображение, метка в виде индексов)
        """

        img_name = self.image_files[idx]
        img_path = os.path.join(self.dir_path, img_name)

        image = Image.open(img_path).convert('RGB')
        image = np.array(image)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        label = os.path.splitext(img_name)[0]

        indices = self.label_to_indices(label=label)

        return image, torch.tensor(indices, dtype=torch.long)

    def label_to_indices(self, label):
        indices = [char_to_idx[char] for char in label if char in char_to_idx]
        return indices


In [7]:
class PlatesRecognized:
    def __init__(self):
        self.correct_plates = 0
        self.all_plates = 0

    def update_state(self, y_true, y_pred):

        for true_label, pred_label in zip(y_true, y_pred):
            if true_label == pred_label:
                self.correct_plates += 1
            self.all_plates += 1

    def result(self):
        return self.correct_plates / self.all_plates if self.all_plates else 0.0

    def reset(self):
        self.correct_plates = 0
        self.all_plates = 0

In [8]:
class SymbolsRecognized:
    def __init__(self):
        self.correct_symbols = 0
        self.all_symbols = 0

    def update_state(self, y_true, y_pred):
        for true_label, pred_label in zip(y_true, y_pred):
            min_len = min(len(true_label), len(pred_label))
            self.correct_symbols += sum(1 for i in range(min_len) if true_label[i] == pred_label[i])
            self.all_symbols += len(true_label)

    def result(self):
        return self.correct_symbols / self.all_symbols if self.all_symbols else 0.0

    def reset(self):
        self.correct_symbols = 0
        self.all_symbols = 0

In [9]:
transform = A.Compose([
    A.Resize(width=200, height=100),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.5),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2(),
])


train_dir = r'D:\Thesis\cache\storage_manager\datasets\ds_2ff4fcefd16749ac9485120c54e61a37\train\img'

train_dataset = PlateDataset(dir_path=train_dir, transform=transform)

def collate_fn(batch):
    images, labels = zip(*batch)

    images = torch.stack(images, 0)

    targets = torch.cat(labels)

    target_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)

    return images, targets, target_lengths, labels  




batch_size = 16

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    pin_memory=True
)



In [10]:
val_dir = r'D:\Thesis\cache\storage_manager\datasets\ds_2ff4fcefd16749ac9485120c54e61a37\val\img'

val_dataset = PlateDataset(dir_path=val_dir, transform=transform)

val_dataloader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    pin_memory=True
)

In [11]:
from clearml import Task

task = Task.init(project_name='thesis', task_name='OCR_EfficientNetV2L')

ClearML Task: created new task id=54d84115e1d04a05a0e67c7833f80fe0
ClearML results page: http://127.0.0.1:8080/projects/60b6a1102bdc4de3bdc5a20e918b0379/experiments/54d84115e1d04a05a0e67c7833f80fe0/output/log


In [12]:
num_classes = len(char_to_idx)

model = EfficientNetV2LOCR(len_chars = num_classes, shape_inp_img=(3, 100, 200)).to(device)


2024-10-27 22:30:42,033 - clearml.Task - INFO - Storing jupyter notebook directly as code




In [13]:
ctc_loss = nn.CTCLoss(blank=char_to_idx['blank'], zero_infinity=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [14]:
plate_metric = PlatesRecognized()
symbol_metric = SymbolsRecognized()

In [15]:
best_loss = float('inf')
num_epochs = 200
for epoch in range(num_epochs):
    plate_metric.reset()
    symbol_metric.reset()
    
    model.train()

    epoch_loss = 0.0
    num_batches = 0

    for batch_idx, (images, labels_flat, label_lengths, labels_str) in enumerate(train_dataloader):
        images = images.to(device)
        labels_flat = labels_flat.to(device)
        label_lengths = label_lengths.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        batch_size = images.size(0)
        input_lengths = torch.full(size=(batch_size,), fill_value=outputs.size(0), dtype=torch.long).to(device)
        loss = ctc_loss(outputs, labels_flat, input_lengths, label_lengths)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        num_batches += 1

        with torch.no_grad():
            decoded_preds = []
            for out in outputs.permute(1, 0, 2):
                out_best = torch.argmax(out, dim=1)
                out_best = torch.unique_consecutive(out_best, dim=0)
                pred_str = ''.join([idx_to_char[idx.item()] for idx in out_best if idx.item() != char_to_idx['blank']])
                decoded_preds.append(pred_str)

            decoded_labels = []
            idx = 0
            for length in label_lengths:
                label = labels_flat[idx:idx+length]
                label_str = ''.join([idx_to_char[idx.item()] for idx in label])
                decoded_labels.append(label_str)
                idx += length.item()

            plate_metric.update_state(decoded_labels, decoded_preds)
            symbol_metric.update_state(decoded_labels, decoded_preds)


        task.logger.report_scalar("Batch Loss", "Batch", loss.item(), batch_idx + epoch * len(train_dataloader))

    avg_epoch_loss = epoch_loss / num_batches

    model.eval()
    val_loss = 0
    val_batches = 0
    val_plate_metric = PlatesRecognized()
    val_symbol_metric = SymbolsRecognized()

    with torch.no_grad():
        for val_batch_idx, (val_images, val_labels_flat, val_label_lengths, val_labels_str) in enumerate(val_dataloader):
            val_images = val_images.to(device)
            val_labels_flat = val_labels_flat.to(device)
            val_label_lengths = val_label_lengths.to(device)

            val_outputs = model(val_images)
            val_batch_size_current = val_images.size(0)
            val_input_lengths = torch.full(size=(val_batch_size_current,), fill_value=val_outputs.size(0), dtype=torch.long).to(device)
            val_loss_batch = ctc_loss(val_outputs, val_labels_flat, val_input_lengths, val_label_lengths)
            val_loss += val_loss_batch.item()
            val_batches += 1
            decoded_val_preds = []
            for out in val_outputs.permute(1, 0, 2):
                out_best = torch.argmax(out, dim=1)
                out_best = torch.unique_consecutive(out_best, dim=0)
                pred_str = ''.join([idx_to_char[idx.item()] for idx in out_best if idx.item() != char_to_idx['blank']])
                decoded_val_preds.append(pred_str)

            decoded_val_labels = []
            idx = 0
            for length in val_label_lengths:
                label = val_labels_flat[idx:idx+length]
                label_str = ''.join([idx_to_char[idx.item()] for idx in label])
                decoded_val_labels.append(label_str)
                idx += length.item()

            val_plate_metric.update_state(decoded_val_labels, decoded_val_preds)
            val_symbol_metric.update_state(decoded_val_labels, decoded_val_preds)

    avg_val_loss = val_loss / val_batches

    task.logger.report_scalar("Epoch Loss", "Train Epoch", avg_epoch_loss, epoch)
    task.logger.report_scalar("Plates Accuracy", "Train Epoch", plate_metric.result(), epoch)
    task.logger.report_scalar("Symbols Accuracy", "Train Epoch", symbol_metric.result(), epoch)

    task.logger.report_scalar("Validation Loss", "Val Epoch", avg_val_loss, epoch)
    task.logger.report_scalar("Validation Plates Accuracy", "Val Epoch", val_plate_metric.result(), epoch)
    task.logger.report_scalar("Validation Symbols Accuracy", "Val Epoch", val_symbol_metric.result(), epoch)

    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_epoch_loss:.4f}, Train Plates Acc: {plate_metric.result():.4f}, Train Symbols Acc: {symbol_metric.result():.4f}")
    print(f"Validation Loss: {avg_val_loss:.4f}, Validation Plates Acc: {val_plate_metric.result():.4f}, Validation Symbols Acc: {val_symbol_metric.result():.4f}")

    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save(model.state_dict(), 'best.pt')
        print(f"Лучшие веса сохранены на эпохе {epoch + 1} с валидационными потерями {avg_val_loss:.4f}")






Epoch 1/200, Train Loss: 2.8821, Train Plates Acc: 0.0000, Train Symbols Acc: 0.0820
Validation Loss: 2.5461, Validation Plates Acc: 0.0000, Validation Symbols Acc: 0.1071
2024-10-27 22:54:45,375 - clearml.frameworks - INFO - Found existing registered model id=b9b2822ae00b4451af728030a9cd934c [d:\Programs\Code\efficient\best.pt] reusing it.
Лучшие веса сохранены на эпохе 1 с валидационными потерями 2.5461
Epoch 2/200, Train Loss: 1.6547, Train Plates Acc: 0.0551, Train Symbols Acc: 0.3950
Validation Loss: 0.7069, Validation Plates Acc: 0.2819, Validation Symbols Acc: 0.7538
Лучшие веса сохранены на эпохе 2 с валидационными потерями 0.7069
Epoch 3/200, Train Loss: 0.5086, Train Plates Acc: 0.4596, Train Symbols Acc: 0.8110
Validation Loss: 0.4310, Validation Plates Acc: 0.5802, Validation Symbols Acc: 0.8664
Лучшие веса сохранены на эпохе 3 с валидационными потерями 0.4310
Epoch 4/200, Train Loss: 0.3629, Train Plates Acc: 0.6494, Train Symbols Acc: 0.8737
Validation Loss: 0.6177, Valid

KeyboardInterrupt: 