Note that this code was adapted from the pytorch macro search implementation of ENAS, credits to https://github.com/MengTianjian/enas-pytorch

In [None]:
import os
import sys


import torch
import torch.nn as nn
import numpy as np
import time
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data.dataset import Subset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.distributions.categorical import Categorical
import torch.nn.functional as F

Visualization

In [None]:
import matplotlib.pyplot as plt
% matplotlib inline

Arguments

In [None]:
args = {
    'description':'ENAS',
    'search_for':'macro',
    'data_path':'export/data/',
    'output_filename':'ENAS',
    'resume':'',
    'batch_size':128,
    'num_epochs':30,
    'log_every':50,
    'eval_every_epochs':1,
    'seed':69,
    'cutout':0,
    'fixed_arc':False,
    'child_num_layers':12,
    'child_out_filters':36,
    'child_grad_bound':5.0,
    'child_l2_reg':0.00025,
    'child_num_branches':3,
    'child_keep_prob':0.9,
    'child_lr_max':0.05,
    'child_lr_min':0.0005,
    'child_lr_T':10,
    'controller_lstm_size':64,
    'controller_lstm_num_layers':1,
    'controller_entropy_weight':0.0001,
    'controller_train_every':1,
    'controller_num_aggregate':20,
    'controller_train_steps':50,
    'controller_lr':0.001,
    'controller_tanh_constant':1.5,
    'controller_op_tanh_reduce':2.5,
    'controller_skip_target':0.4,
    'controller_skip_weight':0.8,
    'controller_bl_dec':0.99
}

Helper Functions


In [None]:
class Cutout(object):
    """Randomly mask out a patche from an image.
    Args:
        length (int): The length (in pixels) of each square patch.
        p (float): The probability of cutout being applied.
    """
    def __init__(self, length, p=0.5):
        self.length = length
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """

        if np.random.rand() < self.p:

            h = img.size(1)
            w = img.size(2)

            mask = np.ones((h, w), np.float32)

            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length, 0, h)
            y2 = np.clip(y + self.length, 0, h)
            x1 = np.clip(x - self.length, 0, w)
            x2 = np.clip(x + self.length, 0, w)

            mask[y1: y2, x1: x2] = 0.

            mask = torch.from_numpy(mask)
            mask = mask.expand_as(img)
            img = img * mask

        return img

In [None]:
# class Logger(object):
#     def __init__(self, filename):
#         self.terminal = sys.stdout
#         self.log = open(filename, 'w')

#     def write(self, message):
#         self.terminal.write(message)
#         self.log.write(message)
#         self.log.flush()

#     def flush(self):
#         #this flush method is needed for python 3 compatibility.
#         #this handles the flush command by doing nothing.
#         #you might want to specify some extra behavior here.
#         pass

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

Controller

In [None]:
class Controller(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py
    '''
    def __init__(self,
                 search_for="macro",
                 search_whole_channels=True,
                 num_layers=12,
                 num_branches=6,
                 out_filters=36,
                 lstm_size=32,
                 lstm_num_layers=2,
                 tanh_constant=1.5,
                 temperature=None,
                 skip_target=0.4,
                 skip_weight=0.8):
        super(Controller, self).__init__()

        self.search_for = search_for
        self.search_whole_channels = search_whole_channels
        self.num_layers = num_layers
        self.num_branches = num_branches
        self.out_filters = out_filters

        self.lstm_size = lstm_size
        self.lstm_num_layers = lstm_num_layers
        self.tanh_constant = tanh_constant
        self.temperature = temperature

        self.skip_target = skip_target
        self.skip_weight = skip_weight

        self._create_params()

    def _create_params(self):
        '''
        https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L83
        '''
        self.w_lstm = nn.LSTM(input_size=self.lstm_size,
                              hidden_size=self.lstm_size,
                              num_layers=self.lstm_num_layers)

        self.g_emb = nn.Embedding(1, self.lstm_size)  # Learn the starting input

        if self.search_whole_channels:
            self.w_emb = nn.Embedding(self.num_branches, self.lstm_size)
            self.w_soft = nn.Linear(self.lstm_size, self.num_branches, bias=False)
        else:
            assert False, "Not implemented error: search_whole_channels = False"

        self.w_attn_1 = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
        self.w_attn_2 = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
        self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)

        self._reset_params()

    def _reset_params(self):
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
                nn.init.uniform_(m.weight, -0.1, 0.1)

        nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
        nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)

    def forward(self):
        '''
        https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L126
        '''
        h0 = None  # setting h0 to None will initialize LSTM state with 0s

        anchors = []
        anchors_w_1 = []

        arc_seq = {}
        entropys = []
        log_probs = []
        skip_count = []
        skip_penaltys = []

        inputs = self.g_emb.weight
        skip_targets = torch.tensor([1.0 - self.skip_target, self.skip_target]).cuda()

        for layer_id in range(self.num_layers):
            if self.search_whole_channels:
                inputs = inputs.unsqueeze(0)
                output, hn = self.w_lstm(inputs, h0)
                output = output.squeeze(0)
                h0 = hn

                logit = self.w_soft(output)
                if self.temperature is not None:
                    logit /= self.temperature
                if self.tanh_constant is not None:
                    logit = self.tanh_constant * torch.tanh(logit)

                branch_id_dist = Categorical(logits=logit)
                branch_id = branch_id_dist.sample()

                arc_seq[str(layer_id)] = [branch_id]

                log_prob = branch_id_dist.log_prob(branch_id)
                log_probs.append(log_prob.view(-1))
                entropy = branch_id_dist.entropy()
                entropys.append(entropy.view(-1))

                inputs = self.w_emb(branch_id)
                inputs = inputs.unsqueeze(0)
            else:
                # https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L171
                assert False, "Not implemented error: search_whole_channels = False"

            output, hn = self.w_lstm(inputs, h0)
            output = output.squeeze(0)

            if layer_id > 0:
                query = torch.cat(anchors_w_1, dim=0)
                query = torch.tanh(query + self.w_attn_2(output))
                query = self.v_attn(query)
                logit = torch.cat([-query, query], dim=1)
                if self.temperature is not None:
                    logit /= self.temperature
                if self.tanh_constant is not None:
                    logit = self.tanh_constant * torch.tanh(logit)

                skip_dist = Categorical(logits=logit)
                skip = skip_dist.sample()
                skip = skip.view(layer_id)

                arc_seq[str(layer_id)].append(skip)

                skip_prob = torch.sigmoid(logit)
                kl = skip_prob * torch.log(skip_prob / skip_targets)
                kl = torch.sum(kl)
                skip_penaltys.append(kl)

                log_prob = skip_dist.log_prob(skip)
                log_prob = torch.sum(log_prob)
                log_probs.append(log_prob.view(-1))

                entropy = skip_dist.entropy()
                entropy = torch.sum(entropy)
                entropys.append(entropy.view(-1))

                # Calculate average hidden state of all nodes that got skips
                # and use it as input for next step
                skip = skip.type(torch.float)
                skip = skip.view(1, layer_id)
                skip_count.append(torch.sum(skip))
                inputs = torch.matmul(skip, torch.cat(anchors, dim=0))
                inputs /= (1.0 + torch.sum(skip))

            else:
                inputs = self.g_emb.weight

            anchors.append(output)
            anchors_w_1.append(self.w_attn_1(output))

        self.sample_arc = arc_seq

        entropys = torch.cat(entropys)
        self.sample_entropy = torch.sum(entropys)

        log_probs = torch.cat(log_probs)
        self.sample_log_prob = torch.sum(log_probs)

        skip_count = torch.stack(skip_count)
        self.skip_count = torch.sum(skip_count)

        skip_penaltys = torch.stack(skip_penaltys)
        self.skip_penaltys = torch.mean(skip_penaltys)

Child

In [None]:
class FactorizedReduction(nn.Module):
    '''
    Reduce both spatial dimensions (width and height) by a factor of 2, and 
    potentially to change the number of output filters
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L129
    '''

    def __init__(self, in_planes, out_planes, stride=2):
        super(FactorizedReduction, self).__init__()

        assert out_planes % 2 == 0, (
        "Need even number of filters when using this factorized reduction.")

        self.in_planes = in_planes
        self.out_planes = out_planes
        self.stride = stride

        if stride == 1:
            self.fr = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_planes, track_running_stats=False))
        else:
            self.path1 = nn.Sequential(
                nn.AvgPool2d(1, stride=stride),
                nn.Conv2d(in_planes, out_planes // 2, kernel_size=1, bias=False))

            self.path2 = nn.Sequential(
                nn.AvgPool2d(1, stride=stride),
                nn.Conv2d(in_planes, out_planes // 2, kernel_size=1, bias=False))
            self.bn = nn.BatchNorm2d(out_planes, track_running_stats=False)

    def forward(self, x):
        if self.stride == 1:
            return self.fr(x)
        else:
            path1 = self.path1(x)

            # pad the right and the bottom, then crop to include those pixels
            path2 = F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0.)
            path2 = path2[:, :, 1:, 1:]
            path2 = self.path2(path2)

            out = torch.cat([path1, path2], dim=1)
            out = self.bn(out)
            return out

class ENASLayer(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L245
    '''
    def __init__(self, layer_id, in_planes, out_planes):
        super(ENASLayer, self).__init__()

        self.layer_id = layer_id
        self.in_planes = in_planes
        self.out_planes = out_planes
#TO POISTON THE SEARCH SPACE TAKE OUT VARIOUS COMBINATIONS OF THE FOLLOWING BRANCHES AND MAKE THE CHANGES TO THE FORWARD FUNCTION BELLOW! 
#WE SUGGEST TAKING OUT 3 BY 3 CONVOLUTIOJNS AND ONE OF THE POOLING BRANCHES AS AN APPROPRIATE POISONING
        self.branch_0 = ConvBranch(in_planes, out_planes, kernel_size=3)
        self.branch_1 = ConvBranch(in_planes, out_planes, kernel_size=3, separable=True)
        self.branch_2 = ConvBranch(in_planes, out_planes, kernel_size=5)
        self.branch_3 = ConvBranch(in_planes, out_planes, kernel_size=5, separable=True)
        self.branch_4 = PoolBranch(in_planes, out_planes, 'avg')
        self.branch_5 = PoolBranch(in_planes, out_planes, 'max')

        self.bn = nn.BatchNorm2d(out_planes, track_running_stats=False)

    def forward(self, x, prev_layers, sample_arc):
        layer_type = sample_arc[0]
        if self.layer_id > 0:
            skip_indices = sample_arc[1]
        else:
            skip_indices = []
#CHANGE THIS BASED ON THE POISONING MENTIONED ABOVE  
        if layer_type == 0:
            out = self.branch_0(x)
        elif layer_type == 1:
            out = self.branch_1(x)
        elif layer_type == 2:
            out = self.branch_2(x)
        elif layer_type == 3:
            out = self.branch_3(x)
        elif layer_type == 4:
            out = self.branch_4(x)
        elif layer_type == 5:
            out = self.branch_5(x)
        else:
            raise ValueError("Unknown layer_type {}".format(layer_type))

        for i, skip in enumerate(skip_indices):
            if skip == 1:
                out += prev_layers[i]

        out = self.bn(out)
        return out

class FixedLayer(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L245
    '''
    def __init__(self, layer_id, in_planes, out_planes, sample_arc):
        super(FixedLayer, self).__init__()

        self.layer_id = layer_id
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.sample_arc = sample_arc

        self.layer_type = sample_arc[0]
        if self.layer_id > 0:
            self.skip_indices = sample_arc[1]
        else:
            self.skip_indices = torch.zeros(1)

        
        if self.layer_type == 0:
            self.branch = ConvBranch(in_planes, out_planes, kernel_size=5)
        elif self.layer_type == 1:
            self.branch = ConvBranch(in_planes, out_planes, kernel_size=5, separable=True)
        elif self.layer_type == 2:
            self.branch = PoolBranch(in_planes, out_planes, 'avg')
      
        else:
            raise ValueError("Unknown layer_type {}".format(self.layer_type))

        # Use concatentation instead of addition in the fixed layer for some reason
        in_planes = int((torch.sum(self.skip_indices).item() + 1) * in_planes)
        self.dim_reduc = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(out_planes, track_running_stats=False))

    def forward(self, x, prev_layers, sample_arc):
        out = self.branch(x)

        res_layers = []
        for i, skip in enumerate(self.skip_indices):
            if skip == 1:
                res_layers.append(prev_layers[i])
        prev = res_layers + [out]
        prev = torch.cat(prev, dim=1)

        out = self.dim_reduc(prev)
        return out

class SeparableConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, bias):
        super(SeparableConv, self).__init__()
        padding = (kernel_size - 1) // 2
        self.depthwise = nn.Conv2d(in_planes, in_planes, kernel_size=kernel_size,
                                   padding=padding, groups=in_planes, bias=bias)
        self.pointwise = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

class ConvBranch(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L483
    '''
    def __init__(self, in_planes, out_planes, kernel_size, separable=False):
        super(ConvBranch, self).__init__()
        assert kernel_size in [3,5], "Kernel size must be either 3 or 5"

        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.separable = separable

        self.inp_conv1 = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_planes, track_running_stats=False),
            nn.ReLU())

        if separable:
            self.out_conv = nn.Sequential(
                SeparableConv(in_planes, out_planes, kernel_size=kernel_size, bias=False),
                nn.BatchNorm2d(out_planes, track_running_stats=False),
                nn.ReLU())
        else:
            padding = (kernel_size - 1) // 2
            self.out_conv = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                          padding=padding, bias=False),
                nn.BatchNorm2d(out_planes, track_running_stats=False),
                nn.ReLU())

    def forward(self, x):
        out = self.inp_conv1(x)
        out = self.out_conv(out)
        return out

class PoolBranch(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L546
    '''
    def __init__(self, in_planes, out_planes, avg_or_max):
        super(PoolBranch, self).__init__()

        self.in_planes = in_planes
        self.out_planes = out_planes
        self.avg_or_max = avg_or_max

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_planes, track_running_stats=False),
            nn.ReLU())

        if avg_or_max == 'avg':
            self.pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        elif avg_or_max == 'max':
            self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        else:
            raise ValueError("Unknown pool {}".format(avg_or_max))

    def forward(self, x):
        out = self.conv1(x)
        out = self.pool(out)
        return out

class SharedCNN(nn.Module):
    def __init__(self,
                 num_layers=12,
                 num_branches=6,
                 out_filters=24,
                 keep_prob=1.0,
                 fixed_arc=None
                 ):
        super(SharedCNN, self).__init__()

        self.num_layers = num_layers
        self.num_branches = num_branches
        self.out_filters = out_filters
        self.keep_prob = keep_prob
        self.fixed_arc = fixed_arc

        pool_distance = self.num_layers // 3
        self.pool_layers = [pool_distance - 1, 2 * pool_distance - 1]

        self.stem_conv = nn.Sequential(
            nn.Conv2d(3, out_filters, kernel_size=3, padding=1, bias=False),
            # nn.Conv2d(1, out_filters, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_filters, track_running_stats=False))

        self.layers = nn.ModuleList([])
        self.pooled_layers = nn.ModuleList([])

        for layer_id in range(self.num_layers):
            if self.fixed_arc is None:
                layer = ENASLayer(layer_id, self.out_filters, self.out_filters)
            else:
                layer = FixedLayer(layer_id, self.out_filters, self.out_filters, self.fixed_arc[str(layer_id)])
            self.layers.append(layer)

            if layer_id in self.pool_layers:
                for i in range(len(self.layers)):
                    if self.fixed_arc is None:
                        self.pooled_layers.append(FactorizedReduction(self.out_filters, self.out_filters))
                    else:
                        self.pooled_layers.append(FactorizedReduction(self.out_filters, self.out_filters * 2))
                if self.fixed_arc is not None:
                    self.out_filters *= 2

        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=1. - self.keep_prob)
        self.classify = nn.Linear(self.out_filters, 10)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x, sample_arc):

        x = self.stem_conv(x)

        prev_layers = []
        pool_count = 0
        for layer_id in range(self.num_layers):
            x = self.layers[layer_id](x, prev_layers, sample_arc[str(layer_id)])
            prev_layers.append(x)
            if layer_id in self.pool_layers:
                for i, prev_layer in enumerate(prev_layers):
                    # Go through the outputs of all previous layers and downsample them
                    prev_layers[i] = self.pooled_layers[pool_count](prev_layer)
                    pool_count += 1
                x = prev_layers[-1]

        x = self.global_avg_pool(x)
        x = x.view(x.shape[0], -1)
        x = self.dropout(x)
        out = self.classify(x)

        return out

Training

In [None]:
def train_shared_cnn(epoch,
                     controller,
                     shared_cnn,
                     data_loaders,
                     shared_cnn_optimizer,
                     fixed_arc=None):
    """Train shared_cnn by sampling architectures from the controller.
    Args:
        epoch: Current epoch.
        controller: Controller module that generates architectures to be trained.
        shared_cnn: CNN that contains all possible architectures, with shared weights.
        data_loaders: Dict containing data loaders.
        shared_cnn_optimizer: Optimizer for the shared_cnn.
        fixed_arc: Architecture to train, overrides the controller sample
        ...
    
    Returns: Nothing.
    """

    controller.eval()

    if fixed_arc is None:
        # Use a subset of the training set when searching for an arhcitecture
        train_loader = data_loaders['train_subset']
    else:
        # Use the full training set when training a fixed architecture
        train_loader = data_loaders['train_dataset']

    train_acc_meter = AverageMeter()
    loss_meter = AverageMeter()

    for i, (images, labels) in enumerate(train_loader):
        start = time.time()
        images = images.cuda()
        labels = labels.cuda()

        if fixed_arc is None:
            with torch.no_grad():
                controller()  # perform forward pass to generate a new architecture
            sample_arc = controller.sample_arc
        else:
            sample_arc = fixed_arc

        shared_cnn.zero_grad()
        pred = shared_cnn(images, sample_arc)
        loss = nn.CrossEntropyLoss()(pred, labels)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), args['child_grad_bound'])
        shared_cnn_optimizer.step()

        train_acc = torch.mean((torch.max(pred, 1)[1] == labels).type(torch.float))

        train_acc_meter.update(train_acc.item())
        loss_meter.update(loss.item())

        end = time.time()

        if (i) % args['log_every'] == 0:
            learning_rate = shared_cnn_optimizer.param_groups[0]['lr']
            display = 'epoch=' + str(epoch) + \
                      '\tch_step=' + str(i) + \
                      '\tloss=%.6f' % (loss_meter.val) + \
                      '\tlr=%.4f' % (learning_rate) + \
                      '\t|g|=%.4f' % (grad_norm.item()) + \
                      '\tacc=%.4f' % (train_acc_meter.val) + \
                      '\ttime=%.2fit/s' % (1. / (end - start))
            print(display)

    
    # vis_win['shared_cnn_acc'] = vis.line(
    #     X=np.array([epoch]),
    #     Y=np.array([train_acc_meter.avg]),
    #     win=vis_win['shared_cnn_acc'],
    #     opts=dict(title='shared_cnn_acc', xlabel='Iteration', ylabel='Accuracy'),
    #     update='append' if epoch > 0 else None)

    plt.figure()
    plt.plot(np.array([epoch]),
             np.array([train_acc_meter.avg]),
             'ro-',
             label='Shared CNN Accuracy')
    plt.plot(np.array([epoch]),
             np.array([loss_meter.avg]),
             'go-',
             label='Shared CNN Loss')
    plt.legend()
    plt.title('Shared CNN')
    plt.xlabel('Iteration')
    plt.savefig("child_training_curve.png")
    plt.show()

    # vis_win['shared_cnn_loss'] = vis.line(
    #     X=np.array([epoch]),
    #     Y=np.array([loss_meter.avg]),
    #     win=vis_win['shared_cnn_loss'],
    #     opts=dict(title='shared_cnn_loss', xlabel='Iteration', ylabel='Loss'),
    #     update='append' if epoch > 0 else None)

    controller.train()

def train_controller(epoch,
                     controller,
                     shared_cnn,
                     data_loaders,
                     controller_optimizer,
                     baseline=None):
    """Train controller to optimizer validation accuracy using REINFORCE.
    Args:
        epoch: Current epoch.
        controller: Controller module that generates architectures to be trained.
        shared_cnn: CNN that contains all possible architectures, with shared weights.
        data_loaders: Dict containing data loaders.
        controller_optimizer: Optimizer for the controller.
        baseline: The baseline score (i.e. average val_acc) from the previous epoch
    
    Returns: 
        baseline: The baseline score (i.e. average val_acc) for the current epoch
    For more stable training we perform weight updates using the average of
    many gradient estimates. controller_num_aggregate indicates how many samples
    we want to average over (default = 20). By default PyTorch will sum gradients
    each time .backward() is called (as long as an optimizer step is not taken),
    so each iteration we divide the loss by controller_num_aggregate to get the 
    average.
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L270
    """
    print('Epoch ' + str(epoch) + ': Training controller')

    shared_cnn.eval()
    valid_loader = data_loaders['valid_subset']

    reward_meter = AverageMeter()
    baseline_meter = AverageMeter()
    val_acc_meter = AverageMeter()
    loss_meter = AverageMeter()

    controller.zero_grad()
    for i in range(args['controller_train_steps'] * args['controller_num_aggregate']):
        start = time.time()
        images, labels = next(iter(valid_loader))
        images = images.cuda()
        labels = labels.cuda()

        controller()  # perform forward pass to generate a new architecture
        sample_arc = controller.sample_arc

        with torch.no_grad():
            pred = shared_cnn(images, sample_arc)
        val_acc = torch.mean((torch.max(pred, 1)[1] == labels).type(torch.float))

        # detach to make sure that gradients aren't backpropped through the reward
        reward = torch.tensor(val_acc.detach())
        reward += args['controller_entropy_weight'] * controller.sample_entropy

        if baseline is None:
            baseline = val_acc
        else:
            baseline -= (1 - args['controller_bl_dec']) * (baseline - reward)
            # detach to make sure that gradients are not backpropped through the baseline
            baseline = baseline.detach()

        loss = -1 * controller.sample_log_prob * (reward - baseline)

        if args['controller_skip_weight'] is not None:
            loss += args['controller_skip_weight'] * controller.skip_penaltys

        reward_meter.update(reward.item())
        baseline_meter.update(baseline.item())
        val_acc_meter.update(val_acc.item())
        loss_meter.update(loss.item())

        # Average gradient over controller_num_aggregate samples
        loss = loss / args['controller_num_aggregate']

        loss.backward(retain_graph=True)

        end = time.time()

        # Aggregate gradients for controller_num_aggregate iterationa, then update weights
        if (i + 1) % args['controller_num_aggregate'] == 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), args['child_grad_bound'])
            controller_optimizer.step()
            controller.zero_grad()

            if (i + 1) % (2 * args['controller_num_aggregate']) == 0:
                learning_rate = controller_optimizer.param_groups[0]['lr']
                display = 'ctrl_step=' + str(i // args['controller_num_aggregate']) + \
                          '\tloss=%.3f' % (loss_meter.val) + \
                          '\tent=%.2f' % (controller.sample_entropy.item()) + \
                          '\tlr=%.4f' % (learning_rate) + \
                          '\t|g|=%.4f' % (grad_norm.item()) + \
                          '\tacc=%.4f' % (val_acc_meter.val) + \
                          '\tbl=%.2f' % (baseline_meter.val) + \
                          '\ttime=%.2fit/s' % (1. / (end - start))
                print(display)
      
    plt.figure()
    plt.plot(np.column_stack([epoch] * 2), 
             np.column_stack([reward_meter.avg, baseline_meter.avg]), 
             'go-',
             label='Controller Reward')
    
    plt.plot(np.array([epoch]),
             np.array([val_acc_meter.avg]),
             'ro-',
             label='Controller Accuracy')
    
    plt.plot(np.array([epoch]),
             np.array([loss_meter.avg]),
             'bo-',
             label='Controller Loss')
    plt.legend()
    plt.title('Controller')
    plt.xlabel('Iteration')
    plt.savefig("controller_training_curve.png")
    plt.show()
    
    # vis_win['controller_reward'] = vis.line(
    #     X=np.column_stack([epoch] * 2),
    #     Y=np.column_stack([reward_meter.avg, baseline_meter.avg]),
    #     win=vis_win['controller_reward'],
    #     opts=dict(title='controller_reward', xlabel='Iteration', ylabel='Reward'),
    #     update='append' if epoch > 0 else None)

    # vis_win['controller_acc'] = vis.line(
    #     X=np.array([epoch]),
    #     Y=np.array([val_acc_meter.avg]),
    #     win=vis_win['controller_acc'],
    #     opts=dict(title='controller_acc', xlabel='Iteration', ylabel='Accuracy'),
    #     update='append' if epoch > 0 else None)

    # vis_win['controller_loss'] = vis.line(
    #     X=np.array([epoch]),
    #     Y=np.array([loss_meter.avg]),
    #     win=vis_win['controller_loss'],
    #     opts=dict(title='controller_loss', xlabel='Iteration', ylabel='Loss'),
    #     update='append' if epoch > 0 else None)

    shared_cnn.train()
    return baseline

def train_enas(start_epoch,
               controller,
               shared_cnn,
               data_loaders,
               shared_cnn_optimizer,
               controller_optimizer,
               shared_cnn_scheduler):
    """Perform architecture search by training a controller and shared_cnn.
    Args:
        start_epoch: Epoch to begin on.
        controller: Controller module that generates architectures to be trained.
        shared_cnn: CNN that contains all possible architectures, with shared weights.
        data_loaders: Dict containing data loaders.
        shared_cnn_optimizer: Optimizer for the shared_cnn.
        controller_optimizer: Optimizer for the controller.
        shared_cnn_scheduler: Learning rate schedular for shared_cnn_optimizer
    
    Returns: Nothing.
    """

    baseline = None
    for epoch in range(start_epoch, args['num_epochs']):

        train_shared_cnn(epoch,
                         controller,
                         shared_cnn,
                         data_loaders,
                         shared_cnn_optimizer)

        baseline = train_controller(epoch,
                                    controller,
                                    shared_cnn,
                                    data_loaders,
                                    controller_optimizer,
                                    baseline)

        if epoch % args['eval_every_epochs'] == 0:
            evaluate_model(epoch, controller, shared_cnn, data_loaders)

        # shared_cnn_scheduler.step(epoch)
        shared_cnn_scheduler.step()

        state = {'epoch': epoch + 1,
                 'args': args,
                 'shared_cnn_state_dict': shared_cnn.state_dict(),
                 'controller_state_dict': controller.state_dict(),
                 'shared_cnn_optimizer': shared_cnn_optimizer.state_dict(),
                 'controller_optimizer': controller_optimizer.state_dict()}
        filename = args['output_filename'] + '.pth.tar'
        torch.save(state, filename)

def train_fixed(start_epoch,
                controller,
                shared_cnn,
                data_loaders):
    """Train a fixed cnn architecture.
    Args:
        start_epoch: Epoch to begin on.
        controller: Controller module that generates architectures to be trained.
        shared_cnn: CNN that contains all possible architectures, with shared weights.
        data_loaders: Dict containing data loaders.
    
    Returns: Nothing.
    Given a fully trained controller and shared_cnn, we sample many architectures,
    and then train a new cnn from scratch using the best architecture we found. 
    We change the number of filters in the new cnn such that the final layer 
    has 512 channels.
    """

    best_arc, best_val_acc = get_best_arc(controller, shared_cnn, data_loaders, n_samples=100, verbose=True)
    print('Best architecture:')
    print_arc(best_arc)
    print('Validation accuracy: ' + str(best_val_acc))

    fixed_cnn = SharedCNN(num_layers=args['child_num_layers'],
                          num_branches=args['child_num_branches'],
                          out_filters=512 // 4,  # args.child_out_filters
                          keep_prob=args['child_keep_prob'],
                          fixed_arc=best_arc)
    fixed_cnn = fixed_cnn.cuda()

    fixed_cnn_optimizer = torch.optim.SGD(params=fixed_cnn.parameters(),
                                          lr=args['child_lr_max'],
                                          momentum=0.9,
                                          nesterov=True,
                                          weight_decay=args['child_l2_reg'])

    fixed_cnn_scheduler = CosineAnnealingLR(optimizer=fixed_cnn_optimizer,
                                            T_max=args['child_lr_T'],
                                            eta_min=args['child_lr_min'])

    test_loader = data_loaders['test_dataset']

    for epoch in range(args['num_epochs']):

        train_shared_cnn(epoch,
                         controller,  # not actually used in training the fixed_cnn
                         fixed_cnn,
                         data_loaders,
                         fixed_cnn_optimizer,
                         best_arc)

        if epoch % args['eval_every_epochs'] == 0:
            test_acc = get_eval_accuracy(test_loader, fixed_cnn, best_arc)
            print('Epoch ' + str(epoch) + ': Eval')
            print('test_accuracy: %.4f' % (test_acc))

        fixed_cnn_scheduler.step()

        state = {'epoch': epoch + 1,
                 'args': args,
                 'best_arc': best_arc,
                 'fixed_cnn_state_dict': shared_cnn.state_dict(),
                 'fixed_cnn_optimizer': fixed_cnn_optimizer.state_dict()}
        filename = args['output_filename'] + '_fixed.pth.tar'
        torch.save(state, filename)


Evaluation


In [None]:
def evaluate_model(epoch, controller, shared_cnn, data_loaders, n_samples=10):
    """Print the validation and test accuracy for a controller and shared_cnn.
    Args:
        epoch: Current epoch.
        controller: Controller module that generates architectures to be trained.
        shared_cnn: CNN that contains all possible architectures, with shared weights.
        data_loaders: Dict containing data loaders.
        n_samples: Number of architectures to test when looking for the best one.
    
    Returns: Nothing.
    """

    controller.eval()
    shared_cnn.eval()

    print('Here are ' + str(n_samples) + ' architectures:')
    best_arc, _ = get_best_arc(controller, shared_cnn, data_loaders, n_samples, verbose=True)

    valid_loader = data_loaders['valid_subset']
    test_loader = data_loaders['test_dataset']

    valid_acc = get_eval_accuracy(valid_loader, shared_cnn, best_arc)
    test_acc = get_eval_accuracy(test_loader, shared_cnn, best_arc)

    print('Epoch ' + str(epoch) + ': Eval')
    print('valid_accuracy: %.4f' % (valid_acc))
    print('test_accuracy: %.4f' % (test_acc))

    controller.train()
    shared_cnn.train()

def get_best_arc(controller, shared_cnn, data_loaders, n_samples=10, verbose=False):
    """Evaluate several architectures and return the best performing one.
    Args:
        controller: Controller module that generates architectures to be trained.
        shared_cnn: CNN that contains all possible architectures, with shared weights.
        data_loaders: Dict containing data loaders.
        n_samples: Number of architectures to test when looking for the best one.
        verbose: If True, display the architecture and resulting validation accuracy.
    
    Returns:
        best_arc: The best performing architecture.
        best_vall_acc: Accuracy achieved on the best performing architecture.
    All architectures are evaluated on the same minibatch from the validation set.
    """

    controller.eval()
    shared_cnn.eval()

    valid_loader = data_loaders['valid_subset']

    images, labels = next(iter(valid_loader))
    images = images.cuda()
    labels = labels.cuda()

    arcs = []
    val_accs = []
    for i in range(n_samples):
        with torch.no_grad():
            controller()  # perform forward pass to generate a new architecture
        sample_arc = controller.sample_arc
        arcs.append(sample_arc)

        with torch.no_grad():
            pred = shared_cnn(images, sample_arc)
        val_acc = torch.mean((torch.max(pred, 1)[1] == labels).type(torch.float))
        val_accs.append(val_acc.item())

        if verbose:
            print_arc(sample_arc)
            print('val_acc=' + str(val_acc.item()))
            print('-' * 80)

    best_iter = np.argmax(val_accs)
    best_arc = arcs[best_iter]
    best_val_acc = val_accs[best_iter]

    controller.train()
    shared_cnn.train()
    return best_arc, best_val_acc

def get_eval_accuracy(loader, shared_cnn, sample_arc):
    """Evaluate a given architecture.
    Args:
        loader: A single data loader.
        shared_cnn: CNN that contains all possible architectures, with shared weights.
        sample_arc: The architecture to use for the evaluation.
    
    Returns:
        acc: Average accuracy.
    """
    total = 0.
    acc_sum = 0.
    for (images, labels) in loader:
        images = images.cuda()
        labels = labels.cuda()

        with torch.no_grad():
            pred = shared_cnn(images, sample_arc)
        acc_sum += torch.sum((torch.max(pred, 1)[1] == labels).type(torch.float))
        total += pred.shape[0]

    acc = acc_sum / total
    return acc.item()

def print_arc(sample_arc):
    """Display a sample architecture in a readable format.
    
    Args: 
        sample_arc: The architecture to display.
    Returns: Nothing.
    """
    for key, value in sample_arc.items():
        if len(value) == 1:
            branch_type = value[0].cpu().numpy().tolist()
            print('[' + ' '.join(str(n) for n in branch_type) + ']')
        else:
            branch_type = value[0].cpu().numpy().tolist()
            skips = value[1].cpu().numpy().tolist()
            print('[' + ' '.join(str(n) for n in (branch_type + skips)) + ']')

Load MNIST data


In [None]:
def load_datasets():
    """Create data loaders for the CIFAR-10 dataset.
    Returns: Dict containing data loaders.
    """
    normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])

    if args['cutout'] > 0:
        train_transform.transforms.append(Cutout(length=args['cutout']))

    valid_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize])

    train_dataset = datasets.CIFAR10(root=args['data_path'],
                                     train=True,
                                     transform=train_transform,
                                     download=True)

    valid_dataset = datasets.CIFAR10(root=args['data_path'],
                                     train=True,
                                     transform=valid_transform,
                                     download=True)

    test_dataset = datasets.CIFAR10(root=args['data_path'],
                                    train=False,
                                    transform=test_transform,
                                    download=True)

    train_indices = list(range(0, 45000))
    valid_indices = list(range(45000, 50000))
    train_subset = Subset(train_dataset, train_indices)
    valid_subset = Subset(valid_dataset, valid_indices)

    data_loaders = {}
    data_loaders['train_subset'] = torch.utils.data.DataLoader(dataset=train_subset,
                                                               batch_size=args['batch_size'],
                                                               shuffle=True,
                                                               pin_memory=True,
                                                               num_workers=2)

    data_loaders['valid_subset'] = torch.utils.data.DataLoader(dataset=valid_subset,
                                                               batch_size=args['batch_size'],
                                                               shuffle=True,
                                                               pin_memory=True,
                                                               num_workers=2,
                                                               drop_last=True)

    data_loaders['train_dataset'] = torch.utils.data.DataLoader(dataset=train_dataset,
                                                                batch_size=args['batch_size'],
                                                                shuffle=True,
                                                                pin_memory=True,
                                                                num_workers=2)

    data_loaders['test_dataset'] = torch.utils.data.DataLoader(dataset=test_dataset,
                                                               batch_size=args['batch_size'],
                                                               shuffle=False,
                                                               pin_memory=True,
                                                               num_workers=2)

    return data_loaders

Set seed

In [None]:
np.random.seed(args['seed'])
torch.cuda.manual_seed(args['seed'])

In [None]:
# if args['fixed_arc']:
#     sys.stdout = Logger(filename='logs/' + args['output_filename'] + '_fixed.log')
# else:
#     sys.stdout = Logger(filename='logs/' + args['output_filename'] + '.log')

print("\n".join("{}\t{}".format(k, v) for k, v in args.items()))

Create data loaders

In [None]:
data_loaders = load_datasets()

Create controller

In [None]:
controller = Controller(search_for=args['search_for'],
                        search_whole_channels=True,
                        num_layers=args['child_num_layers'],
                        num_branches=args['child_num_branches'],
                        out_filters=args['child_out_filters'],
                        lstm_size=args['controller_lstm_size'],
                        lstm_num_layers=args['controller_lstm_num_layers'],
                        tanh_constant=args['controller_tanh_constant'],
                        temperature=None,
                        skip_target=args['controller_skip_target'],
                        skip_weight=args['controller_skip_weight'])
controller = controller.cuda()

Child architectures

In [None]:
shared_cnn = SharedCNN(num_layers=args['child_num_layers'],
                        num_branches=args['child_num_branches'],
                        out_filters=args['child_out_filters'],
                        keep_prob=args['child_keep_prob'])

shared_cnn = shared_cnn.cuda()

Adam optimizer for controller

In [None]:
# https://github.com/melodyguan/enas/blob/master/src/utils.py#L218
controller_optimizer = torch.optim.Adam(params=controller.parameters(),
                                        lr=args['controller_lr'],
                                        betas=(0.0, 0.999),
                                        eps=1e-3)

SGD optimizer for childs

In [None]:
# https://github.com/melodyguan/enas/blob/master/src/utils.py#L213
shared_cnn_optimizer = torch.optim.SGD(params=shared_cnn.parameters(),
                                        lr=args['child_lr_max'],
                                        momentum=0.9,
                                        nesterov=True,
                                        weight_decay=args['child_l2_reg'])

Child learning rate

In [None]:
# https://github.com/melodyguan/enas/blob/master/src/utils.py#L154
shared_cnn_scheduler = CosineAnnealingLR(optimizer=shared_cnn_optimizer,
                                          T_max=args['child_lr_T'],
                                          eta_min=args['child_lr_min'])

In [None]:
if args['resume']:
    if os.path.isfile(args['resume']):
        print("Loading checkpoint '{}'".format(args['resume']))
        checkpoint = torch.load(args['resume'])
        start_epoch = checkpoint['epoch']
        # args = checkpoint['args']
        print(checkpoint.keys())
        shared_cnn.load_state_dict(checkpoint['shared_cnn_state_dict'])
        controller.load_state_dict(checkpoint['controller_state_dict'])
        shared_cnn_optimizer.load_state_dict(checkpoint['shared_cnn_optimizer'])
        controller_optimizer.load_state_dict(checkpoint['controller_optimizer'])
        shared_cnn_scheduler.optimizer = shared_cnn_optimizer  # Not sure if this actually works
        print("Loaded checkpoint '{}' (epoch {})"
              .format(args['resume'], checkpoint['epoch']))
    else:
        raise ValueError("No checkpoint found at '{}'".format(args['resume']))
else:
  start_epoch = 0

if not args['fixed_arc']:
    train_enas(start_epoch,
                controller,
                shared_cnn,
                data_loaders,
                shared_cnn_optimizer,
                controller_optimizer,
                shared_cnn_scheduler)
else:
    assert args['resume'] != '', 'A pretrained model should be used when training a fixed architecture.'
    train_fixed(start_epoch,
                controller,
                shared_cnn,
                data_loaders)