# Automated Neural Architecture Search with BootstrapNAS
This notebook demonstrates how to use [BootstrapNAS](https://arxiv.org/abs/2112.10878), a capability in NNCF to generate weight-sharing super-networks from pre-trained models. Once the super-network has been generated, BootstrapNAS can train it and search for efficient sub-networks. 

We will use [MobileNet-V2](https://arxiv.org/abs/1801.04381) pre-trained with CIFAR-10. MobileNet-V2 is an efficient mobile architecture based on inverted residual blocks. Our goal is to discover alternative models, a.k.a., subnetworks, that perform better than the input pre-trained model.

\*BootstrapNAS is an **experimental feature** in NNCF.

<p align="center">
<img src="https://github.com/jpablomch/bootstrapnas/raw/main/architecture.png" alt="BootstrapNAS Architecture" width="800"/>
</p>


BootstrapNAS (1) takes as input a pre-trained model. (2) It uses this model to generate a weight-sharing super-network. (3) BootstrapNAS then applies a training strategy, and once the super-network has been trained, (4) it searches for efficient subnetworks that satisfy the user's requirements. (5) The configuration of the discovered sub-network(s) is returned to the user.

## Imports and Settings

Import NNCF and all auxiliary packages from your Python code.
Set a name for the model, and the image width and height that will be used for the network. Also define paths where PyTorch, ONNX and OpenVINO IR versions of the models will be stored. 

> NOTE: All NNCF logging messages below ERROR level (INFO and WARNING) are disabled to simplify the tutorial. For production use, it is recommended to enable logging, by removing ```set_log_level(logging.ERROR)```.

In [None]:
import sys
import time
import warnings  # to disable warnings on export to ONNX
warnings.filterwarnings("ignore")
import zipfile
from pathlib import Path
import logging

import torch
import nncf  # Important - should be imported directly after torch

import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

from bootstrapnas_utils import MobileNetV2, validate, train_epoch

from nncf.common.utils.logger import set_log_level
set_log_level(logging.ERROR)  # Disables all NNCF info and warning messages

from nncf import NNCFConfig
from nncf.config.structures import BNAdaptationInitArgs
from nncf.experimental.torch.nas.bootstrapNAS import EpochBasedTrainingAlgorithm
from nncf.experimental.torch.nas.bootstrapNAS import SearchAlgorithm
from nncf.torch import create_compressed_model, register_default_init_args
from nncf.torch.initialization import wrap_dataloader_for_init
from nncf.torch.model_creation import create_nncf_network

from openvino.runtime import Core
from torch.jit import TracerWarning

sys.path.append("../utils")
from notebook_utils import download_file

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

MODEL_DIR = Path("model")
OUTPUT_DIR = Path("output")
DATA_DIR = Path("data")
BASE_MODEL_NAME = "mobilenet-V2"
image_size = 32

OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)
DATA_DIR.mkdir(exist_ok=True)

# Paths where models will be stored
fp32_pth_path = Path(MODEL_DIR / (BASE_MODEL_NAME + "_fp32")).with_suffix(".pth")
model_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME )).with_suffix(".onnx")
supernet_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_supernet")).with_suffix(".onnx")
subnet_onnx_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_subnet")).with_suffix(".onnx")

## Download pre-trained model weights
Download the pre-trained weights for MobileNet-V2 model.

In [None]:

# It's possible to train FP32 model from scratch, but it might be slow. So the pre-trained weights are downloaded by default.
pretrained_on_cifar10 = True
fp32_pth_url = "http://hsw1.jf.intel.com/share/bootstrapNAS/checkpoints/cifar10/mobilenet_v2.pt"
download_file(fp32_pth_url, directory=MODEL_DIR, filename=fp32_pth_path.name)

## Prepare CIFAR-10 dataset
Next, prepare the CIFAR-10 dataset. The CIFAR-10 dataset contains:
* 60,000 images of shape 3x32x32
* 10 different classes (6,000 images per class): airplane, automobile, etc. 

Here, the dataloader is created for both the training and validation dataset which includes normalization, crop, and other transformation.  Each dataloader uses 4 workers and a batch size of 64 for training and 1000 for validation.

In [None]:
DATASET_DIR = DATA_DIR / "cifar10"

image_size = 32
size = int(image_size / 0.875)
normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                                         std=(0.2471, 0.2435, 0.2616))
list_val_transforms = [
            transforms.ToTensor(),
            normalize
        ]
val_transform = transforms.Compose(list_val_transforms)

list_train_transforms = [
            transforms.RandomCrop(image_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]
    
train_transform = transforms.Compose(list_train_transforms)
    
download = False 
if not DATASET_DIR.exists(): 
    download = True

train_dataset = datasets.CIFAR10(DATASET_DIR, train=True, transform=train_transform, download=download)
val_dataset = datasets.CIFAR10(DATASET_DIR, train=False, transform=val_transform, download=download)

batch_size_val = 1000
batch_size = 64
workers = 4
pin_memory = device != 'cpu'
val_sampler = torch.utils.data.SequentialSampler(val_dataset) 
train_sampler = None
train_shuffle = None

val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size_val, shuffle=False,
        num_workers=workers, pin_memory=pin_memory,
        sampler=val_sampler, drop_last=False)

train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=train_shuffle,
            num_workers=workers, pin_memory=pin_memory, sampler=train_sampler, drop_last=True)


<!-- ## Pre-train Floating-Point Model
Using NNCF for model compression assumes that the user has a pre-trained model and a training pipeline.

Here we demonstrate one possible training pipeline: a ResNet-18 model pre-trained on 1000 classes from ImageNet is fine-tuned with 200 classes from Tiny-Imagenet. 

Subsequently, the training and validation functions will be reused as is for quantization-aware training.
 -->
 
 ## Generate Super-network from pre-trained model
 
Using NNCF for model compression assumes that the user has a pre-trained model and a training pipeline. Next, we demonstrate one possible training pipeline.


### Evaluate pre-trained model
Load the pretrained model and evaluate using validation dataset.

In [None]:
model = MobileNetV2()
state_dict = torch.load(fp32_pth_path)
model.load_state_dict(state_dict)

model.to(device)

# Test exporting original model to ONNX
dummy_input = torch.randn(1, 3, image_size, image_size).to(device)

criterion = nn.CrossEntropyLoss()

model_top1_acc, _, _ = validate(model, device, val_loader, criterion) 


### Train supernetwork
We use the pre-trained MobileNet-V2 model to generate a weight-sharing super-network.

In [None]:
# Configurations
train_steps = 10

config = {
            "device": device,
            "input_info": {
                "sample_size": [1, 3, 32, 32],
            },
            "checkpoint_save_dir": OUTPUT_DIR,
            "bootstrapNAS": {
                "training": {
                    # "algorithm": "progressive_shrinking",
                    "batchnorm_adaptation": {
                        "num_bn_adaptation_samples": 2
                    },
                    "schedule": {
                        "list_stage_descriptions": [
                            {"train_dims": ["depth"], "epochs": 1},
                            # {"train_dims": ["depth"], "epochs": 1, "depth_indicator": 2},
                            # {"train_dims": ["depth", "width"], "epochs": 1, "depth_indicator": 2, "reorg_weights": True, "width_indicator": 2}
                        ]
                    },
                    "elasticity": {
                        "available_elasticity_dims": ["width", "depth"]
                    }
                },
                "search": {
                    "algorithm": "NSGA2",
                    "num_evals": 2, #30,
                    "population": 1, # 5,
                    "ref_acc": model_top1_acc.item(),
                    "acc_delta": 4
                }
            }
        }

# define optimizer
init_lr = 3e-4
compression_lr = init_lr / 10
optimizer = torch.optim.Adam(model.parameters(), lr=compression_lr)

# Setup
nncf_config = NNCFConfig.from_dict(config)

bn_adapt_args = BNAdaptationInitArgs(data_loader=wrap_dataloader_for_init(train_loader), device=device)
nncf_config.register_extra_structs([bn_adapt_args])

nncf_network = create_nncf_network(model, nncf_config)


# Training
def train_epoch_fn(loader, model, compression_ctrl, epoch, optimizer):
    train_epoch(loader, model, device, criterion, optimizer, epoch, compression_ctrl, train_iters=train_steps)

training_algorithm = EpochBasedTrainingAlgorithm.from_config(nncf_network, nncf_config)


nncf_network, elasticity_ctrl = training_algorithm.run(train_epoch_fn, train_loader,
                                                       validate, val_loader, optimizer,
                                                       OUTPUT_DIR, None,
                                                       train_steps)


### Search for sub-networks
Use NSGA2 Search algorithm to obtain the best sub-network
Once the super-network has been trained, use NSGA2 (as specified in configuration) to obtain the best sub-network to satisfy the user's requirements. The configuration of the discovered sub-network(s) and ther performance metrics are returned to the user.

In [None]:
search_algo = SearchAlgorithm.from_config(nncf_network, elasticity_ctrl, nncf_config)

def validate_model_fn_top1(model, val_loader):
    top1, _, _ = validate(model, device, val_loader, criterion)
    return top1.item()

elasticity_ctrl, best_config, performance_metrics = search_algo.run(validate_model_fn_top1, val_loader,
                                                                    OUTPUT_DIR,
                                                                    tensorboard_writer=None)

print("Best config: {best_config}".format(best_config=best_config))
print("Performance metrics: {performance_metrics}".format(performance_metrics=performance_metrics))


## Visualization of the search stage
After the search has concluded, we can visualize the search progression phase as a PNG file.

In [None]:
search_algo.visualize_search_progression(filename=Path(OUTPUT_DIR / (BASE_MODEL_NAME + "_search")))

In [None]:
ie = Core()
ie.get_property(device_name="CPU", name="FULL_DEVICE_NAME")