In [1]:
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn

# XAutoDL 
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
    prepare_seed,
    prepare_logger,
    save_checkpoint,
    copy_checkpoint,
    get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_search_spaces

from custom_models import get_cell_based_tiny_net
from custom_search_cells import NAS201SearchCell as SearchCell
from xautodl.models.cell_searchs.genotypes import Structure

# NB201
from nas_201_api import NASBench201API as API

import scipy.stats as stats

2022-11-02 08:00:09.197152: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
parser = argparse.ArgumentParser("Random search for NAS.")
parser.add_argument("--data_path", type=str, default='../cifar.python', help="The path to dataset")
parser.add_argument("--dataset", type=str, default='cifar10',choices=["cifar10", "cifar100", "ImageNet16-120"], help="Choose between Cifar10/100 and ImageNet-16.")

# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, default='nas-bench-201', help="The search space name.")
parser.add_argument("--config_path", type=str, default='./MY.config', help="The path to the configuration.")
parser.add_argument("--max_nodes", type=int, default=4, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, default=16, help="The number of channels.")
parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.")
parser.add_argument("--select_num", type=int, default=100, help="The number of selected architectures to evaluate.")
parser.add_argument("--track_running_stats", type=int, default=0, choices=[0, 1], help="Whether use track_running_stats or not in the BN layer.")
# log
parser.add_argument("--workers", type=int, default=4, help="number of data loading workers")
parser.add_argument("--save_dir", type=str, default='./op_level-arch_loop-reset_cell_params-loop5_ep1-acc_metric', help="Folder to save checkpoints and log.")
# parser.add_argument("--arch_nas_dataset", type=str, default='../NAS-Bench-201-v1_1-096897.pth', help="The path to load the architecture dataset (tiny-nas-benchmark).")
parser.add_argument("--arch_nas_dataset", type=str, default=None, help="The path to load the architecture dataset (tiny-nas-benchmark).")
parser.add_argument("--print_freq", type=int, default=200, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, default=None, help="manual seed")
args = parser.parse_args(args=[])
if args.rand_seed is None or args.rand_seed < 0:
    args.rand_seed = random.randint(1, 100000)

    
print(args.rand_seed)
print(args)
xargs=args

85044
Namespace(arch_nas_dataset=None, channel=16, config_path='./MY.config', data_path='../cifar.python', dataset='cifar10', max_nodes=4, num_cells=5, print_freq=200, rand_seed=85044, save_dir='./op_level-arch_loop-reset_cell_params-loop5_ep1-acc_metric', search_space_name='nas-bench-201', select_num=100, track_running_stats=0, workers=4)


In [3]:
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)

Main Function with logger : Logger(dir=op_level-arch_loop-reset_cell_params-loop5_ep1-acc_metric, use-tf=False, writer=None)
Arguments : -------------------------------
arch_nas_dataset : None
channel          : 16
config_path      : ./MY.config
data_path        : ../cifar.python
dataset          : cifar10
max_nodes        : 4
num_cells        : 5
print_freq       : 200
rand_seed        : 85044
save_dir         : ./op_level-arch_loop-reset_cell_params-loop5_ep1-acc_metric
search_space_name : nas-bench-201
select_num       : 100
track_running_stats : 0
workers          : 4
Python  Version  : 3.7.13 (default, Mar 29 2022, 02:18:16)  [GCC 7.5.0]
Pillow  Version  : 9.0.1
PyTorch Version  : 1.12.0
cuDNN   Version  : 8302
CUDA available   : True
CUDA GPU numbers : 2
CUDA_VISIBLE_DEVICES : None


In [4]:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(train_data,
                                                        valid_data,
                                                        xargs.dataset,
                                                        "../configs/nas-benchmark/",
                                                        (config.batch_size, config.test_batch_size),
                                                        xargs.workers)
logger.log("||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
            xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))

search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
    {
        "name": "RANDOM",
        "C": xargs.channel,
        "N": xargs.num_cells,
        "max_nodes": xargs.max_nodes,
        "num_classes": class_num,
        "space": search_space,
        "affine": False,
        "track_running_stats": bool(xargs.track_running_stats),
    },
    None,
)
search_model = get_cell_based_tiny_net(model_config)

w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)

logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion   : {:}".format(criterion))
# if xargs.arch_nas_dataset is None:
api = None
# else:
#     api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))

last_info, model_base_path, model_best_path = (
    logger.path("info"),
    logger.path("model"),
    logger.path("best"),
)
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

if last_info.exists():  # automatically resume from previous checkpoint
    logger.log(
        "=> loading checkpoint of the last-info '{:}' start".format(last_info)
    )
    last_info = torch.load(last_info)
    start_epoch = last_info["epoch"]
    checkpoint = torch.load(last_info["last_checkpoint"])
    genotypes = checkpoint["genotypes"]
    valid_accuracies = checkpoint["valid_accuracies"]
    search_model.load_state_dict(checkpoint["search_model"])
    w_scheduler.load_state_dict(checkpoint["w_scheduler"])
    w_optimizer.load_state_dict(checkpoint["w_optimizer"])
    logger.log(
        "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
            last_info, start_epoch
        )
    )
else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {}

Files already downloaded and verified
Files already downloaded and verified
./MY.config
Configure(scheduler='cos', LR=0.025, eta_min=0.001, epochs=50, warmup=0, optim='SGD', decay=0.0005, momentum=0.9, nesterov=True, criterion='Softmax', batch_size=64, test_batch_size=512, class_num=10, xshape=(1, 3, 32, 32))
||||||| cifar10    ||||||| Search-Loader-Num=391, Valid-Loader-Num=49, batch size=64
||||||| cifar10    ||||||| Config=Configure(scheduler='cos', LR=0.025, eta_min=0.001, epochs=50, warmup=0, optim='SGD', decay=0.0005, momentum=0.9, nesterov=True, criterion='Softmax', batch_size=64, test_batch_size=512, class_num=10, xshape=(1, 3, 32, 32))
w-optimizer : SGD (
Parameter Group 0
    dampening: 0
    foreach: None
    initial_lr: 0.025
    lr: 0.025
    maximize: False
    momentum: 0.9
    nesterov: True
    weight_decay: 0.0005
)
w-scheduler : CosineAnnealingLR(warmup=0, max-epoch=50, current::epoch=0, iter=0.00, type=cosine, T-max=50, eta-min=0.001)
criterion   : CrossEntropyLoss(

In [5]:
def acc_confidence_robustness_metrics(network, inputs, targets):
    with torch.no_grad():
        # accuracy
        network.train()
        _, logits = network(inputs)
        val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
        acc = val_top1
        
        return acc.item()
        
#         # confidence
#         prob = torch.nn.functional.softmax(logits, dim=1)
#         one_hot_idx = torch.nn.functional.one_hot(targets)
#         confidence = (prob[one_hot_idx==1].sum()) / inputs.size(0) * 100 # in percent
        
#         # sensitivity
#         _, noisy_logits = network(inputs + torch.randn_like(inputs)*0.1)
#         kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
#         sensitivity = kl_loss(torch.nn.functional.log_softmax(noisy_logits, dim=1), torch.nn.functional.softmax(logits, dim=1))
        
#         # robustness
#         original_weights = deepcopy(network.state_dict())
#         for m in network.modules():
#             if isinstance(m, SearchCell):
#                 for p in m.parameters():
#                     p.add_(torch.randn_like(p) * p.std()*0.3)
            
#         _, noisy_logits = network(inputs)
#         kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
#         robustness = -kl_loss(torch.nn.functional.log_softmax(noisy_logits, dim=1), torch.nn.functional.softmax(logits, dim=1))
#         network.load_state_dict(original_weights)
                
#         return acc.item(), confidence.item(), sensitivity.item(), robustness.item()
    
def step_sim_metric(network, criterion, inputs, targets):
    original_dict = deepcopy(network.state_dict())
    optim_large_step = torch.optim.SGD(network.parameters(), lr=0.025)
    
    # single large step
    network.train()
    optim_large_step.zero_grad()
    _, logits = network(inputs)
    base_loss = criterion(logits, targets)
    base_loss.backward()
    optim_large_step.step()
    large_step_dict = deepcopy(network.state_dict())
    
    # multiple small steps
    network.load_state_dict(original_dict)
    optim_small_step = torch.optim.SGD(network.parameters(), lr=0.025/3)
    for i in range(3):
        optim_small_step.zero_grad()
        _, logits = network(inputs)
        base_loss = criterion(logits, targets)
        base_loss.backward()
        optim_small_step.step()
    small_step_dict = deepcopy(network.state_dict())
    scores = []
    for key in large_step_dict.keys():
        if ('weight' in key) and (original_dict[key].dim()==4):
            if (original_dict[key] != large_step_dict[key]).sum():
                large_step = large_step_dict[key] - original_dict[key]
                small_step = small_step_dict[key] - original_dict[key]
                co, ci, kh, kw = large_step.size()
                large_step = large_step.view(co, -1)
                small_step = small_step.view(co, -1)
                score = torch.nn.functional.cosine_similarity(large_step, small_step, dim=1)
                score = score.mean().item() * 100 # in percent
                scores.append(score)
    if len(scores)==0:
        step_sim = 100
        raise RuntimeError
    else:
        step_sim = np.mean(scores)
    
    # resume
    network.load_state_dict(original_dict)
            
    return step_sim

In [None]:
# start training
start_time, search_time, epoch_time, total_epoch = (
    time.time(),
    AverageMeter(),
    AverageMeter(),
    config.epochs + config.warmup,
)

################# initialize
cells = []
for m in network.modules():
    if isinstance(m, SearchCell):
        cells.append(m)
num_cells = len(cells)
print("total number of nodes:{}".format(num_cells*xargs.max_nodes))
        
op_names = deepcopy(cells[0].op_names)
op_names_wo_none = deepcopy(op_names)
if "none" in op_names_wo_none:
    op_names_wo_none.remove("none")

genotypes = []
for i in range(1, xargs.max_nodes):
    xlist = []
    for j in range(i):
        node_str = "{:}<-{:}".format(i, j)
        if i-j==1:
            op_name = "skip_connect"
        else:
            op_name = "none"
        xlist.append((op_name, j))
    genotypes.append(tuple(xlist))
init_arch = Structure(genotypes)

for c in cells:
    c.arch_cache = init_arch

### gen possible connections of a target node
possible_connections = {}
for target_node_idx in range(1,xargs.max_nodes):
    possible_connections[target_node_idx] = list()
    xlists = []
    for src_node in range(target_node_idx):
        node_str = "{:}<-{:}".format(target_node_idx, src_node)
        # select possible ops
#         if target_node_idx - src_node == 1:
#             op_names_tmp = op_names_wo_none
#         else:
#             op_names_tmp = op_names
        op_names_tmp = op_names
            
        if len(xlists) == 0: # initial iteration
            for op_name in op_names_tmp:
                xlists.append([(op_name, src_node)])
        else:
            new_xlists = []
            for op_name in op_names_tmp:
                for xlist in xlists:
                    new_xlist = deepcopy(xlist)
                    new_xlist.append((op_name, src_node))
                    new_xlists.append(new_xlist)
            xlists = new_xlists
    for xlist in xlists:
        selected_ops = []
        for l in xlist:
            selected_ops.append(l[0])
        if sum(np.array(selected_ops) == "none") == len(selected_ops):
            continue
        possible_connections[target_node_idx].append(tuple(xlist))
    print("target_node:{}".format(target_node_idx), len(possible_connections[target_node_idx]))
        
### train while generating random architectures by mutating connections of a target node

for arch_loop in range(3):
    for target_cell_idx in range(num_cells):
        for cell_loop in range(1):
#             network.module.classifier.reset_parameters()
            target_cell = cells[target_cell_idx]
            print("\n\n Searching with a cell #{}".format(target_cell_idx))
            for target_node_idx in range(1,xargs.max_nodes):
                current_genotypes,_ = target_cell.arch_cache.tolist(None)
                print("\nCurrent target cell:{} / current target node:{}".format(target_cell_idx, target_node_idx))
                ####
                for src_node_idx in range(target_node_idx):
                    node_str = "{:}<-{:}".format(target_node_idx, src_node_idx)
                    for m in target_cell.edges[node_str].modules():
                        if hasattr(m, 'reset_parameters'):
                            m.reset_parameters()
                ####
                ## training
                for ep in range(2):
                    data_time, batch_time = AverageMeter(), AverageMeter()
                    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
                    network.train()
                    end = time.time()
                    print_freq = 200
                    for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
                        ######### random generation
                        genotypes = deepcopy(current_genotypes)
                        connection = random.choice(possible_connections[target_node_idx])
                        genotypes[target_node_idx-1] = connection
                        arch = Structure(genotypes)
                        target_cell.arch_cache = arch

                        ######### forward/backward/optim
                        base_targets = base_targets.cuda(non_blocking=True)
                        arch_targets = arch_targets.cuda(non_blocking=True)
                        # measure data loading time
                        data_time.update(time.time() - end)
                        w_optimizer.zero_grad()
                        _, logits = network(base_inputs)
                        base_loss = criterion(logits, base_targets)
                        base_loss.backward()
                        nn.utils.clip_grad_norm_(network.parameters(), 5)
                        w_optimizer.step()

                        ######### logging
                        base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
                        base_losses.update(base_loss.item(), base_inputs.size(0))
                        base_top1.update(base_prec1.item(), base_inputs.size(0))
                        base_top5.update(base_prec5.item(), base_inputs.size(0))
                        batch_time.update(time.time() - end)
                        end = time.time()
                        if step % print_freq == 0 or step + 1 == len(search_loader):
                            Sstr = ("*Train* "+ time_string()+" Ep:{:} [{:03d}/{:03d}]".format(ep, step, len(search_loader)))
                            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(batch_time=batch_time, data_time=data_time)
                            Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(loss=base_losses, top1=base_top1, top5=base_top5)
                            logger.log(Sstr + " " + Tstr + " " + Wstr)

                    logger.log("Ep:{:} ends : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(ep, base_losses.avg, base_top1.avg, base_top5.avg))
                ## evaluation
                network.train()
                archs, metric_accs = [], []
                loader_iter = iter(valid_loader)
                for connection in possible_connections[target_node_idx]:
                    ###### traverse over possible archs
                    genotypes = deepcopy(current_genotypes)
                    genotypes[target_node_idx-1] = connection
                    arch = Structure(genotypes)
                    target_cell.arch_cache = arch
                    ###### measure metrics
                    try:
                        inputs, targets = next(loader_iter)
                    except:
                        loader_iter = iter(valid_loader)
                        inputs, targets = next(loader_iter)
                    inputs, targets = inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True)
                    valid_acc = acc_confidence_robustness_metrics(network, inputs, targets)
                    archs.append(arch)
                    metric_accs.append(valid_acc)
                rank_accs = stats.rankdata(metric_accs)
                rank_agg = rank_accs
#                 l = len(rank_accs)
#                 rank_agg = np.log(rank_accs/l)+np.log(rank_confidences/l)+np.log(rank_sensitivities/l)+np.log(rank_robustnesses/l)+np.log(rank_step_sims/l)
    #             rank_agg = np.log(rank_accs/l)+np.log(rank_confidences/l)+np.log(rank_sensitivities/l)+np.log(rank_step_sims/l)
                best_idx = np.argmax(rank_agg)
                best_arch, best_acc= archs[best_idx], metric_accs[best_idx]
                logger.log("Found best op for target cell:{} / target node:{}".format(target_cell_idx, target_node_idx))
                logger.log(": {:} with accuracy={:.2f}%".format(best_arch, best_acc))
                target_cell.arch_cache = best_arch
            
best_archs = []
for c in cells:
    best_archs.append(c.arch_cache)
    
torch.save({"model":search_model.state_dict(), "best_archs":best_archs}, os.path.join(xargs.save_dir, "output.pth"))

for m in search_model.modules():
    if isinstance(m, SearchCell):
        logger.log(m.arch_cache)

logger.close()

total number of nodes:60
target_node:1 4
target_node:2 24
target_node:3 124


 Searching with a cell #0

Current target cell:0 / current target node:1
*Train* [2022-11-02 08:00:16] Ep:0 [000/391] Time 2.58 (2.58) Data 0.09 (0.09) Base [Loss 2.323 (2.323)  Prec@1 10.94 (10.94) Prec@5 54.69 (54.69)]
*Train* [2022-11-02 08:00:38] Ep:0 [200/391] Time 0.20 (0.12) Data 0.00 (0.00) Base [Loss 1.777 (1.939)  Prec@1 31.25 (26.83) Prec@5 89.06 (80.86)]
*Train* [2022-11-02 08:01:01] Ep:0 [390/391] Time 0.12 (0.12) Data 0.00 (0.00) Base [Loss 1.616 (1.835)  Prec@1 42.50 (30.93) Prec@5 87.50 (83.97)]
Ep:0 ends : loss=1.84, accuracy@1=30.93%, accuracy@5=83.97%
*Train* [2022-11-02 08:01:01] Ep:1 [000/391] Time 0.21 (0.21) Data 0.12 (0.12) Base [Loss 1.598 (1.598)  Prec@1 37.50 (37.50) Prec@5 92.19 (92.19)]
*Train* [2022-11-02 08:01:21] Ep:1 [200/391] Time 0.07 (0.10) Data 0.00 (0.00) Base [Loss 1.547 (1.594)  Prec@1 50.00 (41.04) Prec@5 93.75 (90.16)]
*Train* [2022-11-02 08:01:38] Ep:1 [390/391] Time

*Train* [2022-11-02 08:10:05] Ep:1 [200/391] Time 0.07 (0.11) Data 0.00 (0.00) Base [Loss 1.200 (0.898)  Prec@1 57.81 (68.26) Prec@5 93.75 (97.61)]
*Train* [2022-11-02 08:10:28] Ep:1 [390/391] Time 0.13 (0.12) Data 0.00 (0.00) Base [Loss 1.351 (0.893)  Prec@1 57.50 (68.64) Prec@5 95.00 (97.51)]
Ep:1 ends : loss=0.89, accuracy@1=68.64%, accuracy@5=97.51%
Found best op for target cell:2 / target node:1
: Structure(4 nodes with |nor_conv_3x3~0|+|none~0|skip_connect~1|+|none~0|none~1|skip_connect~2|) with accuracy=68.55%

Current target cell:2 / current target node:2
*Train* [2022-11-02 08:10:29] Ep:0 [000/391] Time 0.36 (0.36) Data 0.13 (0.13) Base [Loss 0.927 (0.927)  Prec@1 65.62 (65.62) Prec@5 100.00 (100.00)]
*Train* [2022-11-02 08:10:54] Ep:0 [200/391] Time 0.07 (0.13) Data 0.00 (0.00) Base [Loss 0.927 (0.961)  Prec@1 67.19 (66.39) Prec@5 96.88 (97.11)]
*Train* [2022-11-02 08:11:17] Ep:0 [390/391] Time 0.07 (0.12) Data 0.00 (0.00) Base [Loss 0.750 (0.921)  Prec@1 72.50 (67.57) Prec@5

*Train* [2022-11-02 08:20:04] Ep:0 [390/391] Time 0.07 (0.10) Data 0.00 (0.00) Base [Loss 0.894 (0.782)  Prec@1 67.50 (72.77) Prec@5 100.00 (97.99)]
Ep:0 ends : loss=0.78, accuracy@1=72.77%, accuracy@5=97.99%
*Train* [2022-11-02 08:20:04] Ep:1 [000/391] Time 0.22 (0.22) Data 0.14 (0.14) Base [Loss 0.829 (0.829)  Prec@1 67.19 (67.19) Prec@5 98.44 (98.44)]
*Train* [2022-11-02 08:20:22] Ep:1 [200/391] Time 0.08 (0.09) Data 0.00 (0.00) Base [Loss 0.810 (0.725)  Prec@1 65.62 (74.57) Prec@5 96.88 (98.24)]
*Train* [2022-11-02 08:20:45] Ep:1 [390/391] Time 0.12 (0.11) Data 0.00 (0.00) Base [Loss 0.912 (0.724)  Prec@1 67.50 (74.76) Prec@5 100.00 (98.28)]
Ep:1 ends : loss=0.72, accuracy@1=74.76%, accuracy@5=98.28%
Found best op for target cell:4 / target node:2
: Structure(4 nodes with |skip_connect~0|+|none~0|skip_connect~1|+|none~0|none~1|skip_connect~2|) with accuracy=78.32%

Current target cell:4 / current target node:3
*Train* [2022-11-02 08:20:48] Ep:0 [000/391] Time 0.25 (0.25) Data 0.16 

*Train* [2022-11-02 08:29:28] Ep:0 [000/391] Time 0.27 (0.27) Data 0.17 (0.17) Base [Loss 0.589 (0.589)  Prec@1 76.56 (76.56) Prec@5 100.00 (100.00)]
*Train* [2022-11-02 08:29:50] Ep:0 [200/391] Time 0.14 (0.11) Data 0.00 (0.00) Base [Loss 0.558 (0.780)  Prec@1 81.25 (73.17) Prec@5 100.00 (97.71)]
*Train* [2022-11-02 08:30:13] Ep:0 [390/391] Time 0.07 (0.11) Data 0.00 (0.00) Base [Loss 0.902 (0.749)  Prec@1 67.50 (74.00) Prec@5 100.00 (98.10)]
Ep:0 ends : loss=0.75, accuracy@1=74.00%, accuracy@5=98.10%
*Train* [2022-11-02 08:30:13] Ep:1 [000/391] Time 0.21 (0.21) Data 0.13 (0.13) Base [Loss 0.678 (0.678)  Prec@1 79.69 (79.69) Prec@5 98.44 (98.44)]
*Train* [2022-11-02 08:30:33] Ep:1 [200/391] Time 0.21 (0.10) Data 0.00 (0.00) Base [Loss 0.974 (0.672)  Prec@1 62.50 (76.40) Prec@5 98.44 (98.70)]
*Train* [2022-11-02 08:30:58] Ep:1 [390/391] Time 0.13 (0.12) Data 0.00 (0.00) Base [Loss 0.720 (0.661)  Prec@1 75.00 (77.03) Prec@5 97.50 (98.71)]
Ep:1 ends : loss=0.66, accuracy@1=77.03%, accura

*Train* [2022-11-02 08:39:55] Ep:1 [390/391] Time 0.07 (0.09) Data 0.00 (0.00) Base [Loss 0.714 (0.612)  Prec@1 75.00 (78.90) Prec@5 97.50 (98.82)]
Ep:1 ends : loss=0.61, accuracy@1=78.90%, accuracy@5=98.82%
Found best op for target cell:8 / target node:3
: Structure(4 nodes with |skip_connect~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|nor_conv_1x1~1|nor_conv_1x1~2|) with accuracy=80.86%


 Searching with a cell #9

Current target cell:9 / current target node:1
*Train* [2022-11-02 08:40:06] Ep:0 [000/391] Time 0.26 (0.26) Data 0.17 (0.17) Base [Loss 0.511 (0.511)  Prec@1 82.81 (82.81) Prec@5 95.31 (95.31)]
*Train* [2022-11-02 08:40:27] Ep:0 [200/391] Time 0.07 (0.10) Data 0.00 (0.00) Base [Loss 0.567 (0.716)  Prec@1 82.81 (75.61) Prec@5 100.00 (98.03)]
*Train* [2022-11-02 08:40:49] Ep:0 [390/391] Time 0.08 (0.11) Data 0.00 (0.00) Base [Loss 0.816 (0.668)  Prec@1 70.00 (77.26) Prec@5 97.50 (98.27)]
Ep:0 ends : loss=0.67, accuracy@1=77.26%, accuracy@5=98.27%
*Train* [2022-11-02 0

*Train* [2022-11-02 08:50:08] Ep:0 [390/391] Time 0.12 (0.10) Data 0.00 (0.00) Base [Loss 0.540 (0.578)  Prec@1 82.50 (80.26) Prec@5 100.00 (98.79)]
Ep:0 ends : loss=0.58, accuracy@1=80.26%, accuracy@5=98.79%
*Train* [2022-11-02 08:50:08] Ep:1 [000/391] Time 0.30 (0.30) Data 0.16 (0.16) Base [Loss 0.543 (0.543)  Prec@1 81.25 (81.25) Prec@5 100.00 (100.00)]
*Train* [2022-11-02 08:50:31] Ep:1 [200/391] Time 0.22 (0.11) Data 0.00 (0.00) Base [Loss 0.479 (0.537)  Prec@1 85.94 (81.70) Prec@5 98.44 (98.94)]
*Train* [2022-11-02 08:50:52] Ep:1 [390/391] Time 0.07 (0.11) Data 0.00 (0.00) Base [Loss 0.413 (0.543)  Prec@1 90.00 (81.42) Prec@5 97.50 (99.07)]
Ep:1 ends : loss=0.54, accuracy@1=81.42%, accuracy@5=99.07%
Found best op for target cell:11 / target node:1
: Structure(4 nodes with |avg_pool_3x3~0|+|none~0|skip_connect~1|+|none~0|none~1|skip_connect~2|) with accuracy=80.86%

Current target cell:11 / current target node:2
*Train* [2022-11-02 08:50:53] Ep:0 [000/391] Time 0.22 (0.22) Data 0.

In [None]:
# import matplotlib.pyplot as plt

# plt.scatter(rank_confidences,rank_accs)
# plt.show()

# plt.scatter(rank_sensitivities,rank_accs)
# plt.show()

# plt.scatter(rank_robustnesses,rank_accs)
# plt.show()

# plt.scatter(rank_step_sims,rank_accs)
# plt.show()

# Train a found model

In [None]:
trained_output = torch.load(os.path.join(xargs.save_dir, "output.pth"))
print(args)
args.save_dir = os.path.join(xargs.save_dir, "train")
print(args)

In [None]:
print(config)

In [None]:
logger = prepare_logger(args)

# cifar_train_config_path = "./MY.config"
cifar_train_config_path = "../configs/nas-benchmark/CIFAR.config"
###
train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(cifar_train_config_path, {"class_num": class_num, "xshape": xshape}, logger)

train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=xargs.workers,
            pin_memory=True,)

test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=xargs.workers,
            pin_memory=True,)

# search_loader, _, valid_loader = get_nas_search_loaders(train_data,
#                                                         valid_data,
#                                                         xargs.dataset,
#                                                         "../configs/nas-benchmark/",
#                                                         (config.batch_size, config.batch_size),
#                                                         xargs.workers)
logger.log("||||||| {:10s} ||||||| Train-Loader-Num={:}, Test-Loader-Num={:}, batch size={:}".format(
            xargs.dataset, len(train_loader), len(test_loader), config.batch_size))
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))

search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
    {
        "name": "RANDOM",
        "C": xargs.channel,
        "N": xargs.num_cells,
        "max_nodes": xargs.max_nodes,
        "num_classes": class_num,
        "space": search_space,
        "affine": False,
        "track_running_stats": True, # true for eval
    },
    None,
)
search_model = get_cell_based_tiny_net(model_config)

### load
# trained_output = torch.load(os.path.join(xargs.save_dir, "output.pth"))
# search_model.load_state_dict(trained_output['model'], strict=False)
best_archs = trained_output['best_archs']
i=0
for m in search_model.modules():
    if isinstance(m, SearchCell):
        m.arch_cache = best_archs[i]
        i += 1
for m in network.modules():
    if isinstance(m, SearchCell):
        print(m.arch_cache)
###

w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)

logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion   : {:}".format(criterion))

network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

last_info, model_base_path, model_best_path = (
    logger.path("info"),
    logger.path("model"),
    logger.path("best"),
)

start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {}

In [None]:
# def search_func_one_arch(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger):
#     data_time, batch_time = AverageMeter(), AverageMeter()
#     base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
#     network.train()
#     end = time.time()
#     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
#         xloader
#     ):
#         scheduler.update(None, 1.0 * step / len(xloader))
#         base_targets = base_targets.cuda(non_blocking=True)
#         arch_targets = arch_targets.cuda(non_blocking=True)
#         # measure data loading time
#         data_time.update(time.time() - end)

#         w_optimizer.zero_grad()
#         _, logits = network(base_inputs)
#         base_loss = criterion(logits, base_targets)
#         base_loss.backward()
#         nn.utils.clip_grad_norm_(network.parameters(), 5)
#         w_optimizer.step()
#         # record
#         base_prec1, base_prec5 = obtain_accuracy(
#             logits.data, base_targets.data, topk=(1, 5)
#         )
#         base_losses.update(base_loss.item(), base_inputs.size(0))
#         base_top1.update(base_prec1.item(), base_inputs.size(0))
#         base_top5.update(base_prec5.item(), base_inputs.size(0))

#         # measure elapsed time
#         batch_time.update(time.time() - end)
#         end = time.time()

#         if step % print_freq == 0 or step + 1 == len(xloader):
#             Sstr = (
#                 "*SEARCH* "
#                 + time_string()
#                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
#             )
#             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
#                 batch_time=batch_time, data_time=data_time
#             )
#             Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
#                 loss=base_losses, top1=base_top1, top5=base_top5
#             )
#             logger.log(Sstr + " " + Tstr + " " + Wstr)
#     return base_losses.avg, base_top1.avg, base_top5.avg

def train_func_one_arch(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets) in enumerate(
        xloader
    ):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        w_optimizer.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(
            logits.data, base_targets.data, topk=(1, 5)
        )
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = (
                "*SEARCH* "
                + time_string()
                + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
            )
            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                batch_time=batch_time, data_time=data_time
            )
            Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
                loss=base_losses, top1=base_top1, top5=base_top5
            )
            logger.log(Sstr + " " + Tstr + " " + Wstr)
    return base_losses.avg, base_top1.avg, base_top5.avg

def valid_func_one_arch(xloader, network, criterion):
    data_time, batch_time = AverageMeter(), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
    network.eval()
    end = time.time()
    with torch.no_grad():
        for step, (arch_inputs, arch_targets) in enumerate(xloader):
            arch_targets = arch_targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - end)
            # prediction

#             network.module.random_genotype_per_cell(True)
            _, logits = network(arch_inputs)
            arch_loss = criterion(logits, arch_targets)
            # record
            arch_prec1, arch_prec5 = obtain_accuracy(
                logits.data, arch_targets.data, topk=(1, 5)
            )
            arch_losses.update(arch_loss.item(), arch_inputs.size(0))
            arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
            arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
    return arch_losses.avg, arch_top1.avg, arch_top5.avg

In [None]:
start_time, search_time, epoch_time, total_epoch = (
    time.time(),
    AverageMeter(),
    AverageMeter(),
    config.epochs + config.warmup,
)
for epoch in range(0, total_epoch):
    w_scheduler.update(epoch, 0.0)
    need_time = "Time Left: {:}".format(
        convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
    )
    epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
    logger.log(
        "\n[Search the {:}-th epoch] {:}, LR={:}".format(
            epoch_str, need_time, min(w_scheduler.get_lr())
        )
    )

    # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
    search_w_loss, search_w_top1, search_w_top5 = train_func_one_arch(
        train_loader,
        network,
        criterion,
        w_scheduler,
        w_optimizer,
        epoch_str,
        xargs.print_freq,
        logger,
    )
    search_time.update(time.time() - start_time)
    logger.log(
        "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
            epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
        )
    )
    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func_one_arch(
        test_loader, network, criterion
    )
    logger.log(
        "[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
            epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
        )
    )
    
    # check the best accuracy
    valid_accuracies[epoch] = valid_a_top1
    if valid_a_top1 > valid_accuracies["best"]:
        valid_accuracies["best"] = valid_a_top1
        find_best = True
    else:
        find_best = False

    # save checkpoint
    save_path = save_checkpoint(
        {
            "epoch": epoch + 1,
            "args": deepcopy(xargs),
            "search_model": search_model.state_dict(),
            "w_optimizer": w_optimizer.state_dict(),
            "w_scheduler": w_scheduler.state_dict(),
            "genotypes": genotypes,
            "valid_accuracies": valid_accuracies,
        },
        model_base_path,
        logger,
    )
    last_info = save_checkpoint(
        {
            "epoch": epoch + 1,
            "args": deepcopy(args),
            "last_checkpoint": save_path,
        },
        logger.path("info"),
        logger,
    )
    if find_best:
        logger.log(
            "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
                epoch_str, valid_a_top1
            )
        )
        copy_checkpoint(model_base_path, model_best_path, logger)
    if api is not None:
        logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

logger.close()

### best_archs