In [None]:
class Architect(object):
    """Constructs the model

    Parameters:
      network_momentum(float):  network momentum
      network_weight_decay(float): network weight decay
      model(Network): Network archtecture with cells
      optimise(optimiser): Adam / SGD
    """

    def __init__(self, model, args):
        """Initialises the architecture

        Args:
            model (Network): Network archtecture with cells
            args (dict): cli args
        """
        self.network_momentum = args.momentum
        self.network_weight_decay = args.weight_decay
        self.model = model
        self.arch_learning_rate = args.arch_learning_rate
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.arch_learning_rate,
                                                       beta1=0.5,
                                                       beta2=0.999)    
    def get_model_theta(self, model):
        specific_tensor = []
        specific_tensor_name = []
        for var in model.trainable_weights:
            if not 'alphas' in var.name:
                specific_tensor.append(var)
                specific_tensor_name.append(var.name)
        return specific_tensor_name, specific_tensor
    
    def step(self, input_train, target_train, input_valid, target_valid, lr, unrolled):
        """Computer a step for gradient descend

        Args:
            input_train (tensor): a train of input
            target_train (tensor): a train of targets
            input_valid (tensor): a train of validation
            target_valid (tensor): a train of validation targets
            eta (tensor): eta
            network_optimizer (optimiser): network optimiser for network
            unrolled (bool): True if training we need unrolled
        """
        train_loss = self.model._loss(self.model(input_train), target_train)
        if unrolled:
            self._compute_unrolled_step(
                input_train, target_train, input_valid, target_valid, self.get_model_theta(self.model)[1], train_loss, lr)
#         else:
#             self._backward_step(input_valid, target_valid)
        
    def _compute_unrolled_step(self, x_train, y_train, x_valid, y_valid, w_var, train_loss, lr):
        arch_var = self.model.arch_parameters()
        unrolled_model = self.model.new()
        unrolled_optimizer = tf.train.GradientDescentOptimizer(lr)
        
        with tf.GradientTape() as tape:
            logits = unrolled_model(x_train)
            unrolled_w_var = self.get_model_theta(unrolled_model)[1]
            # copy weights
            for v,w in zip(unrolled_w_var, w_var):
                v.assign(w)
            unrolled_train_loss = unrolled_model._criterion(logits, y_train)
            grads = tape.gradient(unrolled_train_loss, unrolled_w_var)
            unrolled_optimizer.apply_gradients(zip(grads, unrolled_w_var))
        
        with tf.control_dependencies([unrolled_optimizer]):
            with tf.GradientTape() as tape1:
                valid_loss = unrolled_model._criterion(unrolled_model(x_valid), y_valid)
                valid_grads = tape1.gradient(valid_loss, unrolled_w_var)
        
        r=1e-2
        R = r / tf.global_norm(valid_grads)
        
        print(valid_grads[0], w_var[0])
        optimizer_pos=tf.train.GradientDescentOptimizer(R)
        optimizer_pos=optimizer_pos.apply_gradients(zip(valid_grads, w_var))

        optimizer_neg=tf.train.GradientDescentOptimizer(-2*R)
        optimizer_neg=optimizer_neg.apply_gradients(zip(valid_grads, w_var))

        optimizer_back=tf.train.GradientDescentOptimizer(R)
        optimizer_back=optimizer_back.apply_gradients(zip(valid_grads, w_var))
        
        with tf.control_dependencies([optimizer_pos]):
            with tf.GradientTape() as tape2:
                train_grads_pos=tape2.gradient(train_loss, arch_var)
            with tf.control_dependencies([optimizer_neg]):
                with tf.GradientTape() as tape3:
                    train_grads_neg=tape3.gradient(train_loss, arch_var)
                with tf.control_dependencies([optimizer_back]):
                    with tf.GradientTape() as tape4:
                        leader_opt= self.optimizer
                        leader_grads=tape4.gradient(valid_loss, arch_var)
        
        print(train_grads_pos)
        for i,(g,v) in enumerate(zip(leader_grads, arch_var)):
            leader_grads[i]=(g-lr*tf.divide(train_grads_pos[i]-train_grads_neg[i],2*R),v)

        leader_opt=leader_opt.apply_gradients(leader_grads)
        return leader_opt, unrolled_model