In [1]:
import tensorflow as tf
from tensorflow.keras import Model
import numpy as np

In [2]:
tf.enable_eager_execution()

# Concat

In [3]:
def _concat(xs):
    """nd tensor to 1d tensor

    Args:
        xs (array): the array of nd tensor

    Returns:
        array: concated array
    """
    return tf.concat([tf.reshape(x, [tf.size(x)]) for x in xs], axis=0, name="_concat")

## Testing

In [4]:
a = tf.constant([[[[1],[2],[3]], [[4], [5], [6]]], [[[2],[4],[6]], [[8], [10], [12]]]])
b = tf.constant([1,2])

In [5]:
tf.concat(tf.reshape(a, [tf.size(a)]), axis=-1)

<tf.Tensor: id=4, shape=(12,), dtype=int32, numpy=array([ 1,  2,  3,  4,  5,  6,  2,  4,  6,  8, 10, 12], dtype=int32)>

# Architect

In [43]:
class Architect(object):
    """Constructs the model

    Parameters:
      network_momentum(float):  network momentum
      network_weight_decay(float): network weight decay
      model(Network): Network archtecture with cells
      optimise(optimiser): Adam / SGD
    """

    def __init__(self, model, args):
        """Initialises the architecture

        Args:
            model (Network): Network archtecture with cells
            args (dict): cli args
        """
        self.network_momentum = args.momentum
        self.network_weight_decay = args.weight_decay
        self.model = model
        self.arch_learning_rate = args.arch_learning_rate
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.arch_learning_rate,
                                                       beta1=0.5,
                                                       beta2=0.999)    
    def get_model_theta(self, model):
        specific_tensor = []
        specific_tensor_name = []
        for var in model.trainable_weights:
            if not 'alphas' in var.name:
                specific_tensor.append(var)
                specific_tensor_name.append(var.name)
        return specific_tensor_name, specific_tensor
    
    def step(self, input_train, target_train, input_valid, target_valid, lr, unrolled):
        """Computer a step for gradient descend

        Args:
            input_train (tensor): a train of input
            target_train (tensor): a train of targets
            input_valid (tensor): a train of validation
            target_valid (tensor): a train of validation targets
            eta (tensor): eta
            network_optimizer (optimiser): network optimiser for network
            unrolled (bool): True if training we need unrolled
        """
        if unrolled:
            return self._compute_unrolled_step(
                input_train, target_train, input_valid, target_valid, self.get_model_theta(self.model)[1], lr)
#         else:
#             self._backward_step(input_valid, target_valid)
        
    def _compute_unrolled_step(self, x_train, y_train, x_valid, y_valid, w_var, lr):
        arch_var = self.model.arch_parameters()
        unrolled_model = self.model.new()
        unrolled_optimizer = tf.train.GradientDescentOptimizer(lr)
        
        with tf.GradientTape() as tape:
            logits = unrolled_model(x_train)
            unrolled_w_var = self.get_model_theta(unrolled_model)[1]
            # copy weights
            for v,w in zip(unrolled_w_var, w_var):
                v.assign(w)
            unrolled_train_loss = unrolled_model._criterion(logits, y_train)
            grads = tape.gradient(unrolled_train_loss, unrolled_w_var)
            unrolled_optimizer.apply_gradients(zip(grads, unrolled_w_var))

        with tf.GradientTape() as tape1:
            valid_loss = unrolled_model._criterion(unrolled_model(x_valid), y_valid)
            valid_grads = tape1.gradient(valid_loss, unrolled_w_var)
            
        r=1e-2
        R = r / (tf.global_norm(valid_grads)+1e-6)

        optimizer_pos=tf.train.GradientDescentOptimizer(R)
        optimizer_pos=optimizer_pos.apply_gradients(zip(valid_grads, w_var))
        
        optimizer_neg=tf.train.GradientDescentOptimizer(-2*R)
        optimizer_neg=optimizer_neg.apply_gradients(zip(valid_grads, w_var))

        optimizer_back=tf.train.GradientDescentOptimizer(R)
        optimizer_back=optimizer_back.apply_gradients(zip(valid_grads, w_var))
        
        with tf.GradientTape() as tape2:
            logits_model = self.model(x_train)
            train_loss = self.model._criterion(logits_model, y_train)
            train_grads_pos=tape2.gradient(train_loss, arch_var)
            
        with tf.GradientTape() as tape3:
            logits_model = self.model(x_train)
            train_loss = self.model._criterion(logits_model, y_train)
            train_grads_neg=tape3.gradient(train_loss, arch_var)
        
        with tf.GradientTape() as tape4:
            valid_loss = unrolled_model._criterion(unrolled_model(x_valid), y_valid)
            leader_grads=tape4.gradient(valid_loss, unrolled_model.arch_parameters())
        
        for i,(g,v) in enumerate(zip(leader_grads, arch_var)):
            leader_grads[i]=(g-lr*tf.divide(train_grads_pos[i]-train_grads_neg[i],2*R),v)
        
        leader_opt=self.optimizer.apply_gradients(leader_grads)
        return leader_opt, unrolled_model

## Testing

In [7]:
from model_search import Network

In [48]:
criterion = tf.losses.sigmoid_cross_entropy
model = Network(3, 3, criterion)
args = {
    "momentum": 0.9,
    "weight_decay": 3e-4,
    "arch_learning_rate": 3e-1,
    "arch_weight_decay": 1e-3
}

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [49]:
inp = tf.random_uniform((1, 16, 16, 3), 0, 255)
target = tf.random_uniform((1, 16, 16, 1), 0, 1)
input_valid = tf.random_uniform((1, 16, 16, 3), 0, 255)
target_valid = tf.random_uniform((1, 16, 16, 1), 0, 1)
lr=0.025
unrolled=True

In [50]:
image = tf.random_uniform((1, 16, 16, 3), 0, 255)
res = model(image)

In [51]:
architect = Architect(model, Struct(**args))

In [52]:
opt, model = architect.step(inp, target, input_valid, target_valid, lr, unrolled)

In [53]:
model.trainable_weights

[<tf.Variable 'network_14/sequential_28/conv2d_9846/kernel:0' shape=(3, 3, 3, 9) dtype=float32, numpy=
 array([[[[-0.19614969,  0.21278118, -0.17454168,  0.00797643,
           -0.23286544,  0.08676855,  0.06303824, -0.19849542,
           -0.19169842],
          [ 0.12443991,  0.16739284, -0.07762422, -0.03400785,
            0.20620902, -0.0837101 ,  0.14263578, -0.15548506,
           -0.12333809],
          [-0.11410124,  0.17391543,  0.08748208, -0.13477883,
            0.01943351,  0.06097974,  0.01835601,  0.04203694,
           -0.1293444 ]],
 
         [[ 0.08058198,  0.23235674,  0.07736953,  0.09986408,
            0.16971453,  0.13981228,  0.08082457,  0.19745968,
           -0.18798849],
          [-0.10159032, -0.10275497,  0.0115515 ,  0.0863186 ,
            0.05573444, -0.16917048, -0.03755948, -0.15520582,
            0.19206773],
          [ 0.05304568, -0.12769702,  0.10206972,  0.16292377,
            0.11239751, -0.10618792, -0.12093145, -0.0246482 ,
            0