## RepVGG
> I tried to implement RepVGG model and see its results on cifar10 dataset.

> We've reached accuracy of 87 percent on test dataset with just 15 epochs.

In [1]:
import torch
import torchvision
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm import tqdm

In [2]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

CUDA is available!  Training on GPU ...
cuda:0


In [3]:
transform = transforms.Compose(
        [
            transforms.Pad(4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]
)

train_dataset = torchvision.datasets.CIFAR10(
    root=".", train=True, transform=transform, download=True
)
test_dataset = torchvision.datasets.CIFAR10(
    root=".", train=False, transform=test_transform
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=100, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=100, shuffle=False
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to .


In [23]:
def conv_batch_norm(in_channels, out_channels, kernel_size, stride, padding):
    result = torch.nn.Sequential()
    result.add_module('conv', torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
          kernel_size=kernel_size, stride=stride, padding=padding, groups=1, bias=False))
    result.add_module('batch_norm', torch.nn.BatchNorm2d(num_features=out_channels))
    return result

In [24]:
class RepVGGBlockModule(torch.nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros'):
        super(RepVGGBlockModule, self).__init__()
        self.groups = groups
        self.in_channels = in_channels
        padding_11 = padding - kernel_size // 2

        self.nonlinearity = torch.nn.ReLU()
        self.se = torch.nn.Identity()
        self.rbr_identity = torch.nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
        self.rbr_dense = conv_batch_norm(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.rbr_1x1 = conv_batch_norm(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11)
        

    def forward(self, inputs):
        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)

        return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))

In [25]:
class RepVGGModule(torch.nn.Module):
    def __init__(self, num_classes=10):
        super(RepVGGModule, self).__init__()
        self.width_multiplier=[2.5, 2.5, 2.5, 5]
        self.num_blocks = [4, 6, 16, 1]
        self.override_groups_map = dict()
        
        self.in_planes = min(64, int(64 * self.width_multiplier[0]))
        self.level_0 = RepVGGBlockModule(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1)
        self.cur_layer_idx = 1
        self.level_1 = self._level(int(64 * self.width_multiplier[0]), self.num_blocks[0], stride=2)
        self.level_2 = self._level(int(128 * self.width_multiplier[1]), self.num_blocks[1], stride=2)
        self.level_3 = self._level(int(256 * self.width_multiplier[2]), self.num_blocks[2], stride=2)
        self.level_4 = self._level(int(512 * self.width_multiplier[3]), self.num_blocks[3], stride=2)
        self.gap = torch.nn.AdaptiveAvgPool2d(output_size=1)
        self.linear = torch.nn.Linear(int(512 * self.width_multiplier[3]), num_classes)

    def _level(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        blocks = []
        for stride in strides:
            cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
            blocks.append(RepVGGBlockModule(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
                                      stride=stride, padding=1, groups=cur_groups))
            self.in_planes = planes
            self.cur_layer_idx += 1
        return torch.nn.ModuleList(blocks)

    def forward(self, x):
        out = self.level_0(x)
        for l in (self.level_1, self.level_2, self.level_3, self.level_4):
            for block in l:
                out = block(out)
        out = self.gap(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [26]:
model = RepVGGModule(num_classes=10).to(device)


In [27]:
batch_size = 64
epochs = 15
lr = 3e-4
gamma= 0.7
seed = 42

In [28]:
criterion = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr = lr)

scheduler = StepLR(optimizer, step_size=1, gamma=gamma)


In [None]:
for epoch in range(epochs):
    epoch_loss = 0.0
    epoch_accuracy = 0.0
    
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)
        
        output = model(data)
        
        label = torch.nn.functional.one_hot(label, num_classes = 10)
        label = label.squeeze_()
        label = torch.argmax(label, axis=1)
        
        label = label.type_as(output)
        
        loss = criterion(output, label.long())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0        
        
        for data, label in test_loader:
            data = data.to(device)
            
            label = label.to(device)
            label = label.type_as(output)
            
            val_output = model(data)
            val_loss = criterion(val_output, label.long())
            
            accc = (val_output.argmax(dim = 1) == label).float().mean()
            epoch_val_accuracy += acc / len(test_loader)
            epoch_val_loss += val_loss / len(test_loader)

    print(
        f"Epoch: {epoch+1} - loss: {epoch_loss:.4f} - acc : {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy: .4f}\n"
    )

100%|██████████| 500/500 [01:26<00:00,  5.79it/s]


Epoch: 1 - loss: 1.7129 - acc : 0.3670 - val_loss : 1.3872 - val_acc:  0.4700



100%|██████████| 500/500 [01:23<00:00,  5.96it/s]


Epoch: 2 - loss: 1.3123 - acc : 0.5301 - val_loss : 1.1890 - val_acc:  0.5500



100%|██████████| 500/500 [01:22<00:00,  6.03it/s]


Epoch: 3 - loss: 1.0993 - acc : 0.6142 - val_loss : 1.0271 - val_acc:  0.6300



100%|██████████| 500/500 [01:22<00:00,  6.04it/s]


Epoch: 4 - loss: 0.9601 - acc : 0.6674 - val_loss : 0.8992 - val_acc:  0.6900



100%|██████████| 500/500 [01:22<00:00,  6.06it/s]


Epoch: 5 - loss: 0.8286 - acc : 0.7137 - val_loss : 0.8041 - val_acc:  0.6900



100%|██████████| 500/500 [01:22<00:00,  6.04it/s]


Epoch: 6 - loss: 0.7533 - acc : 0.7408 - val_loss : 0.7633 - val_acc:  0.7300



100%|██████████| 500/500 [01:22<00:00,  6.05it/s]


Epoch: 7 - loss: 0.6860 - acc : 0.7647 - val_loss : 0.7209 - val_acc:  0.8700



100%|██████████| 500/500 [01:23<00:00,  5.99it/s]


Epoch: 8 - loss: 0.6560 - acc : 0.7760 - val_loss : 0.6623 - val_acc:  0.8200



100%|██████████| 500/500 [01:23<00:00,  5.98it/s]


Epoch: 9 - loss: 0.5983 - acc : 0.7974 - val_loss : 0.6576 - val_acc:  0.7700



100%|██████████| 500/500 [01:22<00:00,  6.08it/s]


Epoch: 10 - loss: 0.5756 - acc : 0.8032 - val_loss : 0.6247 - val_acc:  0.8100



100%|██████████| 500/500 [01:22<00:00,  6.05it/s]


Epoch: 11 - loss: 0.5503 - acc : 0.8134 - val_loss : 0.6122 - val_acc:  0.8600



100%|██████████| 500/500 [01:22<00:00,  6.03it/s]


Epoch: 12 - loss: 0.5128 - acc : 0.8243 - val_loss : 0.6336 - val_acc:  0.7600



100%|██████████| 500/500 [01:22<00:00,  6.02it/s]


Epoch: 13 - loss: 0.4876 - acc : 0.8314 - val_loss : 0.6200 - val_acc:  0.8200



100%|██████████| 500/500 [01:24<00:00,  5.95it/s]


Epoch: 14 - loss: 0.4999 - acc : 0.8310 - val_loss : 0.5777 - val_acc:  0.8900



100%|██████████| 500/500 [01:23<00:00,  5.97it/s]


Epoch: 15 - loss: 0.4477 - acc : 0.8469 - val_loss : 0.5570 - val_acc:  0.8700

