Note:

Reference branch: https://github.com/invisibleForce/ENAS-Pytorch/blob/master/controller_model.py

This draft only implements the macro search of the algorithm. Given the amount of code here, I wonder if the scope to implement both Micro and Macro could be too big. We can always start with something small and make sure this part works first

To-dos:
1. Implement infrastructure functions - readImage
2. The current code only give 4 choices of conv network -  I doubt the choices. We could figure out later and use different choices.
3. Re-review controller related code and refactor 
4. put everything together
5. Debug

In the context of Efficient Neural Architecture Search (ENAS), a "child" refers to a specific neural network architecture generated during the search process. ENAS is an approach to automate the design of neural networks, where a "controller" network generates potential architectures, known as "child" networks.

Here's a more detailed breakdown:

Controller Network: The controller is typically a recurrent neural network (RNN) that predicts a sequence of actions, each defining a component of the neural network architecture (e.g., type of layer, number of filters, connections).

Child Network: Each sequence generated by the controller corresponds to a unique "child" network architecture. These child networks are trained and evaluated to measure their performance on a specific task. 

Performance Feedback: The performance of the child network (e.g., accuracy, loss) is fed back to the controller, which uses this information to improve its architecture generation process. This feedback loop helps the controller learn which architectural features lead to better performance.

The ENAS process aims to find optimal or near-optimal neural network architectures efficiently by leveraging the controller's ability to learn and refine the architecture search space.

In [28]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import time
import random

DEBUG = False
SEED = 202407240122
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(SEED)


print(DEVICE)

cpu


In [20]:
class LayerOperation(nn.Module):
    """
    An operation used by a nas layer
    Note:
        conv3/5: need to pad zeros to let the input and output 
        feature maps have the same size. The size padded zeros
        are given as follows. 
        ofmap size:
        E = np.floor((H + 2px - 1 * (R - 1) - 1) / Sx + 1) # see pytorch nn.conv2d for details
        F = np.floor((H + 2px - 1 * (R - 1) - 1) / Sx + 1)
        Let E = H, F = W, so we can solve px and py as follows
        Height: px = (R - 1) / 2
        Width: py = (P - 1) / 2
    """
    # op: conv3, conv5, avgpool3, maxpool3
    # out_channels: = M, num of filters
    def __init__(self, operation, out_channels):
        self.operation = operation
        self.out_channels = out_channels
        super(LayerOperation, self).__init__() 
        self.layer_list = self._build_layer()
    
    def _build_layer(self):
        layers = []
        conv_in = nn.Conv2d(
            in_channels=self.out_channels, 
            out_channels=self.out_channels, 
            kernel_size = 1, 
            stride=1)
        batch_norm = nn.BatchNorm2d(num_features=self.out_channels)
        relu = nn.ReLU()
        kernel = self.get_defined_operation_kernel(self.operation)

        layers.append(conv_in)
        layers.append(batch_norm)
        layers.append(relu)
        layers.append(kernel)
        
        if (self.operation == 'conv3') or (self.operation == 'conv5'):
            bn_out = nn.BatchNorm2d(num_features=self.out_channels)
            layers.append(bn_out)

        layers = nn.ModuleList(layers)

        return layers
    
    def get_defined_operation_kernel(self, operation):
        if operation == 'maxpool3':
            padding_x = 1
            padding_y = padding_x
            padding_size = (padding_x, padding_y)
            kernel = nn.MaxPool2d(kernel_size=3, padding=padding_size, stride=1)
        elif operation == 'avgpool3':
            padding_x = 1
            padding_y = padding_x
            padding_size = (padding_x, padding_y)
            kernel = nn.AvgPool2d(kernel_size=3, padding=padding_size, stride=1)
        elif operation == 'conv3':
            padding_x = 1
            padding_y = padding_x
            padding_size = (padding_x, padding_y)
            kernel = nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=3, padding=padding_size, stride=1)
        elif operation == 'conv5':
            padding_x = 2
            padding_y = padding_x
            padding_size = (padding_x, padding_y)
            kernel = nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=5, padding=padding_size, stride=1)
        else:
            raise ValueError('operation not supported')
        return kernel

    def __call__(self, x):
        for layer in self.layer_list:
            x = layer(x)
        return x
        

class NasLayer(nn.Module):
    def __init__(self, out_channels=24):
        self.out_channels = out_channels
        super(NasLayer, self).__init__()
        self.layers = self._build_nas_layer()
        

    # For the time concerns, we can reduce the search to only one conv and one pooling layer. That should reduce a big chunk of time.
    def _build_nas_layer(self):
        layers = []
        conv3 = LayerOperation('conv3', self.out_channels)
        conv5 = LayerOperation('conv5', self.out_channels)
        avgpool3 = LayerOperation('avgpool3', self.out_channels)
        maxpool3 = LayerOperation('maxpool3', self.out_channels)
        bn_out = nn.BatchNorm2d(num_features=self.out_channels)
    
        layers.append(conv3)
        layers.append(conv5)
        layers.append(avgpool3)
        layers.append(maxpool3)
        layers.append(bn_out)
        layers = nn.ModuleList(layers)
    
        return layers

    # not sure what it is doing
    def layer_operation(self, x, op):
        """
        Run the operation of a nas layer
        Args:
            x: ifmap
            op: operation to run
                0 - conv3
                1 - conv5
                2 - avgpool3
                3 - maxpool3
        Returns:
            x: ofmap
        """
        x = self.layers[op[0]](x)
        
        return x
        
    def skip(self, prev_layers, config):
        """
        Concatenate the desired previous layers of a nas layer
        Args:
            prev_layers: previous layers
            config: describe all the combined layers
        Returns:
            y: ofmap
        """
        layer_index_offset = 1  # used to skip the root_node_conv
        num_layer = len(prev_layers) - layer_index_offset

        desired_layers= []
        for i in range(num_layer):
            if config[i]: # layer specific config, if layer is selected
                desired_layers.append(prev_layers[i + layer_index_offset])

        if len(desired_layers):
            desired_layers = torch.stack(desired_layers) # stack all the tensors in an additional axis (i.e., 0)
            desired_layers = torch.sum(desired_layers, dim=0) # add along axis 0
        else:
            # if no layer is selected, return a tensor of zeros of size root_node_conv
            desired_layers = torch.zeros(prev_layers[0].size()) 

            if torch.cuda.is_available():
                desired_layers = desired_layers.cuda()
        
        return desired_layers
    


    def __call__(self, cnt_layer, prev_layers, layer_config):
        """
        Args:
            prev_layers: all previous layers
            layer_config: op and connectivity
        """
        x = prev_layers[-1]
        operation_config = layer_config[0]
        x = self.layer_operation(x, operation_config)

        if cnt_layer > 0:
            skip_config = layer_config[1]
            y = self.skip(prev_layers, skip_config)
            x = [x, y]
            x = torch.stack(x)
            x = torch.sum(x, dim=0)

            x = self.layers[-1](x) # bn_out

        return x

In [None]:
class ChildModel(nn.Module):
    def __init__(self,
               num_of_class,
               num_layers=6,
               out_channels=24,
               batch_size=32
              ):

        super(ChildModel, self).__init__() 

        self.num_of_class = num_of_class 
        self.num_layers = num_layers # We should define this; it should be part of the fine tuning process
        self.out_channels = out_channels
        self.graph = self.build_graph(self.num_of_class)
        

    def build_graph(self, class_num):
        """
        This method gives the visual representation of the model
        stem_conv: [N, C, H, W] -> [N, 3, 24, 24]
        kernel: [M, C, R, P] -> [M, 3, 3, 3] /[Sx=1, Sy=1] -> I bet we need to recalculate this
        """
        graph = []

        # Build root conv layer
        root_node_conv = self.build_root_conv()
        graph.append(root_node_conv)
        for _ in range(self.num_layers):
            graph.append(NasLayer(self.out_channels))
        # fully connected layer
        fc = nn.Linear(self.out_channels, class_num, bias=True)
        graph.append(fc)

        graph = nn.ModuleList(graph) # this will hold submodule in a list

        return graph
    
    def build_root_conv(self):
        px = 1   # px = int((3 - 1) / 2) # we need to figure out what does that 3 mean.....
        py = px
        padding_size = (px, py)
        root_node_conv = nn.Conv2d(
                in_channels=3, 
                out_channels=self.out_channels, 
                kernel_size = 3, 
                padding=padding_size,
                stride=1)
        return root_node_conv
    
    def global_avgpool(self, x):
        """
        An operation used to reduce the H and W axis
        x = [N, C, H, W] -> [N, C, 1, 1]
        """
        H = x.size()[2]
        W = x.size()[3]
        x = torch.sum(x, dim=[2, 3])
        x = x / (H * W)

        return x

    def model(self, x, sample_arch):
        """
        run (like forward) a child model determined by sample_arch
        Use the given test to get a preview of sample architecture
        Args:
            sample_arch: a list consisting of 2 * num_layers elements
                op_id = sample_arch[2k]: operation id
                skip = sample_arch[2k + 1]: element i of such abinary vector 
                    is used to describe whether the previous layer i is used 
                    as an input
            x: input of the child model
        Return:
            x: output of the child model
        """
        # layers
        prev_layers = []
        # root_node_conv
        x = self.graph[0](x)
        prev_layers.append(x)
        index_offset = 1
        # nas_layers
        for layer_index in range(self.num_layers):
            layer_config = sample_arch[2 * layer_index : 2 * layer_index + 2]   # [op], [skip]
            x = self.graph[layer_index + index_offset](layer_index, prev_layers, layer_config)
            prev_layers.append(x)
        x = self.global_avgpool(x)
        # go through the fully connected layer
        x = self.graph[-1](x)

        return x
    
def test_preview_architecture():
    # child model arch
    sample_arch = []
    # layer 0
    sample_arch.append([0]) # op, c3
    sample_arch.append([]) # skip, none
    # layer 1
    sample_arch.append([1]) # op, c5
    sample_arch.append([1]) # skip=layer i + 1 input, l0=1
    # layer 2
    sample_arch.append([3]) # op, mp
    sample_arch.append([0, 0]) # skip=layer i + 1 input, l0=0, l1=0
    # layer 3
    sample_arch.append([1]) # op, c5
    sample_arch.append([1, 0, 1]) # skip=layer i + 1 input, l0=1, l1=0, l2=1
    # layer 4
    sample_arch.append([0]) # op, c3
    sample_arch.append([0, 0, 0, 0]) # skip=layer i + 1 input, l0=0, l1=0, l2=0, l3=0
    # layer 5
    sample_arch.append([2]) # op, ap
    sample_arch.append([0, 0, 0, 0, 0]) # skip=layer i + 1 input, l0=0, l1=0, l2=0, l3=0
    print(sample_arch)
    # instantiate a model
    images = torch.rand([2, 3, 7, 7])
    labels = torch.tensor([1, 2])
    class_num = 5
    num_layers = 6
    out_channels = 2
    child = ChildModel(class_num, num_layers, out_channels)
    print(len(list(child.parameters())))
    # print(list(child.parameters()))
    print(len(child.graph))
    print(child.graph)
    y = child.model(images, sample_arch)
    print(y.size())

# ------------------
# Run the test to get a preview of the model architecture
# ------------------
test_preview_architecture()


In [27]:
class Child(nn.Module):
    def __init__(self,
               num_of_class,
               num_of_layer=6,
               out_channels=24,
               batch_size=32,
               lr=0.05,
               gamma=0.1,
               lr_cos_lmin=0.001,
               lr_cos_Tmax=2,
               l2_reg_lr=1e-4,
               eval_period=100
              ):
        super(Child, self).__init__() 
        self.num_of_class = num_of_class # number of classification classes
        self.num_of_layer = num_of_layer 
        self.out_channels = out_channels

        self.batch_size = batch_size
        self.eval_period = eval_period

        self.l2_reg_lr = l2_reg_lr

        self.lr = lr
        self.gamma = gamma
        self.lr_cos_lmin = lr_cos_lmin
        self.lr_cos_Tmax = lr_cos_Tmax
        self.device = 'gpu' if torch.cuda.is_available() else 'cpu'
        
        # build child model
        self.net = ChildModel(num_of_class, num_of_layer, out_channels)
        self.criterion = nn.CrossEntropyLoss()

        if DEBUG: print('#param', len(list(self.net.parameters())))

        self.optimizer = optim.SGD([{'params': self.net.parameters(), 'initial_lr': self.lr}], lr=self.lr, weight_decay=self.l2_reg_lr, momentum=0.9, nesterov=True)
        
        # learning rate scheduler - Not sure if it's necessary, basically it's a learning rate decay
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.lr_cos_Tmax, eta_min=self.lr_cos_lmin)
    
    def get_batch(self, images, classification_results, step):
        # The original code takes batch as sequential data, not sure if we need to convert to random. If we ever encounter a problem, we can consider change it to random
        batch_size = self.batch_size
        batch_images = images[step * batch_size : (step + 1) * batch_size] 
        batch_classifications = classification_results[step * batch_size : (step + 1) * batch_size] 
        if DEBUG: print('get_batch', type(batch_images))
        if DEBUG: print('get_batch', type(batch_classifications))

        # augment data step - I comment it out because I'm not sure if it's necessary given the amount of data
        # batch_images = augment(batch_images)

        return batch_images, batch_classifications

    def train_epoch(self, sample_arch, images, labels, epoch, train_step):    
        running_loss = 0.0
        if DEBUG: print('lr=', self.scheduler.get_lr())
        for step in range(train_step): 
            batch_inputs, batch_classifications = self.get_batch(images, labels, step)
            self.optimizer.zero_grad()
            outputs = self.net.model(batch_inputs, sample_arch)
            loss = self.criterion(outputs, batch_classifications)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            if step % self.eval_period == (self.eval_period - 1):
                print('[%d, %5d], loss: %.3f' %
                    (epoch + 1, step + 1, running_loss / self.eval_period))
                running_loss = 0.0
        
        self.scheduler.step()

    def valid_rl(self, sample_arch, images, labels):    
        """
        validate a sampled child model on a random minibatch of validation set
        """
        max_index = labels.size()[0] // self.batch_size
        
        batch_idx = torch.randint(max_index, (1,1))
        batch_inputs, batch_classifications = self.get_batch(images, labels, batch_idx)

        outputs = self.net.model(batch_inputs, sample_arch)
        
        value, idx = torch.topk(outputs, 1)
        idx = idx.reshape((-1))
        accuracy = (idx == batch_classifications).float().sum()
        accuracy /= self.batch_size
        
        return accuracy

    def eval(self, sample_arch, images, labels):    
        num_of_step = labels.size()[0] // self.batch_size
        accuracy = 0
        for i in range(num_of_step):
            batch_inputs, batch_classifications = self.get_batch(images, labels, i)
            outputs = self.net.model(batch_inputs, sample_arch)
            _, idx = torch.topk(outputs, 1) # we can change it to argmax or softmax or guumbel softmax if necessary
            idx = idx.reshape((-1))
            accuracy += (idx == batch_classifications).float().sum()
        accuracy /= (num_of_step * self.batch_size)
        
        return accuracy

     

def test_child():
    # obtain datasets
    t = time.time()
    images, labels = read_data()
    t = time.time() - t
    print('read dataset consumes %.2f sec' % t)
    # config of a model
    class_num = 10
    num_layers = 6
    out_channels = 32
    batch_size = 32
    device = 'gpu'
    epoch_num = 4
    # sample a child model
    sample_arch = []
    # layer 0
    sample_arch.append([0]) # op, c3
    sample_arch.append([]) # skip, none
    # layer 1
    sample_arch.append([1]) # op, c5
    sample_arch.append([1]) # skip=layer i + 1 input, l0=1
    # layer 2
    sample_arch.append([3]) # op, mp
    sample_arch.append([0, 0]) # skip=layer i + 1 input, l0=0, l1=0
    # layer 3
    sample_arch.append([1]) # op, c5
    sample_arch.append([1, 0, 1]) # skip=layer i + 1 input, l0=1, l1=0, l2=1
    # layer 4
    sample_arch.append([0]) # op, c3
    sample_arch.append([0, 0, 0, 0]) # skip=layer i + 1 input, l0=0, l1=0, l2=0, l3=0
    # layer 5
    sample_arch.append([2]) # op, ap
    sample_arch.append([0, 0, 0, 0, 0]) # skip=layer i + 1 input, l0=0, l1=0, l2=0, l3=0
    print(sample_arch)
    
    # create a child
    child = Child(images, labels, class_num, num_layers, out_channels, batch_size, device, epoch_num)
    print(len(list(child.net.graph)))
    # print(child.net.graph)
    # train a child model
    t = time.time()
    child.train(sample_arch)
    t = time.time() - t
    print('training time %.2f sec' % t)

    # # train another sample_arch
    # sample_arch = []
    # # layer 0
    # sample_arch.append([1]) # op, c5
    # sample_arch.append([]) # skip, none
    # # layer 1
    # sample_arch.append([0]) # op, c3
    # sample_arch.append([1]) # skip=layer i + 1 input, l0=1
    # # layer 2
    # sample_arch.append([3]) # op, mp
    # sample_arch.append([1, 0]) # skip=layer i + 1 input, l0=1, l1=0
    # # layer 3
    # sample_arch.append([0]) # op, c3
    # sample_arch.append([1, 0, 1]) # skip=layer i + 1 input, l0=1, l1=0, l2=1
    # # layer 4
    # sample_arch.append([0]) # op, c3
    # sample_arch.append([0, 1, 0, 1]) # skip=layer i + 1 input, l0=0, l1=1, l2=0, l3=1
    # # layer 5
    # sample_arch.append([2]) # op, ap
    # sample_arch.append([0, 0, 0, 1, 1]) # skip=layer i + 1 input, l0=0, l1=0, l2=0, l3=1, l4=1
    # print(sample_arch)

    # print(len(list(child.net.graph)))
    # print(child.net.graph)
    # train a child model
    t = time.time()
    child.train(sample_arch)
    t = time.time() - t
    print('training time %.2f sec' % t)


# the read data function is not implemented yet
# test_child()

In [31]:

import torch.nn.functional as f

class StackLSTM(nn.Module):
    """
    StackLSTM class.
    It describes a stacked LSTM which only 
    run a single step.
    """
    def __init__(self, input_size, hidden_size, lstm_num_layers=2):
        # init
        super(StackLSTM, self).__init__() # init the parent class of Net, i.e., nn.Module
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm_num_layers = lstm_num_layers
        self.net = self._build_net()

    def _build_net(self):
        return nn.ModuleList([nn.LSTMCell(self.input_size, self.hidden_size) for _ in range(self.lstm_num_layers)])

    def __call__(self, inputs, prev_h, prev_c):
        next_h, next_c = [], []
        for i, cell in enumerate(self.net):
            x = inputs if i == 0 else next_h[-1]
            cur_h, cur_c = cell(x, (prev_h[i], prev_c[i]))
            next_h.append(cur_h)
            next_c.append(cur_c)
        return next_h, next_c

class ControllerModel(nn.Module):
    def __init__(self,
               child_num_layers=6,
               lstm_hidden_size=32,
               lstm_num_layers=2,
               num_operations=4,
               temperature=5,
               tanh_constant=2.5,
               skip_target=0.4
              ):
        super(ControllerModel, self).__init__() # init the parent class of Net, i.e., nn.Module
        self.child_num_layers = child_num_layers 

        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_num_layers = lstm_num_layers
        self.num_operations = num_operations
        self.temperature = temperature
        self.tanh_constant = tanh_constant
        self.skip_target = skip_target
        self.device = DEVICE

        self.net = self._build_net()
        # add g_emb as a parameter to ControllerModel
        # initialized by uniform distribution between -0.1 to 0.1
        # 0 <= torch.rand < 1
        g_emb_init = 0.2 * torch.rand(1,self.lstm_hidden_size) - 0.1
        self.register_parameter(name='g_emb', param=torch.nn.Parameter(g_emb_init))

        self.sample_arch = []
        self.sample_entropy = []
        self.sample_log_prob = []
        self.sample_skip_count = []
        self.sample_skip_penaltys = []
        

    def _build_net(self):
        net = {}
        net['lstm'] = StackLSTM(self.lstm_hidden_size, self.lstm_hidden_size, self.lstm_num_layers)
        net['op_fc'] = nn.Linear(self.lstm_hidden_size, self.num_operations)
        net['op_emb_lookup'] = nn.Embedding(self.num_operations, self.lstm_hidden_size)
        net['skip_attn1'] = nn.Linear(self.lstm_hidden_size, self.lstm_hidden_size) # w_attn1 
        net['skip_attn2'] = nn.Linear(self.lstm_hidden_size, self.lstm_hidden_size) # w_attn2
        net['skip_attn3'] = nn.Linear(self.lstm_hidden_size, 1)                     # v_attn

        if DEBUG:
            for name in ['lstm', 'op_fc', 'op_emb_lookup', 'skip_attn1', 'skip_attn2', 'skip_attn3']:
                param = list(net[name].parameters())
                print(name)
                for p in param:
                    print(p.size())
    
        net = nn.ModuleDict(net)
        
        return net

    def _op_sample(self, args):
        """
        sample an op (it is a part of controller's forward)
        Args: consisting of the following parts
            inputs: input of op_sample
            prev_h & prev_c: the hidden and cell states of the prev layer
            arc_seq: architecture sequence
            log_probs: all the log probabilities used for training (recall the gradient calculation of REINFORCE)
            entropys: all the entropys used for training
        Return:
            x: output of the child model
        """
        net = self.net
        inputs, prev_h, prev_c, arc_seq, log_probs, entropys = args
 
        next_h, next_c = net['lstm'](inputs, prev_h, prev_c)
        prev_h, prev_c = next_h, next_c

        logit = net['op_fc'](next_h[-1])    # h state of the last layer
  
        if self.temperature is not None:
            logit /= self.temperature

        if self.tanh_constant is not None:
            logit = self.tanh_constant * torch.tanh(logit)

        prob = f.softmax(logit, dim=1)
        op_id = torch.multinomial(prob, 1) # logit = probs of each type of operation, 1 = sample a single op
        op_id = op_id[0]

        inputs = net['op_emb_lookup'](op_id.long())
        log_prob = f.cross_entropy(logit, op_id)
        entropy = log_prob * torch.exp(-log_prob)

        if self.device == 'gpu':
            op = op_id.cpu() # that's the line that I don't really understand
        else:
            op = op_id
        op = int(op.data.numpy()) # to an int
        op = [op] # to list
        arc_seq.append(op)
        log_probs.append(log_prob)
        entropys.append(entropy)

        return inputs, prev_h, prev_c, arc_seq, log_probs, entropys        

    def _skip_sample(self, args):
        """
        sample skip connections for layer_id (it is a part of controller's forward)
        Args:
            layer_id: layer count
            inputs: input of op_sample
            prev_h & prev_c: the hidden and cell states of the prev layer
            arc_seq: architecture sequence
            log_probs: all the log probabilities used for training (recall the gradient calculation of REINFORCE)
            entropys: all the entropys used for training
            archors & anchors_w_1: archor points and its weighed values
            skip_targets & skip_penaltys & skip_count: used to enforce the sparsity of skip connections
        Return:
            all args except layer_id
        """    
        layer_id, inputs, prev_h, prev_c, arc_seq, log_probs, entropys, anchors, anchors_w_1, skip_targets, skip_penaltys, skip_count = args
        net = self.net

        next_h, next_c = net['lstm'](inputs, prev_h, prev_c)
        prev_h, prev_c = next_h, next_c
        if layer_id > 0:
            # use attention mechanism to generate logits
            # concate the weighed anchors
            query = torch.cat(anchors_w_1, dim=0) 
            # attention 2 - fc
            query = torch.tanh(net['skip_attn2'](next_h[-1]) + query)
            # attention 3 - fc            
            query = net['skip_attn3'](query)
            # generate logit
            logit = torch.cat([-query, query], dim=1)
            # process logit with temperature
            if self.temperature is not None:
                logit /= self.temperature
            # process logit with tanh and scale it
            if self.temperature is not None:
                logit = self.tanh_constant * torch.tanh(logit)
            # calculate prob of skip (see NAS paper, Sec3.3)
            skip_prob = torch.sigmoid(logit) # use sigmoid to convert skip to its prob
            # sample skip connections using multinomial distribution sampler
            skip = torch.multinomial(skip_prob, 1)  # 0 - used as an input, 1 - not an input
            # calcualte kl as skip penalty
            kl = skip_prob * torch.log(skip_prob / skip_targets) # calculate kl
            kl = torch.sum(kl)
            skip_penaltys.append(kl)
            # cal log_prob and append it - used by REINFORCE to calculate gradients of controller (i.e., LSTM)
            log_prob = f.cross_entropy(logit, skip.squeeze(dim=1))
            log_probs.append(torch.sum(log_prob))
            # cal entropys and append it
            entropy = log_prob * torch.exp(-log_prob)
            entropy = torch.sum(entropy)
            entropys.append(entropy)
            # update count of skips
            skip_count.append(skip.sum())
            # add skip to arc_seq
            if self.device == 'gpu':
                skip_cpu = skip.cpu()
            else:
                skip_cpu = skip
            arc_seq.append(skip_cpu.squeeze(dim=1).data.numpy().tolist())
            # generate inputs for the next time step
            skip = torch.reshape(skip, (1, layer_id)) # reshape skip
            cat_anchors = torch.cat(anchors, dim=0)
            # skip = 1 x layer_id (layer_id > 0) 
            # cat_anchors = layer_id x lstm_size
            inputs = torch.matmul(skip.float(), cat_anchors) 
            inputs /= (1.0 + torch.sum(skip))
        else:
            inputs = self.g_emb
            if self.device == 'gpu':
                inputs = inputs.cuda()
            arc_seq.append([]) # no skip, use empty list to occupy the position
        
        # cal the
        anchors.append(next_h[-1])
        # cal attention 1
        attn1 = net['skip_attn1'](next_h[-1])
        anchors_w_1.append(attn1)

        return inputs, prev_h, prev_c, arc_seq, log_probs, entropys, anchors, anchors_w_1, skip_targets, skip_penaltys, skip_count

    def net_sample(self):
        """
        run (like forward) a controller model to sample an neural architecture
        Args:
            
        Return:
            
        """
        # net sample
        arc_seq = []
        entropys = []
        log_probs = []
        # skip sample 
        anchors = []        # store hidden states of skip lstm; anchor = hidden states of skip lstm (i.e., layer_id)
        anchors_w_1 = []    # store results of attention 1 (input=h, w_attn1)
        skip_count = []
        skip_penaltys = []

        # determine the device used to run the model
        if self.device == 'gpu': # check whether gpu is available or not
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        else: device = 'cpu'
        if DEBUG: print(device)
        # move model to gpu
        if self.device == 'gpu': # check whether gpu is available or not
            self.net.to(device) # move net to gpu
        # init inputs and states
        # init prev cell states to zeros for each layer of the lstm
        prev_c = [torch.zeros((1, self.lstm_hidden_size),device=device) for _ in range(self.lstm_num_layers)]
        # init prev hidden states to zeros for each layer of the lstm
        prev_h = [torch.zeros((1, self.lstm_hidden_size),device=device) for _ in range(self.lstm_num_layers)]
        # inputs
        inputs = self.g_emb
        if self.device == 'gpu': # check whether gpu is available or not
            inputs = inputs.cuda()
        # skip_target = 0.4 = the prob of a layer used as an input of another layer
        # 1 - skip_target = 0.6; the probability that this layer is not used as an input
        skip_targets = torch.tensor([1.0 - self.skip_target, self.skip_target], dtype=torch.float, device=device)
        

        # sample an arch
        for layer_id in range(self.child_num_layers):
            arg_op_sample = [inputs, prev_h, prev_c, arc_seq, log_probs, entropys]
            returns_op_sample = self._op_sample(arg_op_sample)
            inputs, prev_h, prev_c, arc_seq, log_probs, entropys = returns_op_sample
            arg_skip_sample = [layer_id, inputs, prev_h, prev_c, arc_seq, log_probs, entropys, 
                                anchors, anchors_w_1, skip_targets, skip_penaltys, skip_count]
            returns_skip_sample = self._skip_sample(arg_skip_sample)
            inputs, prev_h, prev_c, arc_seq, log_probs, entropys, anchors, anchors_w_1, skip_targets, skip_penaltys, skip_count = returns_skip_sample

        # generate sample arch
        # [[op], [skip]] * num_layer
        self.sample_arch = arc_seq
        if DEBUG: 
            print('sample_arch')
            print('len:', len(self.sample_arch))
            for idx, data in enumerate(self.sample_arch):
                if idx % 2 == 0:
                    print('-' * 15)
                    print('layer:', idx)
                    print('op:', data)
                else:
                    print('skip:', data)
        # cal sample entropy
        entropys = torch.stack(entropys)
        self.sample_entropy = torch.sum(entropys)
        if DEBUG: 
            print('sample_entropy: %.3f' % self.sample_entropy.item())
            
        # cal sample log_probs
        log_probs = torch.stack(log_probs)
        self.sample_log_prob = torch.sum(log_probs)
        if DEBUG: 
            print('sample_log_prob: %.3f' % self.sample_log_prob.item())
            
        # cal skip count
        skip_count = torch.stack(skip_count)
        self.sample_skip_count = torch.sum(skip_count)
        if DEBUG: 
            print('sample_skip_count: %.0f' % self.sample_skip_count.item())
            
        # cal skip penaltys
        skip_penaltys = torch.stack(skip_penaltys)
        self.sample_skip_penaltys = torch.sum(skip_penaltys)
        if DEBUG: 
            print('sample_skip_penaltys : %.3f' % self.sample_skip_penaltys.item())
        

def test_model():
    ctrler = ControllerModel()
    # param = list(ctrler.parameters())
    # param_num = len(param)
    # print(param_num)
    # for i in range(param_num):
    #     # print(p.size())
    #     print(i, param[i].size())
    print(ctrler.net)
    ctrler.net_sample()
    
# ------------------
# Testbench
# ------------------
if __name__ == '__main__':
    test_model()

ModuleDict(
  (lstm): StackLSTM(
    (net): ModuleList(
      (0-1): 2 x LSTMCell(32, 32)
    )
  )
  (op_fc): Linear(in_features=32, out_features=4, bias=True)
  (op_emb_lookup): Embedding(4, 32)
  (skip_attn1): Linear(in_features=32, out_features=32, bias=True)
  (skip_attn2): Linear(in_features=32, out_features=32, bias=True)
  (skip_attn3): Linear(in_features=32, out_features=1, bias=True)
)


In [None]:
class Controller(nn.Module):
    """
    Controller class.
    It describes how to train a controller
        1) train
    """
    def __init__(self,
               device='gpu',
               lstm_size=32,
               lstm_num_layers=2,
               child_num_layers=6,
               num_op=4,
               train_step_num=50,
               ctrl_batch_size=20,
               opt_algo='adam',
               lr_init=0.00035,
               lr_gamma=0.1,
               temperature=5,
               tanh_constant=2.5,
               entropy_weight=0.0001,
               baseline_decay=0.999,
               skip_target=0.4,
               skip_weight=0.8
              ):
        """
        1. init params
        2. create a graph which contains the sampled subgraph
        """
        super(Controller, self).__init__() # init the parent class of Net, i.e., nn.Module
        # config of controller model
        # child model
        self.child_num_layers = child_num_layers # imgs of dataset
        # ctrl model
        self.lstm_size = lstm_size # labels of dataset 
        self.lstm_num_layers = lstm_num_layers # number of classes
        self.num_op = num_op # 
        self.temperature = temperature
        self.tanh_constant = tanh_constant
        self.skip_target = skip_target
        # ctrl training
        self.ctrl_batch_size=ctrl_batch_size
        self.opt_algo=opt_algo
        self.lr_init = lr_init
        self.lr_gamma = lr_gamma
        self.train_step_num = train_step_num
        self.entropy_weight = entropy_weight
        self.baseline_decay = baseline_decay
        self.skip_weight = skip_weight
        # device
        self.device = device
        # # training parameters on cpu
        if self.device == 'gpu':
            self.reward = torch.zeros(1).cuda() # rewards of samples
            self.baseline = torch.zeros(1).cuda() # base line
            self.log_prob = torch.zeros(1).cuda() # log_probs of samples
            self.entropy = torch.zeros(1).cuda() # entropys of samples
            # self.skip_rate = torch.zeros(1) # skip_rates of samples
            self.skip_penalty = torch.zeros(1).cuda() # skip_penaltys of samples
            self.loss = torch.zeros(1).cuda() # loss
        else:
            self.reward = torch.zeros(1) # rewards of samples
            self.baseline = torch.zeros(1) # base line
            self.log_prob = torch.zeros(1) # log_probs of samples
            self.entropy = torch.zeros(1) # entropys of samples
            # self.skip_rate = torch.zeros(1) # skip_rates of samples
            self.skip_penalty = torch.zeros(1) # skip_penaltys of samples
            self.loss = torch.zeros(1) # loss
        # training parameters on gpu
        

        # build controller
        self.ctrl = ControllerModel(child_num_layers=child_num_layers,
               lstm_size=lstm_size,
               lstm_num_layers=lstm_num_layers,
               num_op=num_op,
               temperature=temperature,
               tanh_constant=tanh_constant,
               skip_target=skip_target,
               device=device)
        # Optimizer; use SGD
        if DEBUG: print('#param', len(list(self.ctrl.parameters())))
        # style: Adam
        # self.optimizer = optim.Adam(self.ctrl.parameters(), lr=self.lr_init, betas=(0, 0.999)) # ENAS code sets beta1=0
        self.optimizer = optim.Adam(self.ctrl.parameters(), lr=self.lr_init)
        # self.optimizer = optim.SGD(self.net.parameters(), lr=self.lr_init, weight_decay=self.l2_reg, momentum=0.9, nesterov=True)
        
        # learning rate scheduler - not mentioned in paper, not use it first
        # style: exponential decaying
        # lr = gamma * lr for each epoch
        # self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, self.lr_gamma)
        # style: multistepLR
        # decay lr every step_size epochs
        # self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[1,2], gamma=0.1)
        # style: stepLR; 
        # decay lr every step_size epochs
        # self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=2, gamma=0.1)
        # style: cosine
        # self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.lr_cos_Tmax, eta_min=self.lr_cos_lmin)
    
    def train_epoch(self, child_model, images, labels, file):    
        """
        train controller for an epoch
        Procedure
        for N train stpes:
            for M child archs: - like obtain a batch of data
                sample a child architecture
                validate the sampled arch on a single minibatch of validation set
                obtain reward
                add weighed entropy to reward
                update baseline
                    exponential moving average of previous rewards
                cal loss: add weighed skip penalty to loss
            loss = avg( sample_log_prob * (reward - baseline) + skip_weight * skip_penaltys )
            zero grads
            cal grapds
                loss.backward() = REINFORCE
            update params of ctrl
            
        Args:
            
        Return:
            
        """
        # rewards = []
        # skip_rates = []
        # log_probs = []
        # entropys = []
        # skip_penaltys = []
        for step in range(self.train_step_num):
            # a single step of training
            # sample a batch of child archs and obtain their metrics
            if self.device == 'gpu':
                loss = torch.zeros(self.ctrl_batch_size).cuda()
            else:
                loss = torch.zeros(self.ctrl_batch_size)
            t_step = time.time()
            for sample_cnt in range(self.ctrl_batch_size):
                # sample a child arch
                self.ctrl.net_sample()
                # valid a sampled arch and obtain reward
                self.reward = child_model.valid_rl(self.ctrl.sample_arch, images, labels) 
                # add weighed entropy to reward
                self.entropy = self.ctrl.sample_entropy
                self.reward += self.entropy_weight * self.entropy
                # update baseline
                with torch.no_grad():
                    self.baseline = self.baseline + (1 - self.baseline_decay) * (self.reward - self.baseline)
                # update loss
                self.log_prob = self.ctrl.sample_log_prob
                self.skip_penalty = self.ctrl.sample_skip_penaltys
                loss[sample_cnt] = self.log_prob * (self.reward - self.baseline) + self.skip_weight * self.skip_penalty
                # # cal skip rate and append it
                # skip_rate = self.ctrl.sample_skip_count
                # normalize = self.child_num_layers * (self.child_num_layers - 1) / 2
                # skip_rate /= normalize
                # skip_rates.append(skip_rate)
            self.loss = loss.sum() / self.ctrl_batch_size # avg loss
            # zero grads
            self.optimizer.zero_grad()
            # cal grads
            # self.loss.backward(retain_graph=True)
            self.loss.backward()
            # update weights
            self.optimizer.step()
            # print(self.ctrl.net['op_fc'].weight.grad) # check grad is updated
            # cal time consumed per step
            t_step = time.time() - t_step
            if step % 10 == 0:
                print('step', step)
                display_sample_arch(self.ctrl.sample_arch)
                file.write('step:'+str(step))
                print_sample_arch(self.ctrl.sample_arch, file)
                print('time_per_step', t_step)
                file.write('time_per_step'+str(t_step))

    def get_op_portion(self):
        """
        Count number of each type of ops in the sample arch
            
        Args: sample_arch
        Return:
        """
        op_counts = [0] * self.num_op
        sample_arch = self.ctrl.sample_arch
        for i in range(self.child_num_layers):
            op_counts[sample_arch[2 * i][0]] += 1

        return op_counts

    def get_op_percent(self, op_histroy):
        """
        Avg number of each type of ops in the sample arch
            
        Args: op_histroy
        Return:
        """
        num_samples = len(op_histroy)
        op_histroy = np.stack(op_histroy)
        op_history = np.sum(op_histroy, axis=0)
        op_history = op_history / np.sum(op_history) # portion of each type of op

        return op_history

    def eval(self, child_model, arc_num, images, labels, file):
        """
        evaluate controller using validating data set.
        It samples several archs and validate them on 
        the whole validate set.
            
        Args:
            
        Return:
            
        """
        accuracy = []
        arcs = []
        op_percent = []
        for _ in range(arc_num):
            # sample a child arch
            self.ctrl.net_sample()
            arcs.append(self.ctrl.sample_arch)
            # valid a sampled arch and obtain reward
            eval_acc = child_model.eval(self.ctrl.sample_arch, images, labels) 
            accuracy.append(eval_acc)
            # get the op analysis
            op_percent.append(self.get_op_portion())
        # obtain averaged op_history
        op_percent = self.get_op_percent(op_percent)
        # print to file  
        # accuracy      
        file.write('arch \t accuracy\n')    
        for i, acc in enumerate(accuracy):
            file.write('%d \t %f\n' % (i, acc)) 
        # arch       
        for i, arc in enumerate(arcs):    
            file.write('arch#: %d\n' % i)
            print_sample_arch(arc, file)
        
        return accuracy, op_percent

    def derive_best_arch(self, child_model, arc_num, images, labels, file):
        """
        derive the final child model using controller
        procedure
            1. sample 1000 archs
            2. test them on test data set
            3. select the one with highest accuracy as the best arch
        Args:
            
        Return:
            best_arch
        """
        accuracy = []
        arcs = []
        best_arch = []
        best_accuracy = 0
        for _ in range(arc_num):
            # sample a child arch
            self.ctrl.net_sample()
            arcs.append(self.ctrl.sample_arch)
            # valid a sampled arch and obtain reward
            eval_acc = child_model.eval(self.ctrl.sample_arch, images, labels) 
            accuracy.append(eval_acc)
            # select the best arch
            if eval_acc > best_accuracy:
                best_accuracy = eval_acc
                best_arch = self.ctrl.sample_arch
        
        # print to file  
        # best accuracy and arc
        file.write('best accuracy: %f\n' % best_accuracy)
        file.write('best arch \n')
        print_sample_arch(best_arch, file)
        # accuracy    
        file.write('-' * 30 + '\n')   
        file.write(' accuracies \n')   
        file.write('-' * 30 + '\n')   
        file.write('arch \t accuracy\n')    
        for i, acc in enumerate(accuracy):
            file.write('%d \t %f\n' % (i, acc)) 
        # arch     
        file.write('-' * 30 + '\n')   
        file.write(' archs \n')   
        file.write('-' * 30 + '\n')     
        for i, arc in enumerate(arcs):    
            file.write('arch#: %d\n' % i)
            print_sample_arch(arc, file)

        return best_accuracy, best_arch

def test_ctrl():
    # obtain datasets
    t = time.time()
    images, labels = read_data()
    t = time.time() - t
    print('read dataset consumes %.2f sec' % t)
    # config of a model
    class_num = 10
    child_num_layers = 6
    out_channels = 32
    batch_size = 32
    device = 'gpu'
    epoch_num = 4
    # files to print sampled archs
    child_filename = 'child_file.txt'
    ctrl_filename = 'controller_file.txt'
    child_file = open(child_filename, 'w')
    ctrl_file = open(ctrl_filename, 'w')
    # create a controller
    ctrl = Controller(child_num_layers=child_num_layers)
    # create a child, set epoch to 1; later this will be moved to an over epoch
    child = Child(images, labels, class_num, child_num_layers, out_channels, batch_size, device, 1)
    print(len(list(child.net.graph)))
    # print(child.net.graph)
    # train multiple epochs
    for _ in range(epoch_num):
        # sample an arch
        ctrl.ctrl.net_sample()
        sample_arch = ctrl.ctrl.sample_arch
        print_sample_arch(sample_arch, child_file)
        # train a child model
        # t = time.time()
        # child.train(sample_arch)
        # t = time.time() - t
        # print('child training time %.2f sec' % t)

        # train controller
        t = time.time()
        ctrl.train(child, ctrl_file)
        t = time.time() - t
        print('ctrller training time %.2f sec' % t)

# ------------------
# Testbench
# ------------------
if __name__ == '__main__':
    test_ctrl()