In [1]:
from dataloader import TrafficSignDataset, Collator
from model.repvgg import create_RepVGG_A0
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import numpy as np
from tqdm import tqdm
import torch

## 1. Create Traffic Sign Dataset

In [2]:
dataset = TrafficSignDataset(image_dir='./Data/myData/', label_file='./Data/labels.csv', target_shape=(32, 32))
nb_classes = len(np.unique(dataset.labels))
print('------------------------------------------------------')
print('The number of data: {}. The number of classes: {}'.format(len(dataset), nb_classes))

Create data for class Pedestrians: 100%|██████████████████████| 450/450 [00:00<00:00, 977946.53it/s]
Create data for class Road work: 100%|█████████████████████| 2850/2850 [00:00<00:00, 1158310.70it/s]
Create data for class Ahead only: 100%|████████████████████| 2280/2280 [00:00<00:00, 1221173.94it/s]
Create data for class Children crossing: 100%|█████████████| 1020/1020 [00:00<00:00, 1132395.47it/s]
Create data for class End of all speed and passing limits: 100%|█| 450/450 [00:00<00:00, 1122138.41i
Create data for class Road narrows on the right: 100%|███████| 510/510 [00:00<00:00, 1097534.65it/s]
Create data for class Double curve: 100%|████████████████████| 600/600 [00:00<00:00, 1081006.19it/s]
Create data for class Speed limit (100km/h): 100%|█████████| 2730/2730 [00:00<00:00, 1155443.99it/s]
Create data for class Keep right: 100%|████████████████████| 3930/3930 [00:00<00:00, 1180943.88it/s]
Create data for class Wild animals crossing: 100%|█████████| 1470/1470 [00:00<00:00, 116113

------------------------------------------------------
The number of data: 73139. The number of classes: 43





## 2. Split train and validation data

In [3]:
# split train and val dataloader
split_ratio = 0.9
n_train = int(len(dataset) * split_ratio)
n_val = len(dataset) - n_train
train_dataset, val_dataset = random_split(dataset, [n_train, n_val])

In [4]:
print("The number of train data: ", len(train_dataset))
print("The number of val data: ", len(val_dataset))

The number of train data:  65825
The number of val data:  7314


## 3. Define config

In [5]:
batch_size = 128
valid_every = 2000
print_every = 500
lr = 0.001
num_iters = 60000
device = ("cuda" if torch.cuda.is_available() else "cpu")

## 4. Create dataloader for loading data

In [6]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=Collator(), shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=Collator(), shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

## 5. Create RepVGG model

In [7]:
repvgg_model = create_RepVGG_A0(num_classes=nb_classes)
repvgg_model = repvgg_model.to(device)

## 6. Define a loss function and optimizer

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(repvgg_model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-09)
scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=num_iters, pct_start=0.1)

## 7. Train the network

In [9]:
def batch_to_device(images, gts):
    images = images.to(device, non_blocking=True)
    gts = gts.to(device, non_blocking=True)
    
    return images, gts

In [10]:
def cal_acc(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [11]:
def validate():
    repvgg_model.eval()
    total_loss = []
    total_acc = []
    
    with torch.no_grad():
        for batch in val_loader:
            images, gts = batch
            images, gts = batch_to_device(images, gts)
            outputs = repvgg_model(images)
            loss = criterion(outputs, gts)
            acc = cal_acc(outputs, gts)
            
            total_loss.append(loss.item())
            total_acc.append(acc)
            
            del outputs
            del loss
            
    val_loss = np.mean(total_loss)
    val_acc = np.mean(total_acc)
    repvgg_model.train()
    
    return val_loss, val_acc

In [12]:
def train_step(batch):
    # get the inputs
    images, gts = batch
    images, gts = batch_to_device(images, gts)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize + scheduler
    outputs = repvgg_model(images)
    loss = criterion(outputs, gts)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(repvgg_model.parameters(), 1) 
    optimizer.step()
    scheduler.step()

    loss_item = loss.item()
    
    return loss_item

In [13]:
total_loss = 0
best_acc = 0
global_step = 0
weight_path = 'repvgg.pth'

data_iter = iter(train_loader)
for i in range(num_iters):
    repvgg_model.train()
    
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(train_loader)
        batch = next(data_iter)
        
    global_step += 1
    loss = train_step(batch)
    total_loss += loss

    if global_step % print_every == 0:
        print('step: {:06d}, train_loss: {:.4f}'.format(global_step, total_loss / print_every))
        total_loss = 0
        

    if global_step % valid_every == 0:
        # validate 
        val_loss, val_acc = validate()
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(repvgg_model.state_dict(), weight_path)
            
        print("==============================================================================")
        print("val_loss: {:.4f}, val_acc: {:.4f}".format(val_loss, val_acc))
        print("==============================================================================")

step: 000500, train_loss: 2.7534
step: 001000, train_loss: 0.9924
step: 001500, train_loss: 0.3827
step: 002000, train_loss: 0.2445
val_loss: 0.2385, val_acc: 0.9328
step: 002500, train_loss: 0.1895
step: 003000, train_loss: 0.1556
step: 003500, train_loss: 0.1186
step: 004000, train_loss: 0.1007
val_loss: 0.0902, val_acc: 0.9762
step: 004500, train_loss: 0.0748
step: 005000, train_loss: 0.0559
step: 005500, train_loss: 0.0512
step: 006000, train_loss: 0.0370
val_loss: 0.0529, val_acc: 0.9841
step: 006500, train_loss: 0.0295
step: 007000, train_loss: 0.0248
step: 007500, train_loss: 0.0251
step: 008000, train_loss: 0.0262
val_loss: 0.0269, val_acc: 0.9930
step: 008500, train_loss: 0.0157
step: 009000, train_loss: 0.0180
step: 009500, train_loss: 0.0137
step: 010000, train_loss: 0.0136
val_loss: 0.0221, val_acc: 0.9955
step: 010500, train_loss: 0.0120
step: 011000, train_loss: 0.0104
step: 011500, train_loss: 0.0094
step: 012000, train_loss: 0.0107
val_loss: 0.1320, val_acc: 0.9949
step

step: 052000, train_loss: 0.0000
val_loss: 37.5417, val_acc: 0.9963
step: 052500, train_loss: 0.0000
step: 053000, train_loss: 0.0000
step: 053500, train_loss: 0.0001
step: 054000, train_loss: 0.0000
val_loss: 28.9426, val_acc: 0.9962
step: 054500, train_loss: 0.0000
step: 055000, train_loss: 0.0000
step: 055500, train_loss: 0.0000
step: 056000, train_loss: 0.0000
val_loss: 37.2928, val_acc: 0.9962
step: 056500, train_loss: 0.0000
step: 057000, train_loss: 0.0000
step: 057500, train_loss: 0.0000
step: 058000, train_loss: 0.0000
val_loss: 641.1455, val_acc: 0.9952
step: 058500, train_loss: 0.0000
step: 059000, train_loss: 0.0000
step: 059500, train_loss: 0.0000
step: 060000, train_loss: 0.0000
val_loss: 45.0168, val_acc: 0.9959
