In [None]:
import os
import sys
import time
import glob
import math
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torch.distributions.categorical as cate
import torchvision.utils as vutils
from tqdm import tqdm

from torch.autograd import Variable
from model_search import Network
from architect import Architect

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [None]:
CIFAR_CLASSES = 10

parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=6, help='batch size')
parser.add_argument('--batch_increase', default=8, type=int, help='how much does the batch size increase after making a decision')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--warmup_dec_epoch', type=int, default=9, help='warmup decision epoch')
parser.add_argument('--decision_freq', type=int, default=5, help='decision freq epoch')
parser.add_argument('--use_history', action='store_true', help='use history for decision')
parser.add_argument('--history_size', type=int, default=4, help='number of stored epoch scores')
parser.add_argument('--post_val', action='store_true', default=False, help='validate after each decision')
#args = parser.parse_args()
args, _ = parser.parse_known_args()

In [None]:
'''
def load_data(args):
    torch.cuda.empty_cache()
    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
    _, train_data, valid_data = torch.utils.data.random_split(train_data, [30000, 10000, 10000])
    test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True, num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True, num_workers=2)
    
    test_queue = torch.utils.data.DataLoader(
        test_data, batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True, num_workers=2)
    
    return train_queue, valid_queue, test_queue
'''


def load_data(args):
    torch.cuda.empty_cache()
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,)),])
    
    train_data = datasets.MNIST(root=args.data, download=True, train=True, transform=transform)
    _, train_data, test_data = torch.utils.data.random_split(train_data, [58000, 1000, 1000])
    valid_data = datasets.MNIST(root=args.data, download=True, train=False, transform=transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True, num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True, num_workers=2)
    
    test_queue = torch.utils.data.DataLoader(
        test_data, batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True, num_workers=2)
    
    return train_queue, valid_queue, test_queue

def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch):
    objs = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    
    pbar = enumerate(train_queue)
    print(('\n' + '%10s' * 4) % ('Epoch', 'Loss', 'Top1', 'Top5'))
    pbar = tqdm(pbar, total=len(train_queue))
    
    for step, (input, target) in pbar:
        model.train()
        n = input.size(0)

        input = Variable(input, requires_grad=False).cuda()
        target = Variable(target, requires_grad=False).cuda(async=True)

        # get a random minibatch from the search queue with replacement
        input_search, target_search = next(iter(valid_queue))
        input_search = Variable(input_search, requires_grad=False).cuda()
        target_search = Variable(target_search, requires_grad=False).cuda(async=True)

        # Algorithm 1. Update undetermined architecture parameters(only alpha)
        architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
        
        # Algorithm 2. Update weights W
        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
        
        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        objs.update(loss.item(), n)
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)
        
        msg = ('%10s' * 4) % (str(epoch),
                              np.round(objs.avg, 4),
                              np.round(top1.avg, 4),
                              np.round(top5.avg, 4))
        pbar.set_description(msg)
            
    return top1.avg, objs.avg

def inference(valid_queue, model, criterion, epoch):
    objs = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    
    pbar = enumerate(valid_queue)
    print(('\n' + '%10s' * 4) % ('Epoch', 'Loss', 'Top1', 'Top5'))
    pbar = tqdm(pbar, total=len(valid_queue))
    
    model.eval()
    for step, (input, target) in enumerate(valid_queue):
        input = Variable(input).cuda()
        target = Variable(target).cuda(async=True)

        logits = model(input)
        loss = criterion(logits, target)

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.item(), n)
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)
        
        msg = ('%10s' * 4) % (str(epoch),
                              np.round(objs.avg, 4),
                              np.round(top1.avg, 4),
                              np.round(top5.avg, 4))
        pbar.set_description(msg)

    return top1.avg, objs.avg

def edge_decision(type, alphas, selected_idxs, candidate_flags, probs_history, epoch, model, args):
    mat = F.softmax(torch.stack(alphas, dim=0), dim=-1).detach()
    
    # Formula 1
    importance = torch.sum(mat[:, 1:], dim=-1)

    # Formula 2
    probs = mat[:, 1:] / importance[:, None]
    entropy = cate.Categorical(probs=probs).entropy() / math.log(probs.size()[1])


    if args.use_history: # SGAS Cri.2 
        # Formula 3
        histogram_inter = histogram_average(probs_history, probs)
        probs_history.append(probs)
        if (len(probs_history) > args.history_size):
            probs_history.pop(0)
        
        # Formula 5
        score = utils.normalize(importance) * utils.normalize(
            1 - entropy) * utils.normalize(histogram_inter)

    else: # SGAS Cri.1
        # Formula 4
        score = utils.normalize(importance) * utils.normalize(1 - entropy)


    if torch.sum(candidate_flags.int()) > 0 and \
            epoch >= args.warmup_dec_epoch and \
            (epoch - args.warmup_dec_epoch) % args.decision_freq == 0:
        masked_score = torch.min(score,(2 * candidate_flags.float() - 1) * np.inf)
        selected_edge_idx = torch.argmax(masked_score)
        selected_op_idx = torch.argmax(probs[selected_edge_idx]) + 1 # add 1 since none op
        selected_idxs[selected_edge_idx] = selected_op_idx

        candidate_flags[selected_edge_idx] = False
        alphas[selected_edge_idx].requires_grad = False
        if type == 'normal':
            reduction = False
        elif type == 'reduce':
            reduction = True
        else:
            raise Exception('Unknown Cell Type')
        candidate_flags, selected_idxs = model.check_edges(candidate_flags,selected_idxs,reduction=reduction)
        print(type + "_candidate_flags {}".format(candidate_flags))
        score_image(type, score, epoch)
        return True, selected_idxs, candidate_flags

    else:
        print(type + "_candidate_flags {}".format(candidate_flags))
        score_image(type, score, epoch)
        return False, selected_idxs, candidate_flags




def main(args):
    torch.cuda.empty_cache()
    criterion = nn.CrossEntropyLoss().cuda()
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion, in_channel=1).cuda()
    #print(sum(x.numel() for x in model.parameters()))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    train_queue, valid_queue, test_queue = load_data(args)
    
    
    num_edges = model._steps * 2
    post_train = 5
    epochs = args.warmup_dec_epoch + args.decision_freq * (num_edges - 1) + post_train + 1
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(epochs), eta_min=args.learning_rate_min)
    print('num_edges:', num_edges, ' epochs:', epochs)
    
    architect = Architect(model, args)
    
    normal_selected_idxs = torch.tensor(len(model.alphas_normal) * [-1], requires_grad=False, dtype=torch.int).cuda()
    reduce_selected_idxs = torch.tensor(len(model.alphas_reduce) * [-1], requires_grad=False, dtype=torch.int).cuda()
    normal_candidate_flags = torch.tensor(len(model.alphas_normal) * [True], requires_grad=False, dtype=torch.bool).cuda()
    reduce_candidate_flags = torch.tensor(len(model.alphas_reduce) * [True], requires_grad=False, dtype=torch.bool).cuda()
    model.normal_selected_idxs = normal_selected_idxs
    model.reduce_selected_idxs = reduce_selected_idxs
    model.normal_candidate_flags = normal_candidate_flags
    model.reduce_candidate_flags = reduce_candidate_flags
    #print(F.softmax(torch.stack(model.alphas_normal, dim=0), dim=-1).detach())
    #print(F.softmax(torch.stack(model.alphas_reduce, dim=0), dim=-1).detach())
    
    count = 0
    normal_probs_history = []
    reduce_probs_history = []
    

    for epoch in range(epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        
        # train
        train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch)
        
        # validation
        with torch.no_grad():
            inference(test_queue, model, criterion, epoch)
            
        # greedy decision
        saved_memory_normal, model.normal_selected_idxs, \
        model.normal_candidate_flags = edge_decision('normal',
                                                     model.alphas_normal,
                                                     model.normal_selected_idxs,
                                                     model.normal_candidate_flags,
                                                     normal_probs_history,
                                                     epoch,
                                                     model,
                                                     args)

        saved_memory_reduce, model.reduce_selected_idxs, \
        model.reduce_candidate_flags = edge_decision('reduce',
                                                     model.alphas_reduce,
                                                     model.reduce_selected_idxs,
                                                     model.reduce_candidate_flags,
                                                     reduce_probs_history,
                                                     epoch,
                                                     model,
                                                     args)
        
        
        if saved_memory_normal or saved_memory_reduce:
            del train_queue, valid_queue
            torch.cuda.empty_cache()

            count += 1
            new_batch_size = args.batch_size + args.batch_increase * count
            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=new_batch_size,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
                pin_memory=True, num_workers=2)

            valid_queue = torch.utils.data.DataLoader(
                train_data, batch_size=new_batch_size,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
                pin_memory=True, num_workers=2)

In [None]:
main(args)