# Automated Neural Architecture Search with BootstrapNAS

This notebook demonstrates how to use [BootstrapNAS](https://arxiv.org/abs/2112.10878), a capability in NNCF to generate weight-sharing super-networks from pre-trained models. Once the super-network has been generated, BootstrapNAS can train it and search for efficient sub-networks. 

We will use [MobileNet-V2](https://arxiv.org/abs/1801.04381) pre-trained with CIFAR-10. MobileNet-V2 is an efficient mobile architecture based on inverted residual blocks. Our goal is to discover alternative models, a.k.a., subnetworks, that perform better than the input pre-trained model. 

## Imports and Settings

Import NNCF and all auxiliary packages from your Python code.
Set a name for the model, and the image width and height that will be used for the network. Also define paths where PyTorch, ONNX and OpenVINO IR versions of the models will be stored. 

> NOTE: All NNCF logging messages below ERROR level (INFO and WARNING) are disabled to simplify the tutorial. For production use, it is recommended to enable logging, by removing ```set_log_level(logging.ERROR)```.

In [None]:
import sys
import time
import warnings  # to disable warnings on export to ONNX
import zipfile
from pathlib import Path
import logging

import torch
import nncf  # Important - should be imported directly after torch

import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

from bootstrapnas_utils import MobileNetV2

from nncf.common.utils.logger import set_log_level
set_log_level(logging.ERROR)  # Disables all NNCF info and warning messages

from nncf import NNCFConfig
from nncf.config.structures import BNAdaptationInitArgs
from nncf.experimental.torch.nas.bootstrapNAS import EpochBasedTrainingAlgorithm
from nncf.experimental.torch.nas.bootstrapNAS import SearchAlgorithm
from nncf.torch import create_compressed_model, register_default_init_args
from nncf.torch.initialization import wrap_dataloader_for_init
from nncf.torch.model_creation import create_nncf_network

from openvino.runtime import Core
from torch.jit import TracerWarning

sys.path.append("../utils")
from notebook_utils import download_file

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

MODEL_DIR = Path("model")
OUTPUT_DIR = Path("output")
DATA_DIR = Path("data")
BASE_MODEL_NAME = "mobilenet-V2"
image_size = 32

OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)
DATA_DIR.mkdir(exist_ok=True)

# Paths where models will be stored
fp32_pth_path = Path(MODEL_DIR / (BASE_MODEL_NAME + "_fp32")).with_suffix(".pth")
model_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME )).with_suffix(".onnx")
supernet_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_supernet")).with_suffix(".onnx")
subnet_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_subnet")).with_suffix(".onnx")

# It's possible to train FP32 model from scratch, but it might be slow. So the pre-trained weights are downloaded by default.
pretrained_on_cifar10 = True
fp32_pth_url = "http://hsw1.jf.intel.com/share/bootstrapNAS/checkpoints/cifar10/mobilenet_v2.pt"
download_file(fp32_pth_url, directory=MODEL_DIR, filename=fp32_pth_path.name)

Download CIFAR-10 dataset
* 60000 images of shape 3x32x32
* 10 different classes: airplane, automobile, etc. 6000 images per class. 

In [None]:
DATASET_DIR = DATA_DIR / "cifar10"

image_size = 32
size = int(image_size / 0.875)
normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                                         std=(0.2471, 0.2435, 0.2616))
list_val_transforms = [
            transforms.ToTensor(),
            normalize
        ]
val_transform = transforms.Compose(list_val_transforms)

list_train_transforms = [
            transforms.RandomCrop(image_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]
    
train_transform = transforms.Compose(list_train_transforms)
    
download = False 
if not DATASET_DIR.exists(): 
    download = True

train_dataset = datasets.CIFAR10(DATASET_DIR, train=True, transform=train_transform, download=download)
val_dataset = datasets.CIFAR10(DATASET_DIR, train=False, transform=val_transform, download=download)

batch_size_val = 1000
batch_size = 64
workers = 4
pin_memory = device != 'cpu'
val_sampler = torch.utils.data.SequentialSampler(val_dataset) 
train_sampler = None
train_shuffle = None

val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size_val, shuffle=False,
        num_workers=workers, pin_memory=pin_memory,
        sampler=val_sampler, drop_last=False)

train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=train_shuffle,
            num_workers=workers, pin_memory=pin_memory, sampler=train_sampler, drop_last=True)


<!-- ## Pre-train Floating-Point Model
Using NNCF for model compression assumes that the user has a pre-trained model and a training pipeline.

Here we demonstrate one possible training pipeline: a ResNet-18 model pre-trained on 1000 classes from ImageNet is fine-tuned with 200 classes from Tiny-Imagenet. 

Subsequently, the training and validation functions will be reused as is for quantization-aware training.
 -->
 
 ## Super-network training pipeline
 
Using NNCF for model compression assumes that the user has a pre-trained model and a training pipeline. Next, we demonstrate one possible training pipeline:


### Train Function

In [None]:
def train_epoch(train_loader, model, criterion, optimizer, epoch, compression_ctrl, train_iters=None):
    batch_time = AverageMeter("Time", ":3.3f")
    losses = AverageMeter("Loss", ":2.3f")
    top1 = AverageMeter("Acc@1", ":2.2f")
    top5 = AverageMeter("Acc@5", ":2.2f")
    progress = ProgressMeter(
        len(train_loader), [batch_time, losses, top1, top5], prefix="Epoch:[{}]".format(epoch)
    )

    # switch to train mode
    model.train()

    compression_scheduler = compression_ctrl.scheduler

    if train_iters is None:
        train_iters = len(train_loader)

    end = time.time()
    for i, (images, target) in enumerate(train_loader):

        compression_scheduler.step()

        images = images.to(device)
        target = target.to(device)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do opt step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        print_frequency = 50
        if i % print_frequency == 0:
            progress.display(i)

        if i >= train_iters:
            break




### Validate Function

In [None]:
def validate(model, val_loader, criterion=nn.CrossEntropyLoss()):
    batch_time = AverageMeter("Time", ":3.3f")
    losses = AverageMeter("Loss", ":2.3f")
    top1 = AverageMeter("Acc@1", ":2.2f")
    top5 = AverageMeter("Acc@5", ":2.2f")
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5], prefix="Test: ")

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            print_frequency = 10
            if i % print_frequency == 0:
                progress.display(i)

        print(" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5))
    return top1.avg, top5.avg, losses.val 

### Helpers

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

## Generate Super-network from pre-trained model

In [None]:
import warnings
warnings.filterwarnings("ignore")

model = MobileNetV2()
state_dict = torch.load(fp32_pth_path)
model.load_state_dict(state_dict)

model.to(device)

# Test exporting original model to ONNX
dummy_input = torch.randn(1, 3, image_size, image_size).to(device)

criterion = nn.CrossEntropyLoss()

model_top1_acc, _, _ = validate(model, val_loader, criterion) 

train_steps = 10

config = {
            "device": device,
            "input_info": {
                "sample_size": [1, 3, 32, 32],
            },
            "checkpoint_save_dir": OUTPUT_DIR,
            "bootstrapNAS": {
                "training": {
                    # "algorithm": "progressive_shrinking",
                    "batchnorm_adaptation": {
                        "num_bn_adaptation_samples": 2
                    },
                    "schedule": {
                        "list_stage_descriptions": [
                            {"train_dims": ["depth"], "epochs": 1},
                            # {"train_dims": ["depth"], "epochs": 1, "depth_indicator": 2},
                            # {"train_dims": ["depth", "width"], "epochs": 1, "depth_indicator": 2, "reorg_weights": True, "width_indicator": 2}
                        ]
                    },
                    "elasticity": {
                        "available_elasticity_dims": ["width", "depth"]
                    }
                },
                "search": {
                    "algorithm": "NSGA2",
                    "num_evals": 2, #30,
                    "population": 1, # 5,
                    "ref_acc": model_top1_acc.item(),
                    "acc_delta": 4
                }
            }
        }

def train_epoch_fn(loader, model, compression_ctrl, epoch, optimizer):
    train_epoch(loader, model, criterion, optimizer, epoch, compression_ctrl, train_iters=train_steps)

# define optimizer
init_lr = 3e-4
compression_lr = init_lr / 10
optimizer = torch.optim.Adam(model.parameters(), lr=compression_lr)

nncf_config = NNCFConfig.from_dict(config)

bn_adapt_args = BNAdaptationInitArgs(data_loader=wrap_dataloader_for_init(train_loader), device=device)
nncf_config.register_extra_structs([bn_adapt_args])

nncf_network = create_nncf_network(model, nncf_config)

training_algorithm = EpochBasedTrainingAlgorithm.from_config(nncf_network, nncf_config)


nncf_network, elasticity_ctrl = training_algorithm.run(train_epoch_fn, train_loader,
                                                       validate, val_loader, optimizer,
                                                       OUTPUT_DIR, None,
                                                       train_steps)

search_algo = SearchAlgorithm.from_config(nncf_network, elasticity_ctrl, nncf_config)

def validate_model_fn_top1(model, val_loader):
    top1, _, _ = validate(model, val_loader, criterion)
    return top1.item()

elasticity_ctrl, best_config, performance_metrics = search_algo.run(validate_model_fn_top1, val_loader,
                                                                    OUTPUT_DIR,
                                                                    tensorboard_writer=None)

print("Best config: {best_config}".format(best_config=best_config))
print("Performance metrics: {performance_metrics}".format(performance_metrics=performance_metrics))




## Visualization of the search stage

In [None]:
search_algo.visualize_search_progression(filename=Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_search")))

In [None]:
ie = Core()
ie.get_property(device_name="CPU", name="FULL_DEVICE_NAME")