<a href="https://colab.research.google.com/github/mehravehj/Debiased_supernet_sampling/blob/main/Macro_bench_NAS/MacroBnechNAS_output.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

data_loader

In [None]:
from random import shuffle

import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler


def _data_transforms_cifar10():
    '''
    CIFAR10 data augmentation and normalization
    :return: training set transforms, test set transforms
    '''
    cifar_mean = [0.49139968, 0.48215827, 0.44653124]
    cifar_std = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(cifar_mean, cifar_std),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar_mean, cifar_std),
    ])
    return train_transform, test_transform
def validation_set_indices(num_train, valid_percent):
    '''
    separate randomly the training set to training and validation sets
    :param num_train: training size (currently not used)
    :param valid_percent: what portion of training set to be used for validation
    :return: a list containing [training set indices, validation set indices]
    '''
    train_size = num_train - int(valid_percent * num_train)  # number of training examples
    val_size = num_train - train_size  # number of validation examples
    print('training size:', train_size, ', validation size:', val_size)
    indexes = list(range(num_train))  # available indices at training set
    shuffle(indexes) # shuffle
    indexes = indexes[:num_train] # select the first part
    split = train_size
    train_index = indexes[:split]
    val_index = indexes[split:]
    indices = [train_index, val_index]
    return indices


def data_loader(valid_percent, batch_size, num_train=0, indices=0, dataset_dir='~/Desktop/codes/multires/data/', workers=2):
    '''
    Load dataset with augmentation and spliting of training and validation set
    :param dataset_name: Only for CIFAR10
    :param valid_percent: what portion of training set to be used for validation
    :param batch_size: batch_size
    :param indices: use particular indices rather than randomly separate training set
    :param dataset_dir: dataset directory
    :param workers: number of workers
    :return: train, validation, test data loader, indices, number of classes
    '''
    train_transform_CIFAR, test_transform_CIFAR = _data_transforms_cifar10()
    trainset = torchvision.datasets.CIFAR10(root=dataset_dir, train=True, download=True, transform=train_transform_CIFAR)
    valset = torchvision.datasets.CIFAR10(root=dataset_dir, train=True, download=True, transform=test_transform_CIFAR) # no augmentation for validation set
    testset = torchvision.datasets.CIFAR10(root=dataset_dir, train=False, download=True, transform=test_transform_CIFAR)
    num_class = 10


    if not num_train:
        num_train = len(trainset)

    if not indices: # split and create indices for training and validation set
        indices = validation_set_indices(num_train, valid_percent)

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=SubsetRandomSampler(indices[0]), num_workers=workers, pin_memory=True, drop_last=True) # load training set
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True, drop_last=True) # load test set
    if valid_percent: # load validation set if used
        validation_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, sampler=SubsetRandomSampler(indices[1]), num_workers=workers, pin_memory=True, drop_last=True)
    else:
        validation_loader = 0

    return train_loader, validation_loader, test_loader, indices, num_class


search_space_design

In [None]:
from itertools import combinations

import torch
from torch.distributions.categorical import Categorical as Categorical


def create_search_space(num_layers=10, num_scales=3):
    # create initial tuple based on layers and scales
    num_pooling = num_scales - 1 # number of pooling layers to insert
    num_available_layers = num_layers - 1 # number of availble layers to insert pooling on
    paths =[]
    for positions in combinations(range(num_available_layers), num_pooling):
        p = [0] * num_available_layers

        for i in positions:
            p[i] = 1

        # yield tuple(p)

        paths.append(tuple([0] + p))
    paths = tuple(paths)
    number_paths = len(paths)
    print('all %d paths created: ' %(number_paths))
    print(paths)
    return paths, number_paths


def sample_uniform(paths,num_paths):
    sample_weights = torch.FloatTensor([1 for i in range(num_paths)])  # initialize logits
    prob = Categorical(logits=sample_weights)
    # print('probabilities')
    # print(prob.probs)
    path_index = int(prob.sample().data)
    # sampled_path = paths[path_index]
    return path_index, paths[path_index]




utility_functions

In [None]:
def string_to_list(x, leng):
    if ',' in x:
        x = x.split(',')
        res = [int(i) for i in x]
    else:
        res = [int(x) for i in range(leng-1)]

    return res


lr_scheduler

In [None]:
import math
import torch
from torch.optim.lr_scheduler import _LRScheduler

class CosineAnnealingWarmupRestarts(_LRScheduler):
    """
        optimizer (Optimizer): Wrapped optimizer.
        first_cycle_steps (int): First cycle step size.
        cycle_mult(float): Cycle steps magnification. Default: -1.
        max_lr(float): First cycle's max learning rate. Default: 0.1.
        min_lr(float): Min learning rate. Default: 0.001.
        warmup_steps(int): Linear warmup step size. Default: 0.
        gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
        last_epoch (int): The index of last epoch. Default: -1.
    """

    def __init__(self,
                 optimizer : torch.optim.Optimizer,
                 first_cycle_steps : int,
                 cycle_mult : float = 1.,
                 max_lr : float = 0.1,
                 min_lr : float = 0.001,
                 warmup_steps : int = 0,
                 gamma : float = 1.,
                 last_epoch : int = -1
        ):
        assert warmup_steps < first_cycle_steps

        self.first_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle_mult = cycle_mult # cycle steps magnification
        self.base_max_lr = max_lr # first max learning rate
        self.max_lr = max_lr # max learning rate in the current cycle
        self.min_lr = min_lr # min learning rate
        self.warmup_steps = warmup_steps # warmup step size
        self.gamma = gamma # decrease rate of max learning rate by cycle

        self.cur_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle = 0 # cycle count
        self.step_in_cycle = last_epoch # step size of the current cycle

        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)

        # set learning rate min_lr
        self.init_lr()

    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)

    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.max_lr - base_lr) \
                    * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
                                    / (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch

        self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr


NAS_trainer

In [None]:
import copy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
# from utils import *
# from utils.lr_scheduler import CosineAnnealingWarmupRestarts
# from utils.search_space_design import sample_uniform

def create_optimizer(type, net, lr, m, wd, epochs, m_lr, first_cycle_steps, cycle_mult, warmup_steps, gamma):
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=m, weight_decay=wd)
    if type == 'cosine_anneal':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs + 1, eta_min=m_lr)
    elif type == 'cosine_anneal_wr':
        scheduler = CosineAnnealingWarmupRestarts(optimizer,
                                                  first_cycle_steps=first_cycle_steps,
                                                  cycle_mult=cycle_mult,
                                                  max_lr=lr,
                                                  min_lr=m_lr,
                                                  warmup_steps=warmup_steps,
                                                  gamma=gamma)

    return optimizer, scheduler


def train_valid(model, train_queue, optimizer, paths, validation_queue, decay, criterion=nn.CrossEntropyLoss()):
    # initializing average per epoch metrics
    train_loss = 0 #average train loss
    validation_loss = 0 #average validation loss
    train_accuracy = 0 #average train accuracy
    validation_accuracy = 0 #average validation accuracy
    # iterate over validation set
    validation_iterator = iter(validation_queue)  # validation set iterator
    training_iterator = iter(train_queue)  # train set iterator


    for batch_idx, (train_inputs, train_targets) in enumerate(train_queue):
        #print(batch_idx)
        train_acc = 0
        valid_acc = 0
        # sample paths
        # path_index, pool = sample_uniform(paths,3969)
        path_index, pool = sample_uniform(paths,3969)
        #calculate normalized probabilities
        # counter_matrix[path_index, model_index] += 1
        model.train()
        model.set_path(pool) # setting path
        train_inputs, train_targets = train_inputs.cuda(), train_targets.cuda()
        optimizer.zero_grad()
        train_outputs = model(train_inputs)
        train_minibatch_loss = criterion(train_outputs, train_targets)
        train_minibatch_loss.backward()
        optimizer.step()
        train_loss += train_minibatch_loss.detach().cpu().item()
        train_acc = calculate_accuracy(train_outputs, train_targets)
        train_accuracy += train_acc
        # validation
        model.eval()
        try:
          validation_inputs, validation_targets = next(validation_iterator)
        except:
          validation_iterator = iter(validation_queue)
          validation_inputs, validation_targets = next(validation_iterator)
        validation_inputs, validation_targets = validation_inputs.cuda(), validation_targets.cuda()
        validation_outputs = model(validation_inputs)
        validation_minibatch_loss = criterion(validation_outputs, validation_targets)

        # update weight matrix
        valid_acc = copy.deepcopy(calculate_accuracy(validation_outputs, validation_targets))
        valid_acc_batch = copy.deepcopy(valid_acc[0] / valid_acc[1])
        validation_loss += validation_minibatch_loss.detach().cpu().item()
        validation_accuracy += valid_acc

        # #sanity checks
        # print('selected path:', path_index, path)
        # print('selected model:', model_index)
        # print('counter:', counter_matrix)
        # print('weight_mat:', weight_mat)

    return train_loss, validation_loss, train_accuracy, validation_accuracy


def validate_all(model, paths, num_paths, validation_queue):
    print('evaluating of all models on all paths....')
    init_acc_mat = torch.zeros((num_paths)) # initialize matrix for accuracy
    init_acc_mat_per_class = torch.zeros((num_paths, 10)) # initialize matrix for accuracy

    model.eval()
    for j in range(num_paths):
        valid_accuracy_batch = 0
        valid_accuracy_epoch = 0
        per_class = 0
        model.set_path(paths[j])

        for batch_idx, (validation_inputs, validation_targets) in enumerate(validation_queue):
            validation_inputs, validation_targets = validation_inputs.cuda(), validation_targets.cuda()
            validation_outputs = model(validation_inputs)
            valid_acc = calculate_accuracy(validation_outputs, validation_targets)
            valid_accuracy = copy.deepcopy(valid_acc)
            valid_accuracy_batch += valid_accuracy
            per_class += accuracy_per_class(validation_outputs, validation_targets)
        valid_acc_epoch = valid_accuracy_batch[0] / valid_accuracy_batch[1]
        valid_acc_epoch_per_class = per_class[0,:] / per_class[1,:]
        init_acc_mat[j] = copy.deepcopy(valid_acc_epoch)
        init_acc_mat_per_class[j,:] = copy.deepcopy(valid_acc_epoch_per_class)
    return init_acc_mat, init_acc_mat_per_class

def output_batch(model, paths, num_paths, validation_inputs_batch, train_inputs_batch):
    print('Calculate output matrix....')
    # valid_batch_out = torch.zeros((num_paths, 10)) # valid output
    # train_batch_out = torch.zeros((num_paths, 10)) # train output
    train_out = torch.zeros((num_paths, train_inputs_batch.size(0), 10))
    batch_mat_val = torch.zeros((num_paths, train_inputs_batch.size(0), 10))
    model.eval()
    for j in range(num_paths):
        model.set_path(paths[j])

        train_inputs_batch= train_inputs_batch.cuda()
        train_outputs = model(train_inputs_batch)
        train_out= copy.deepcopy(train_outputs.detach())

        validation_inputs_batch= validation_inputs_batch.cuda()
        validation_outputs = model(validation_inputs_batch)
        batch_mat_val[j,...] = copy.deepcopy(validation_outputs.detach())

    return train_out, batch_mat_val


def accuracy_per_class(logits, target):
    correct_class = torch.zeros(10)
    total_class = torch.zeros(10)
    _, test_predicted = logits.max(1)
    for c in range(10):
        total_class[c] = target.eq(c).sum().item()
        correct_class[c] = (test_predicted.eq(target) * target.eq(c)).sum()
    return torch.vstack((correct_class, total_class))



def calculate_accuracy(logits, target):
    _, test_predicted = logits.max(1)
    batch_total = target.size(0)
    batch_correct = test_predicted.eq(target).sum().item()
    return torch.tensor([batch_correct, batch_total])



create_model

In [None]:
from typing import Optional

import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict


candidate_OP = ['id', 'ir_3x3_t3', 'ir_5x5_t6']
OPS = OrderedDict()
OPS['id'] = lambda inp, oup, stride: Identity(inp=inp, oup=oup, stride=stride)
OPS['ir_3x3_t3'] = lambda inp, oup, stride: InvertedResidual(inp=inp, oup=oup, t=3, stride=stride, k=3)
OPS['ir_5x5_t6'] = lambda inp, oup, stride: InvertedResidual(inp=inp, oup=oup, t=6, stride=stride, k=5)



class Identity(nn.Module):
    def __init__(self, inp, oup, stride):
        super(Identity, self).__init__()
        if stride != 1 or inp != oup:
            self.downsample = nn.Sequential(
                nn.Conv2d(inp, oup, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(oup, affine=False)#, track_running_stats=False),
            )
        else:
            self.downsample = None

    def forward(self, x):
        if self.downsample is not None:
            x = self.downsample(x)
        return x



class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, t, k=3, activation=nn.ReLU, use_se=False, **kwargs):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        self.t = t
        self.k = k
        self.use_se = use_se
        assert stride in [1, 2]
        hidden_dim = round(inp * t)
        if t == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, k, stride, padding=k//2, groups=hidden_dim,
                              bias=False),
                nn.BatchNorm2d(hidden_dim, affine=False)#, track_running_stats=False),
                activation(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup, affine=False)#, track_running_stats=False)
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim, affine=False)#, track_running_stats=False),
                activation(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, k, stride, padding=k//2, groups=hidden_dim,
                              bias=False),
                nn.BatchNorm2d(hidden_dim, affine=False)#, track_running_stats=False),
                activation(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup, affine=False)#, track_running_stats=False),
            )
        self.use_shortcut = inp == oup and stride == 1

    def forward(self, x):
        if self.use_shortcut:
            return self.conv(x) + x
        return self.conv(x)


class supernet(nn.Module): # contains all possible architectures
    def __init__(self, num_classes=10, stages=[2, 3, 3], init_channels=32):
        super(supernet, self).__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, init_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(init_channels, affine=False)#, track_running_stats=False),
            nn.ReLU(inplace=True)
        )

        ops_layers = []

        channels = init_channels
        for stage in stages:
            for idx in range(stage):
                ops_per_layer = []
                for o in candidate_OP:
                    op_func = OPS[o] # operation
                    if idx == 0:
                        # stride = 2
                        ops_per_layer.append(op_func(channels, channels*2, 2).cuda())
                    else:
                        ops_per_layer.append(op_func(channels, channels, 1).cuda())
                if idx == 0:
                    channels *= 2
                ops_layers.append(ops_per_layer)


        self.all_ops = ops_layers
        self.chosen_ops = None

        self.out = nn.Sequential(
            nn.Conv2d(channels, 1280, kernel_size=1, bias=False, stride=1),
            nn.BatchNorm2d(1280, affine=False)#, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        self.classifier = nn.Linear(1280, num_classes, bias=False)


    def forward(self, x):
        x = self.stem(x)
        #print('stem done...')
        x = self.chosen_ops(x)
        x = self.out(x)
        out = self.classifier(x.view(x.size(0), -1))
        return out

    def set_path(self,path): # list of op indices
        path_ = path.copy()
        chosen_ops = []
        #print(len(self.all_ops))
        for l in range(len(path)):
            # print(l)
            op_func = self.all_ops[l][path_.pop(0)]
            chosen_ops.append(op_func)
        self.chosen_ops = nn.Sequential(*chosen_ops)
        #print('path set...', path)
        #print('path set...', self.chosen_ops)

    def feature_extractor(self, x: torch.Tensor) -> List[torch.Tensor]:
        """
        Extracts feature maps after the stem, after each operation in chosen_ops,
        and after each operation in self.out.
        Assumes set_path() has been successfully called.
        """
        feature_maps: List[torch.Tensor] = []

        # 1. Process stem
        x = self.stem(x)
        # To include features after stem: feature_maps.append(x.clone())

        # 2. Process each operation in chosen_ops
        # Assuming self.chosen_ops is not None and is an nn.Sequential
        # as per the user's request to assume set_path() has been called.
        temp_x_chosen_ops = x
        for op_module in self.chosen_ops: # type: ignore # Assuming chosen_ops is nn.Sequential
            temp_x_chosen_ops = op_module(temp_x_chosen_ops)
            feature_maps.append(temp_x_chosen_ops.clone())
        x = temp_x_chosen_ops # Update x to be the output of the chosen_ops sequence

        # 3. Process each operation in self.out
        # self.out is always an nn.Sequential by definition in __init__
        temp_x_out = x
        for out_layer in self.out:
            temp_x_out = out_layer(temp_x_out)
            feature_maps.append(temp_x_out.clone())

        return feature_maps



main

In [None]:
import argparse
import copy
import os
from datetime import datetime
from os import path

import numpy as np

import torch
import torch.nn as nn
import torch.utils.data

import torch.backends.cudnn as cudnn

import pandas as pd

# from create_model import supernet



# from utils.NAS_trainer import create_optimizer, train_valid, validate_all, output_batch
# from utils.data_loader import data_loader
# from utils.search_space_design import create_search_space
# from utils.utility_functions import string_to_list

parser = argparse.ArgumentParser(description='PyTorch Resnet multi model NAS Training')
parser.add_argument('--batchsize', '-b', type=int, default=512, help='batch size')
parser.add_argument('--test_name', '-tn', type=int, default=1000, help='test name for saving model')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--epochs', '-e', type=int, default=200, help='epochs to train')
parser.add_argument('--validation_percent', '-vp', type=float, default=0.5, help='percent of train data for validation')

parser.add_argument('--learning_rate', '-lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--min_learning_rate', '-mlr', default=0.0001, type=float, help='min learning rate')
parser.add_argument('--weight_momentum', '-wm', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', '-wd', default=0.0001, type=float, help='weight decay')

parser.add_argument('--sched_type', '-st', default='cosine_anneal', type=str, help='scheduler type, cosine annealing')
parser.add_argument('--first_cycle_steps', '-fcs', type=int, default=100, help='first cycle epochs')
parser.add_argument('--cycle_mult', '-cm', default=1.0, type=float, help='Cycle steps magnification')
parser.add_argument('--warmup_steps', '-ws', type=int, default=0, help='Linear warmup step size')
parser.add_argument('--gamma', '-gm', default=1.0, type=float, help='Decrease rate of max learning rate by cycle')

parser.add_argument('--data_dir', '-dd', type=str, default='./data/', help='dataset directory')
parser.add_argument('--workers', '-wr', type=int, default=0, help='number of workers to load data')

parser.add_argument('--ema_decay', '-emd', default=0.9, type=float, help='exponential moving average decay')
parser.add_argument('--init_logit', '-ilog', default=1.0, type=float, help='initial logits for path probabilitites')

# parser.add_argument('--local_rank', '-lrank', type=int, default=0)
# parser.add_argument('-ngpu', type=int, default=4)

args = parser.parse_args()

def main():
    print('Supernet pre-training...')
    print('Test: ', args.test_name)
    print('-------------------')
    print(args)
    print('-------------------')
    seed_np = int(np.random.randint(low=0, high=9999, size=None, dtype=int))
    print('random seed is:', seed_np)
    torch.manual_seed(seed_np)
    np.random.seed(seed_np)
    cudnn.benchmark = True


    startTime = datetime.now()
    epochs = args.epochs
    decay = args.ema_decay
    # optimizer parameters
    lr = args.learning_rate
    mlr = args.min_learning_rate
    moment = args.weight_momentum
    w_decay = args.weight_decay
    # schaduler parameters
    sched_type = args.sched_type
    first_cycle_steps = args.first_cycle_steps
    cycle_mult = args.cycle_mult
    warmup_steps = args.warmup_steps
    gamma = args.gamma

    save_dir = './supernet_checkpoint/' + str(args.test_name) + 'supernet_macro_' + str(args.test_name) + '.t7'  # checkpoint save directory


    # create network
    ncat = 10 # for CIFAR
    print('creating mobilenet model....', flush=True)
    layers = 10
    net = supernet()
    net.cuda()
    print('-------------------', flush=True)
    print(net, flush=True)
    #
    # net.set_path([0, 1, 1, 0, 1, 2, 1, 2])
    # print('-------------------')
    # print(net)
    # print('-------------------')

    optimizer, scheduler = create_optimizer(sched_type, net, lr, moment, w_decay, epochs, mlr, first_cycle_steps, cycle_mult, warmup_steps, gamma)

    criterion = nn.CrossEntropyLoss()  # classification loss criterion
    criterion = criterion.cuda()

    current_epoch = 0

    ###loading data
    if path.exists(args.data_dir):
        dataset_dir = args.data_dir
    else:
        dataset_dir = '~/Desktop/codes/multires/data/'

    index = 0
    train_loader, validation_loader, test_loader, indices, num_class = data_loader(args.validation_percent,
                                                                                   args.batchsize,
                                                                                   indices=index,
                                                                                   dataset_dir=dataset_dir,
                                                                                   workers=args.workers)

    ### intialize paths
    num_paths = 3969 # unique paths
    # get unique paths
    df = pd.read_csv('./mat_dir/nas-bench-macro_cifar10_unique.csv', dtype='str')
    paths_string = df['arch']
    paths_list = []
    for path_string in paths_string:
        path_list = [int(x) for x in path_string]
        paths_list.append(path_list)

    paths = tuple(paths_list)

    t_loss = []
    t_acc = []
    v_loss = []
    v_acc = []
    acc_mat = [0 for i in range(epochs+1)]
    acc_mat_per_class = [0 for i in range(epochs+1)]
    # val_all_output_matrix = [0 for i in range(epochs+1)]
    val_batch_output_matrix = [0 for i in range(epochs+1)]
    train_batch_output_matrix = [0 for i in range(epochs+1)]

    # Take first batch of train and validation for output matrix computation
    validation_iterator = iter(validation_loader)  # validation set iterator
    training_iterator = iter(train_loader)  # validation set iterator
    validation_inputs_batch, validation_targets_batch = next(validation_iterator)  # check if stable
    train_inputs_batch, train_targets_batch = next(training_iterator)  # check if stable
    if not os.path.isdir('./supernet_checkpoint/test_' + str(args.test_name)+'/'):
        os.makedirs('./supernet_checkpoint/test_' + str(args.test_name)+'/')
    save_dir = './supernet_checkpoint/test_' + str(args.test_name) + '/supernet_chpt_' + str(
        args.test_name) + '.t7'  # checkpoint save directory
    for epoch in range(current_epoch, args.epochs + 1):
        print('epoch ', epoch,flush=True)
        print('net learning rate: ', optimizer.param_groups[0]['lr'], flush=True)
        if (epoch % 50 == 0 or epoch == args.epochs):  # test and save checkpoint every 5 epochs
            print('Calculating output for one batch...')
            # one batch output matrix calculation for train and validation
            _, val_out = output_batch(net, paths, num_paths, validation_inputs_batch, train_inputs_batch)
            # print(val_out[0,0,:])
            # print(val_out.size())
            # train_batch_output_matrix[epoch] = copy.deepcopy(train_out.detach())
            # val_batch_output_matrix[epoch] = copy.deepcopy(val_out.detach())
            #save
            output_save_dir = './supernet_checkpoint/test_' + str(args.test_name) + '/supernet_chpt_epoch_' + str(
                epoch) + '.t7'  # checkpoint save directory
            torch.save(copy.deepcopy(val_out.detach()), output_save_dir)
            print('Calculating output for all validation set...', flush=True)
            # if (epoch % 20 == 0 or epoch == args.epochs):
            if 0:
                init_acc_mat_per_class, accuracy_mat = validate_all(net, paths, num_paths, validation_loader)
                acc_mat[epoch] = copy.deepcopy(accuracy_mat)
                acc_mat_per_class[epoch] = copy.deepcopy(init_acc_mat_per_class)
                # output_matrix[epoch] = copy.deepcopy(init_acc_mat_per_class)
                # print(init_acc_mat)
            print('Saving models and progress...', flush=True)
            save_checkpoint(save_dir, net, optimizer, scheduler, epoch,t_acc, v_acc, t_loss, v_loss,  index,  acc_mat, acc_mat_per_class, val_batch_output_matrix)

        # train
        train_loss, validation_loss, t_accuracy, v_accuracy = train_valid(net, train_loader, optimizer, paths,
                  validation_loader, decay, criterion=nn.CrossEntropyLoss())

        t_loss.append(train_loss)
        t_acc.append(t_accuracy[0]/t_accuracy[1])
        v_loss.append(validation_loss)
        v_acc.append(v_accuracy[0]/v_accuracy[1])
        print('train  acc', t_accuracy[0]/t_accuracy[1], flush=True)
        print('v acc', v_accuracy[0]/v_accuracy[1], flush=True)

        scheduler.step()
        print('Training time: ', datetime.now() - startTime, flush=True)





def save_checkpoint(save_dir, model, optimizer, scheduler, epoch,t_acc, v_acc, t_loss, v_loss, index, acc_mat, acc_mat_per_class, val_batch_output_matrix):
    state = {
        'test_properties': vars(args),
        'seed': args.seed,
        'indices': index,
        't_loss': t_loss,
        't_acc': t_acc,
        'v_loss': v_loss,
        'v_acc': v_acc,
        'acc_mat': acc_mat,
        'acc_mat_per_class': acc_mat_per_class,
        'model': model.state_dict(),
        'epoch': epoch,
        'weight_optimizer': optimizer.state_dict(),
        'scheduler_state': scheduler.state_dict(),
        'val_batch_output_matrix': val_batch_output_matrix,
    }
    # if not os.path.isdir('checkpoint'):
    #     os.mkdir('checkpoint')
    torch.save(state, save_dir)


if __name__ == '__main__':
  main()

