In [25]:
import tensorflow as tf
from tensorflow.keras import Model
import numpy as np
tf.enable_eager_execution()

In [26]:
from model_search import Network

In [27]:
criterion = tf.losses.sigmoid_cross_entropy
model = Network(3, 3, criterion)
args = {
    "momentum": 0.9,
    "weight_decay": 3e-4,
    "arch_learning_rate": 3e-4,
    "arch_weight_decay": 1e-3
}

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [61]:
inp = tf.random_uniform((1, 16, 16, 3), 0, 255)
target = tf.random_uniform((1, 16, 16, 1), 0, 10)

In [None]:
model(inp)

In [63]:
new_model = model.new()

In [64]:
def get_model_theta(model):
    specific_tensor = []
    specific_tensor_name = []
    for var in model.trainable_weights:
        if not 'alphas' in var.name:
            specific_tensor.append(var)
            specific_tensor_name.append(var.name)
    return specific_tensor_name, specific_tensor

In [65]:
model_weights = get_model_theta(model)[1]

In [66]:
new_model_weights = get_model_theta(model)[1]

In [67]:
for v,w in zip(new_model_weights,model_weights):
    v.assign(w)

In [79]:
with tf.GradientTape() as tape:
    logits = new_model(inp)
    unrolled_train_loss = tf.losses.mean_squared_error(logits, target)
    new_model_weights = get_model_theta(new_model)[1]
    tape.watch(new_model_weights)
    for v,w in zip(new_model_weights,model_weights):
        v.assign(w)
    grads = tape.gradient(unrolled_train_loss, new_model_weights, unconnected_gradients=tf.UnconnectedGradients.NONE)
    

In [69]:
opt = tf.train.GradientDescentOptimizer(0.01)

In [70]:
opt.apply_gradients(zip(grads, new_model_weights))

In [71]:
opt.apply_gradients(zip(grads, model_weights))

In [72]:
unrolled_train_loss

<tf.Tensor: id=782832, shape=(), dtype=float32, numpy=24.817299>

In [75]:
def _concat(xs):
    """nd tensor to 1d tensor

    Args:
        xs (array): the array of nd tensor

    Returns:
        array: concated array
    """
    return tf.concat([tf.reshape(x, [tf.size(x)]) for x in xs], axis=0, name="_concat")

In [80]:
tf.reduce_sum(_concat(grads))

<tf.Tensor: id=865877, shape=(), dtype=float32, numpy=0.0>

In [86]:
x = tf.ones((2, 2))

with tf.GradientTape() as t:
  t.watch(x)
  y = tf.reduce_sum(x)
  z = tf.multiply(y, y)

# Derivative of z with respect to the original input tensor x
dz_dx = t.gradient(z, x)
for i in [0, 1]:
  for j in [0, 1]:
    assert dz_dx[i][j].numpy() == 8.0

In [87]:
dz_dx

<tf.Tensor: id=866067, shape=(2, 2), dtype=float32, numpy=
array([[8., 8.],
       [8., 8.]], dtype=float32)>

In [89]:
x = tf.ones((2, 2))

with tf.GradientTape() as t:
  t.watch(x)
  y = tf.reduce_sum(x)
  z = tf.multiply(y, y)

# Use the tape to compute the derivative of z with respect to the
# intermediate value y.
dz_dy = t.gradient(z, y)
assert dz_dy.numpy() == 8.0


In [90]:
x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as t:
  t.watch(x)
  y = x * x
  z = y * y
dz_dx = t.gradient(z, x)  # 108.0 (4*x^3 at x = 3)
dy_dx = t.gradient(y, x)  # 6.0
del t  # Drop the reference to the tape


In [94]:
def f(x, y):
  output = 1.0
  for i in range(y):
    if i > 1 and i < 5:
      output = tf.multiply(output, x)
  return output

def grad(x, y):
  with tf.GradientTape() as t:
    t.watch(x)
    out = f(x, y)
  return t.gradient(out, x)

x = tf.convert_to_tensor(2.0)

assert grad(x, 6).numpy() == 12.0
assert grad(x, 5).numpy() == 12.0
assert grad(x, 4).numpy() == 4.0


In [96]:
with tf.GradientTape() as tape:
    logits = model(inp)
    unrolled_train_loss = tf.losses.mean_squared_error(logits, target)
    new_model_weights = get_model_theta(model)[1]
    tape.watch(new_model_weights)
    grads = tape.gradient(unrolled_train_loss, new_model_weights, unconnected_gradients=tf.UnconnectedGradients.NONE)

In [97]:
grads

[<tf.Tensor: id=891508, shape=(3, 3, 3, 9), dtype=float32, numpy=
 array([[[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
 
 
        [[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
 
 
        [[[0., 0., 0., 0., 0., 0., 0., 0., 0.]