## Neural Architecture Search - Supernet training

**Before running code in this section, you need to finish supernet implementation.**

Please, go to `supernet.py` file and inspect the current implementation of SearchBlock and Supernet classes.
Pay attention to the TODOs. You need to implement all of them.

Supernet and BasicBlock classes are modified versions of ResNet and BasicBlock classes from `resnet.py`.

    Tip: to understand how the Supernet is constructed, compare the implementation of Supernet and ResNet classes. You should probably use diff tool in your IDE or something.

Task: briefly describe the differences made to construct supernet.

Differences:
1. Bottleneck block is replaced by SearchBlock
2. Supernet is designed for search the optimal SearchBlock structure per layer.

In [22]:
from typing import Tuple, Union, List

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from thop import profile
from torch.utils.data import DataLoader
from tqdm import tqdm

from cifar10 import get_train_transform, get_val_transform
from resnet import resnet18

from supernet import supernet18

import time

In [23]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

In [24]:
channel_multipliers = [0.5, 1.0, 2.0]

In [25]:
device = torch.device('cuda:0')

supernet = supernet18(num_classes=10, zero_init_residual=True, channel_multipliers=channel_multipliers)
supernet.to(device=device)

Supernet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): SearchBlock(
      (ops): ModuleList(
        (0): BasicBlock(
          (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05,

In [26]:
train_transform = get_train_transform()
val_transform = get_val_transform()

batch_size = 512

train_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform,
)
test_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=val_transform,
)

train_dataloader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)
test_dataloader = DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    drop_last=False,
)

Files already downloaded and verified
Files already downloaded and verified


### Select hyperparameters

### Build optimizer and scheduler

In [27]:
criterion = nn.CrossEntropyLoss()
lr = 0.25
weight_decay = 5e-4
momentum = 0.9
n_epochs = 60  # Longer training gives better results, but let's keep baseline model epochs to 20.

# TODO: build optimizer and scheduler for supernet training.
optimizer = optim.SGD(supernet.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_dataloader) * n_epochs)

### Define training function

In [28]:
# TODO: Copy `train_one_epoch` function here, rename to `pretrain_one_epoch`.
#       Call `model.sample_random_architecture()` before making forward pass on each batch.
def pretrain_one_epoch(
        model: nn.Module,
        criterion: nn.Module,
        dataloader: DataLoader,
        optimizer: optim.Optimizer,
        scheduler,
        device: torch.device,
        epoch: int,
) -> Tuple[float, float]:
    model.train()

    total_loss = 0.0
    total_correct = 0.0
    total_samples = 0

    #wrapped_dataloader = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        model.sample_random_architecture()
        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        with torch.no_grad():
            _, predicted_labels = torch.max(logits, 1)
            total_loss += loss.item()
            total_correct += (predicted_labels == labels).sum().item()
            total_samples += labels.shape[0]

        # wrapped_dataloader.set_description(
        #     f'(train) Epoch={epoch}, lr={scheduler.get_last_lr()[0]:.4f} loss={total_loss / (i + 1):.3f}'
        # )

    return total_loss / len(dataloader), total_correct / total_samples


@torch.no_grad()
def validate_one_epoch(
        model: nn.Module,
        criterion: nn.Module,
        dataloader: DataLoader,
        device: torch.device,
        epoch: int,
) -> Tuple[float, float]:
    model.eval()

    total_loss = 0.0
    total_correct = 0.0
    total_samples = 0

    #wrapped_dataloader = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        logits = model(inputs)
        loss = criterion(logits, labels)
        _, predicted_labels = torch.max(logits, 1)
        total_loss += loss.item()
        total_correct += (predicted_labels == labels).sum().item()
        total_samples += labels.shape[0]

        #wrapped_dataloader.set_description(f'(val) Epoch={epoch}, loss={total_loss / (i + 1):.3f}')

    return total_loss / len(dataloader), total_correct / total_samples

def random_search(
        trained_supernet: nn.Module,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        device: torch.device,
        n_architectures_to_test: int,
        target_latency: float,) -> Union[float, List[float]]:
    # TODO: Implement Random search.
    # This function should evaluate `n_architectures_to_test` architectures
    # using the `trained_supernet` on the whole validation dataset. The resulting
    # architecture should have latency not greater than `target_latency`.
    # Rank architectures by validation accuracy.
    best_accuracy = 0.0
    best_architecture = []
    
    for _ in tqdm(range(n_architectures_to_test)):
        # Sample a random architecture
        trained_supernet.sample_random_architecture()
        current_architecture = [block.active_op_index for block in trained_supernet.search_blocks]
        
        # Measure its latency
        macs, params = profile(trained_supernet, inputs=(torch.zeros(1, 3, 32, 32, device=device),), verbose=False)

        # Skip this architecture if it exceeds the target latency
        if macs > target_latency:
            continue
        
        # Evaluate the architecture
        trained_supernet.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_dataloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = trained_supernet(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = correct / total
        
        # Update best architecture if this one is better
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_architecture = current_architecture
    
    return best_accuracy, best_architecture

In [32]:
n_epochs = 80
for epoch in tqdm(range(n_epochs)):
    loss, accuracy = pretrain_one_epoch(supernet, criterion, train_dataloader, optimizer, scheduler, device, epoch)
    if epoch % 20 ==0:
        print(f'Epoch: {epoch}')
        print(f'train_loss={loss:.4f}, train_accuracy={accuracy:.3%}')
    loss, accuracy = validate_one_epoch(supernet, criterion, test_dataloader, device, epoch)
    if epoch % 20 ==0:
        print(f'test_loss={loss:.4f}, test_accuracy={accuracy:.3%}')


n_architectures_to_test = 1000
target_latency = 30*1e6

accuracy, best_architecture = random_search(supernet, train_dataloader, test_dataloader, device, n_architectures_to_test, target_latency)
print(f'best architecture: {best_architecture} (test_accuracy={accuracy:.3%})')


torch.save(supernet.state_dict(), 'supernet.pth')

  0%|          | 0/80 [00:00<?, ?it/s]

Epoch: 0
train_loss=0.6225, train_accuracy=78.429%


  1%|▏         | 1/80 [00:03<04:00,  3.05s/it]

test_loss=0.6860, test_accuracy=76.700%


 25%|██▌       | 20/80 [01:00<03:01,  3.03s/it]

Epoch: 20
train_loss=0.4353, train_accuracy=84.925%


 26%|██▋       | 21/80 [01:03<02:59,  3.04s/it]

test_loss=0.4989, test_accuracy=83.390%


 50%|█████     | 40/80 [02:00<02:01,  3.03s/it]

Epoch: 40
train_loss=0.3217, train_accuracy=88.785%


 51%|█████▏    | 41/80 [02:03<01:59,  3.06s/it]

test_loss=0.4294, test_accuracy=85.670%


 75%|███████▌  | 60/80 [03:02<01:01,  3.06s/it]

Epoch: 60
train_loss=0.4311, train_accuracy=85.005%


 76%|███████▋  | 61/80 [03:05<00:58,  3.09s/it]

test_loss=0.5342, test_accuracy=82.810%


100%|██████████| 80/80 [04:04<00:00,  3.06s/it]
100%|██████████| 1000/1000 [00:27<00:00, 35.75it/s]

best architecture: [2, 0, 0, 0, 0, 0, 1, 0] (test_accuracy=74.480%)





In [33]:
macs, params = profile(supernet, inputs=(torch.zeros(1, 3, 32, 32, device=device),))
print(f'Number of macs: {macs / 1e6:.2f}M, number of parameters: {params / 1e6:.2f}M')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Number of macs: 28.96M, number of parameters: 5.91M
