# OFA³: Automatic Selection of the Best Non-dominated Sub-networks for Ensembles

- **Description**: 
  - This Jupyter notebook works as a preparation for the OFA³ search. 
  - We first load the 100 models obtained from the output of the OFA² search (file "ofa2_nsga2.pickle"). 
  - Then we take each model and evaluate them on the **<ins>training set (1,281,167 images)</ins>** of the ILSVRC dataset (ImageNet-1k). 
  - The output of this notebook provides two tables for each model (200 tables in total):
    - "OFA2_model_XXX_class.csv": table containing the top-5 predicted classes of the model.
    - "OFA2_model_XXX_prob.csv": table containing the respective probabilities.
  - The directory that files will be saved is <ins>**"ofa2_models_output"</ins>**.

- **Author**: TBA (hidden due to blind review)
- **email**: TBA (hidden due to blind review)

- **arXiv link**: TBA

# Install packages

In [1]:
#!pip install -q -r requirements.txt
!pip install -q \
    numpy       \
    torch       \
    torchvision \
    ofa2        \
    tqdm        \
    matplotlib

# Imports

In [2]:
# general
import os
import time
import math
import random
import pickle
from tqdm import tqdm

# AI/ML/NN
import pandas as pd
import numpy as np
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn

# OFA/OFA²
from ofa2.model_zoo import ofa_net
from ofa2.imagenet_classification.elastic_nn.utils import set_running_statistics
from ofa2.utils import AverageMeter#, accuracy

In [3]:
# set random seeds for reproducibility
random_seed = 1
random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [4]:
# set device to use GPU or CPU
cuda_available = torch.cuda.is_available()
if cuda_available:
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.cuda.manual_seed(random_seed)
    print("Using GPU.")
else:
    print("Using CPU.")
#cuda0 = torch.device('cuda:0')

Using GPU.


# Dataset & DataLoader

In [5]:
# ImageNet Full
imagenet_data_path = "~/dataset/imagenet/"
#----------------------------
# ImageNet subset
#imagenet_data_path = "~/dataset/imagenet_1k"

In [6]:
ofa_network = ofa_net("ofa_mbv3_d234_e346_k357_w1.2", pretrained=True)
# ofa_network2 = torch.load(model_dir='~/model/ofa_mbv3_d234_e346_k357_w1.2')

In [7]:
# The following function build the data transforms for test
def build_val_transform(size):
    return transforms.Compose(
        [
            transforms.Resize(int(math.ceil(size / 0.875))),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

In [8]:
# this dataloader is for the training set --> used to generate probability table
data_loader_train = torch.utils.data.DataLoader(
    datasets.ImageFolder(
        root=os.path.join(imagenet_data_path, "train"), transform=build_val_transform(224)
    ),
    batch_size=4_096, # evaluation batch size
    shuffle=False,    # evaluation only
    num_workers=16,   # number of workers for the data loader
    pin_memory=True,
    drop_last=False,
)
print("The ImageNet dataloader for the training set is ready.")

The ImageNet dataloader for the training set is ready.


# Load results from OFA²

In [9]:
with open('ofa2_nsga2.pickle', 'rb') as f:
    ofa2_nsga2 = pickle.load(f)

# Evaluation function

In [10]:
def ensemble_evaluate_ofa_subnet(
    filename, ofa_net, path, net_config, data_loader, batch_size, device="cuda:0"
):
    assert "ks" in net_config and "d" in net_config and "e" in net_config
    assert (
        len(net_config["ks"]) == 20
        and len(net_config["e"]) == 20
        and len(net_config["d"]) == 5
    )
    ofa_net.set_active_subnet(ks=net_config["ks"], d=net_config["d"], e=net_config["e"])
    subnet = ofa_net.get_active_subnet().to(device)
    calib_bn(subnet, path, net_config["r"][0], batch_size)
    top1 = validate(filename, subnet, path, net_config["r"][0], data_loader, batch_size, device)
    return top1

In [11]:
def calib_bn(net, path, image_size, batch_size, num_images=2000):
    # print('Creating dataloader for resetting BN running statistics...')
    dataset = datasets.ImageFolder(
        os.path.join(path, "train"),
        transforms.Compose(
            [
                transforms.RandomResizedCrop(image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=32.0 / 255.0, saturation=0.5),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        ),
    )
    chosen_indexes = np.random.choice(list(range(len(dataset))), num_images)
    sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sub_sampler,
        batch_size=batch_size,
        num_workers=4,
        pin_memory=True,
        drop_last=False,
    )
    # print('Resetting BN running statistics (this may take 10-20 seconds)...')
    set_running_statistics(net, data_loader)

In [12]:
# from: once-for-all/ofa/utils/common_tools.py
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    softmax = torch.nn.functional.softmax(output, dim=1)
    # prob, pred = softmax.topk(maxk, 1, True, True)
    prob, _ = softmax.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res, pred.t(), prob

In [13]:
def validate(filename, net, path, image_size, data_loader, batch_size=100, device="cuda:0"):
    if "cuda" in device:
        net = torch.nn.DataParallel(net).to(device)
    else:
        net = net.to(device)

    data_loader.dataset.transform = transforms.Compose(
        [
            transforms.Resize(int(math.ceil(image_size / 0.875))),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().to(device)

    net.eval()
    net = net.to(device)
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    with torch.no_grad():
        with tqdm(total=len(data_loader), desc="Validate") as t:
            for i, (images, labels) in enumerate(data_loader):
                images, labels = images.to(device), labels.to(device)
                # compute output
                output = net(images)
                loss = criterion(output, labels)
                #-----------------------------------------------------
                # measure accuracy and record loss
                res, pred, prob = accuracy(output, labels, topk=(1, 5))
                acc1 = res[0]
                acc5 = res[1]
                # save to CSV
                if filename is not None:
                    # print(f'{acc1=}, {acc5=}, {pred}, {labels=}')
                    labels_t = labels.t().unsqueeze(dim=1)
                    topk_classification = torch.cat((pred, labels_t), dim=1)
                    # cast to DataFrame
                    topk_df = pd.DataFrame(topk_classification.cpu())
                    topk_df.to_csv(filename + '_class.csv', mode='a', header=False, index=False)
                    # probability
                    topk_prob = pd.DataFrame(prob.cpu())
                    topk_prob.to_csv(filename + '_prob.csv', encoding='utf-8', mode='a', header=False, index=False)
                #-----------------------------------------------------
                                
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0].item(), images.size(0))
                top5.update(acc5[0].item(), images.size(0))
                t.set_postfix(
                    {
                        "loss": losses.avg,
                        "top1": top1.avg,
                        "top5": top5.avg,
                        "img_size": images.size(2),
                    }
                )
                t.update(1)

    print(
        "Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f"
        % (losses.avg, top1.avg, top5.avg)
    )
    return top1.avg

In [14]:
def individual_to_arch(population, n_blocks):
    archs = []
    for individual in population:
        archs.append(
            {
                "ks": individual[0:n_blocks],
                "e": individual[n_blocks : 2 * n_blocks],
                "d": individual[2 * n_blocks : -1],
                "r": individual[-1:],
            }
        )
    return archs

In [15]:
def individual_to_ofa(model):
    # insert wid
    model['wid'] = None
    # cast back from NumPy to list
    model['ks'] = model['ks'].tolist()
    model['e'] = model['e'].tolist()
    model['d'] = model['d'].tolist()
    model['r'] = model['r'].tolist()
    return model

# Generate probability table

In [None]:
# start measuring time
start_time = time.time()
#----------------------------
debug = True
#debug = False
if debug:
    # loop for each candidate to form the ensemble
    for idx, individual in enumerate(ofa2_nsga2[:], 0):
        encoding = individual.get('X')
        model = individual_to_ofa(individual_to_arch([encoding], 20)[0])
        filename = 'OFA2_model_' + str(idx).zfill(3)
        path = os.path.join('ofa2_models_output', filename)
        
        # get classification label 
        top1 = ensemble_evaluate_ofa_subnet(
            path,
            ofa_network,
            imagenet_data_path,
            model,
            data_loader_train, # dataloader for the training set
            batch_size=4_096,  # evaluation batch size
            device="cuda:0" if cuda_available else "cpu",
        )
#----------------------------
# stop measuring time
end_time = time.time()
#----------------------------
elapsed = end_time - start_time
print('The generation of the probability tables took', time.strftime("%Hh%Mm%Ss", time.gmtime(elapsed)), 'to finish.')

Validate: 100%|██████████| 313/313 [08:18<00:00,  1.59s/it, loss=0.759, top1=83.3, top5=96.6, img_size=160]


Results: loss=0.75945,	 top1=83.3,	 top5=96.6


Validate: 100%|██████████| 313/313 [08:07<00:00,  1.56s/it, loss=0.751, top1=83.4, top5=96.7, img_size=160]


Results: loss=0.75109,	 top1=83.4,	 top5=96.7


Validate: 100%|██████████| 313/313 [08:02<00:00,  1.54s/it, loss=0.74, top1=83.8, top5=96.8, img_size=160] 


Results: loss=0.73967,	 top1=83.8,	 top5=96.8


Validate: 100%|██████████| 313/313 [08:31<00:00,  1.63s/it, loss=0.73, top1=83.9, top5=96.8, img_size=160] 


Results: loss=0.72951,	 top1=83.9,	 top5=96.8


Validate: 100%|██████████| 313/313 [08:47<00:00,  1.69s/it, loss=0.714, top1=84.3, top5=97, img_size=160]  


Results: loss=0.71428,	 top1=84.3,	 top5=97.0


Validate: 100%|██████████| 313/313 [08:42<00:00,  1.67s/it, loss=0.706, top1=84.5, top5=97, img_size=160]  


Results: loss=0.70575,	 top1=84.5,	 top5=97.0


Validate: 100%|██████████| 313/313 [08:44<00:00,  1.67s/it, loss=0.699, top1=84.7, top5=97.1, img_size=160]


Results: loss=0.69878,	 top1=84.7,	 top5=97.1


Validate: 100%|██████████| 313/313 [08:20<00:00,  1.60s/it, loss=0.695, top1=84.8, top5=97.1, img_size=160]


Results: loss=0.69531,	 top1=84.8,	 top5=97.1


Validate: 100%|██████████| 313/313 [08:48<00:00,  1.69s/it, loss=0.686, top1=85, top5=97.2, img_size=160]  


Results: loss=0.68590,	 top1=85.0,	 top5=97.2


Validate: 100%|██████████| 313/313 [08:10<00:00,  1.57s/it, loss=0.672, top1=85.2, top5=97.2, img_size=160]


Results: loss=0.67200,	 top1=85.2,	 top5=97.2


Validate: 100%|██████████| 313/313 [08:45<00:00,  1.68s/it, loss=0.669, top1=85.4, top5=97.3, img_size=160]


Results: loss=0.66918,	 top1=85.4,	 top5=97.3


Validate: 100%|██████████| 313/313 [08:36<00:00,  1.65s/it, loss=0.661, top1=85.6, top5=97.4, img_size=160]


Results: loss=0.66077,	 top1=85.6,	 top5=97.4


Validate: 100%|██████████| 313/313 [08:36<00:00,  1.65s/it, loss=0.656, top1=85.8, top5=97.4, img_size=160]


Results: loss=0.65649,	 top1=85.8,	 top5=97.4


Validate: 100%|██████████| 313/313 [08:35<00:00,  1.65s/it, loss=0.646, top1=86, top5=97.5, img_size=160]  


Results: loss=0.64614,	 top1=86.0,	 top5=97.5


Validate: 100%|██████████| 313/313 [08:29<00:00,  1.63s/it, loss=0.643, top1=86, top5=97.5, img_size=160]  


Results: loss=0.64345,	 top1=86.0,	 top5=97.5


# End of the notebook