## Train found architecture from scratch

## Train found architecture from scratch

In [1]:
from typing import Tuple, Union, List

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from thop import profile
from torch.utils.data import DataLoader
from tqdm import tqdm

from cifar10 import get_train_transform, get_val_transform
from resnet import resnet18

from supernet import supernet18

import time


torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

train_transform = get_train_transform()
val_transform = get_val_transform()

# Change this value if needed.
batch_size = 512

train_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform,
)
test_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=val_transform,
)

train_dataloader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)
test_dataloader = DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    drop_last=False,
)

Files already downloaded and verified
Files already downloaded and verified


In [2]:
def train_one_epoch(
        model: nn.Module,
        criterion: nn.Module,
        dataloader: DataLoader,
        optimizer: optim.Optimizer,
        scheduler,
        device: torch.device,
        epoch: int,
) -> Tuple[float, float]:
    model.train()

    total_loss = 0.0
    total_correct = 0.0
    total_samples = 0

    wrapped_dataloader = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (inputs, labels) in wrapped_dataloader:
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        optimizer.zero_grad()

        logits = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        with torch.no_grad():
            _, predicted_labels = torch.max(logits, 1)
            total_loss += loss.item()
            total_correct += (predicted_labels == labels).sum().item()
            total_samples += labels.shape[0]

        wrapped_dataloader.set_description(
            f'(train) Epoch={epoch}, lr={scheduler.get_last_lr()[0]:.4f} loss={total_loss / (i + 1):.3f}'
        )

    return total_loss / len(dataloader), total_correct / total_samples

In [3]:
@torch.no_grad()
def validate_one_epoch(
        model: nn.Module,
        criterion: nn.Module,
        dataloader: DataLoader,
        device: torch.device,
        epoch: int,
) -> Tuple[float, float]:
    model.eval()

    total_loss = 0.0
    total_correct = 0.0
    total_samples = 0

    #wrapped_dataloader = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        logits = model(inputs)
        loss = criterion(logits, labels)
        _, predicted_labels = torch.max(logits, 1)
        total_loss += loss.item()
        total_correct += (predicted_labels == labels).sum().item()
        total_samples += labels.shape[0]

        #wrapped_dataloader.set_description(f'(val) Epoch={epoch}, loss={total_loss / (i + 1):.3f}')

    return total_loss / len(dataloader), total_correct / total_samples


In [4]:
best_architecture = [2, 0, 0, 0, 0, 0, 1, 0] 
device = torch.device('cuda:0')

channel_multipliers = [0.5, 1.0, 2.0]
supernet = supernet18(num_classes=10, zero_init_residual=True, channel_multipliers=channel_multipliers)

supernet.sample(best_architecture)
supernet.to(device=device)

macs, params = profile(supernet, inputs=(torch.zeros(1, 3, 32, 32, device=device),))
print(f'Number of macs: {macs / 1e6:.2f}M, number of parameters: {params / 1e6:.2f}M')

n_epochs = 20
criterion = nn.CrossEntropyLoss()

lr = 0.25
weight_decay = 5e-4
momentum = 0.9
n_epochs = 20  # Longer training gives better results, but let's keep baseline model epochs to 20.

optimizer = optim.SGD(supernet.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_dataloader) * n_epochs)

for epoch in range(n_epochs):
    print(f'Epoch: {epoch}')
    loss, accuracy = train_one_epoch(supernet, criterion, train_dataloader, optimizer, scheduler, device, epoch)
    print(f'train_loss={loss:.4f}, train_accuracy={accuracy:.3%}')
    loss, accuracy = validate_one_epoch(supernet, criterion, test_dataloader, device, epoch)
    print(f'test_loss={loss:.4f}, test_accuracy={accuracy:.3%}')


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Number of macs: 28.96M, number of parameters: 7.57M
Epoch: 0


(train) Epoch=0, lr=0.2485 loss=2.022: 100%|██████████| 97/97 [00:02<00:00, 33.27it/s]


train_loss=2.0217, train_accuracy=27.130%
test_loss=1.6425, test_accuracy=39.060%
Epoch: 1


(train) Epoch=1, lr=0.2439 loss=1.458: 100%|██████████| 97/97 [00:02<00:00, 45.82it/s]


train_loss=1.4576, train_accuracy=46.382%
test_loss=1.4224, test_accuracy=49.520%
Epoch: 2


(train) Epoch=2, lr=0.2364 loss=1.240: 100%|██████████| 97/97 [00:02<00:00, 46.26it/s]

train_loss=1.2400, train_accuracy=55.344%





test_loss=1.2361, test_accuracy=56.910%
Epoch: 3


(train) Epoch=3, lr=0.2261 loss=1.104: 100%|██████████| 97/97 [00:02<00:00, 43.87it/s]

train_loss=1.1038, train_accuracy=60.287%





test_loss=1.0349, test_accuracy=62.990%
Epoch: 4


(train) Epoch=4, lr=0.2134 loss=1.006: 100%|██████████| 97/97 [00:02<00:00, 46.28it/s]

train_loss=1.0057, train_accuracy=64.133%





test_loss=1.0136, test_accuracy=64.660%
Epoch: 5


(train) Epoch=5, lr=0.1985 loss=0.936: 100%|██████████| 97/97 [00:02<00:00, 45.46it/s]

train_loss=0.9363, train_accuracy=66.904%





test_loss=0.9572, test_accuracy=67.000%
Epoch: 6


(train) Epoch=6, lr=0.1817 loss=0.874: 100%|██████████| 97/97 [00:02<00:00, 44.39it/s]

train_loss=0.8741, train_accuracy=69.036%





test_loss=0.9757, test_accuracy=66.120%
Epoch: 7


(train) Epoch=7, lr=0.1636 loss=0.823: 100%|██████████| 97/97 [00:02<00:00, 45.17it/s]

train_loss=0.8231, train_accuracy=71.253%





test_loss=0.9083, test_accuracy=68.640%
Epoch: 8


(train) Epoch=8, lr=0.1446 loss=0.783: 100%|██████████| 97/97 [00:02<00:00, 45.22it/s]

train_loss=0.7833, train_accuracy=72.342%





test_loss=0.7720, test_accuracy=73.020%
Epoch: 9


(train) Epoch=9, lr=0.1250 loss=0.733: 100%|██████████| 97/97 [00:02<00:00, 45.16it/s]

train_loss=0.7331, train_accuracy=74.034%





test_loss=0.8168, test_accuracy=71.950%
Epoch: 10


(train) Epoch=10, lr=0.1054 loss=0.699: 100%|██████████| 97/97 [00:02<00:00, 45.38it/s]

train_loss=0.6990, train_accuracy=75.554%





test_loss=0.6933, test_accuracy=75.770%
Epoch: 11


(train) Epoch=11, lr=0.0864 loss=0.661: 100%|██████████| 97/97 [00:02<00:00, 45.31it/s]

train_loss=0.6614, train_accuracy=76.707%





test_loss=0.6911, test_accuracy=76.170%
Epoch: 12


(train) Epoch=12, lr=0.0683 loss=0.620: 100%|██████████| 97/97 [00:02<00:00, 43.97it/s]

train_loss=0.6203, train_accuracy=78.127%





test_loss=0.6533, test_accuracy=77.290%
Epoch: 13


(train) Epoch=13, lr=0.0515 loss=0.583: 100%|██████████| 97/97 [00:02<00:00, 43.82it/s]

train_loss=0.5835, train_accuracy=79.822%





test_loss=0.6503, test_accuracy=77.590%
Epoch: 14


(train) Epoch=14, lr=0.0366 loss=0.544: 100%|██████████| 97/97 [00:02<00:00, 39.57it/s]

train_loss=0.5438, train_accuracy=80.825%





test_loss=0.5959, test_accuracy=79.630%
Epoch: 15


(train) Epoch=15, lr=0.0239 loss=0.504: 100%|██████████| 97/97 [00:02<00:00, 38.75it/s]

train_loss=0.5045, train_accuracy=82.243%





test_loss=0.5505, test_accuracy=81.150%
Epoch: 16


(train) Epoch=16, lr=0.0136 loss=0.482: 100%|██████████| 97/97 [00:02<00:00, 44.36it/s]

train_loss=0.4819, train_accuracy=83.046%





test_loss=0.5359, test_accuracy=81.630%
Epoch: 17


(train) Epoch=17, lr=0.0061 loss=0.450: 100%|██████████| 97/97 [00:02<00:00, 43.31it/s]

train_loss=0.4503, train_accuracy=84.264%





test_loss=0.5220, test_accuracy=81.890%
Epoch: 18


(train) Epoch=18, lr=0.0015 loss=0.436: 100%|██████████| 97/97 [00:02<00:00, 44.60it/s]

train_loss=0.4355, train_accuracy=84.796%





test_loss=0.5116, test_accuracy=82.570%
Epoch: 19


(train) Epoch=19, lr=0.0000 loss=0.426: 100%|██████████| 97/97 [00:02<00:00, 39.20it/s]

train_loss=0.4256, train_accuracy=85.116%





test_loss=0.5105, test_accuracy=82.380%
