In [1]:
from dataloader import TrafficSignDataset, Collator
from model.repvgg import create_RepVGG_A0
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import numpy as np
from tqdm import tqdm
import torch

## 1. Create Traffic Sign Dataset

In [2]:
dataset = TrafficSignDataset(image_dir='./Data/myData/', label_file='./Data/labels.csv', target_shape=(32, 32))
nb_classes = len(np.unique(dataset.labels))
print('------------------------------------------------------')
print('The number of data: {}. The number of classes: {}'.format(len(dataset), nb_classes))

Create data for class Road narrows on the right: 100%|███████| 510/510 [00:00<00:00, 1222340.02it/s]
Create data for class Speed limit (30km/h): 100%|██████████| 4920/4920 [00:00<00:00, 1110476.01it/s]
Create data for class Beware of ice/snow: 100%|██████████████| 840/840 [00:00<00:00, 1454671.91it/s]
Create data for class Turn right ahead: 100%|██████████████| 1288/1288 [00:00<00:00, 1437919.50it/s]
Create data for class Wild animals crossing: 100%|█████████| 1470/1470 [00:00<00:00, 1377178.22it/s]
Create data for class Double curve: 100%|████████████████████| 600/600 [00:00<00:00, 1392685.33it/s]
Create data for class Bumpy road: 100%|██████████████████████| 720/720 [00:00<00:00, 1182882.44it/s]
Create data for class Dangerous curve to the left: 100%|█████| 390/390 [00:00<00:00, 1232689.19it/s]
Create data for class Speed limit (70km/h): 100%|██████████| 3750/3750 [00:00<00:00, 1261318.36it/s]
Create data for class No entry: 100%|██████████████████████| 2100/2100 [00:00<00:00, 140322

------------------------------------------------------
The number of data: 73139. The number of classes: 43


## 2. Split train and validation data

In [3]:
# split train and val dataloader
split_ratio = 0.9
n_train = int(len(dataset) * split_ratio)
n_val = len(dataset) - n_train
train_dataset, val_dataset = random_split(dataset, [n_train, n_val])

In [4]:
print("The number of train data: ", len(train_dataset))
print("The number of val data: ", len(val_dataset))

The number of train data:  65825
The number of val data:  7314


## 3. Define config

In [5]:
batch_size = 3
valid_every = 500
print_every = 100
lr = 0.001
num_iters = 30000
device = ("cuda" if torch.cuda.is_available() else "cpu")

## 4. Create dataloader for loading data

In [6]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=Collator(), shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=Collator(), shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

## 5. Create RepVGG model

In [7]:
repvgg_model = create_RepVGG_A0(num_classes=nb_classes)
repvgg_model = repvgg_model.to(device)

## 6. Define a loss function and optimizer

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(repvgg_model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-09)
scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=num_iters, pct_start=0.1)

## 7. Train the network

In [9]:
def batch_to_device(images, gts):
    images = images.to(device, non_blocking=True)
    gts = gts.to(device, non_blocking=True)
    
    return images, gts

In [10]:
def cal_acc(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [11]:
def validate():
    repvgg_model.eval()
    total_loss = []
    total_acc = []
    
    with torch.no_grad():
        for batch in val_loader:
            images, gts = batch
            images, gts = batch_to_device(images, gts)
            outputs = repvgg_model(images)
            loss = criterion(outputs, gts)
            acc = cal_acc(outputs, gts)
            
            total_loss.append(loss.item())
            total_acc.append(acc)
            
            del outputs
            del loss
            
    val_loss = np.mean(total_loss)
    val_acc = np.mean(total_acc)
    repvgg_model.train()
    
    return val_loss, val_acc

In [12]:
def train_step(batch):
    # get the inputs
    images, gts = batch
    images, gts = batch_to_device(images, gts)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize + scheduler
    outputs = repvgg_model(images)
    loss = criterion(outputs, gts)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(repvgg_model.parameters(), 1) 
    optimizer.step()
    scheduler.step()

    loss_item = loss.item()
    
    return loss_item

In [None]:
total_loss = 0
best_acc = 0
global_step = 0
weight_path = 'repvgg.pth'

data_iter = iter(train_loader)
for i in range(num_iters):
    repvgg_model.train()
    
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(self.train_gen)
        batch = next(data_iter)
        
    global_step += 1
    loss = train_step(batch)
    total_loss += loss

    if global_step % print_every == 0:
        print('step: {:06d}, train_loss: {:.4f}'.format(global_step, total_loss / print_every))
        total_loss = 0
        

    if global_step % valid_every == 0:
        # validate 
        val_loss, val_acc = validate()
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(repvgg_model.state_dict(), weight_path)
            
        print("==============================================================================")
        print("val_loss: {:.4f}, val_acc: {:.4f}".format(val_loss, val_acc))
        print("==============================================================================")