In [1]:
import tensorflow as tf
import numpy as np
import utils

from model_search import Network
from architect_graph import Architect
from genotypes import Genotype

In [2]:
def get_model_theta(model):
    specific_tensor = []
    specific_tensor_name = []
    for var in model.trainable_weights:
        if not 'alphas' in var.name:
            specific_tensor.append(var)
            specific_tensor_name.append(var.name)
    return specific_tensor

args = {
    "momentum": 0.9,
    "weight_decay": 3e-4,
    "arch_learning_rate": 3e-1,
    "momentum": 0.9,
    "grad_clip": 5,
    "learning_rate": 0.025,
    "learning_rate_decay": 0.97,
    "learning_rate_min": 0.0001,
    "num_batches_per_epoch": 2000,
    
    "unrolled": True,
    "epochs": 10,
    "train_batch_size": 2,
    "eval_batch_size": 2,
    "save": "EXP",
    "init_channels": 3,
    "num_layers": 3,
    "num_classes": 6,
    "crop_size": [8, 8],
    "save_checkpoints_steps": 100,
    "model_dir": 'gs://unet-darts/train-search-ckptss',
    "max_steps": 10000,
    # NEW
    "steps_per_eval": 2,
    "num_train_examples": 16,
    #
    
    "use_tpu": False,
    "use_host_call": True,
    "tpu": 'unet-darts',
    "zone": 'us-central1-f',
    "project": "isro-nas",
}
args.update({"num_batches_per_epoch": args["num_train_examples"] // args["train_batch_size"]})

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

args = Struct(**args)

criterion = tf.losses.sigmoid_cross_entropy
model = Network(3, 3, criterion, num_classes=args.num_classes)
lr=args.learning_rate
unrolled=args.unrolled
W, H = args.crop_size[0], args.crop_size[1]
NUM_IMAGES = 20
x_train = np.random.randint(0, 256, (NUM_IMAGES, W, H, 3)).astype(np.float32)
y_train = np.random.randint(0, args.num_classes, (NUM_IMAGES, W, H, 1)).astype(np.float32)
x_valid = np.random.randint(0, 256, (NUM_IMAGES, W, H, 3)).astype(np.float32)
y_valid = np.random.randint(0, args.num_classes, (NUM_IMAGES, W, H, 1)).astype(np.float32)
ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(args.train_batch_size, drop_remainder=True)
ds_valid = tf.data.Dataset.from_tensor_slices((x_valid, y_valid)).batch(args.train_batch_size, drop_remainder=True)
it_train = ds_train.make_one_shot_iterator()
image, label = it_train.get_next()
it_valid = ds_valid.make_one_shot_iterator()
image_valid, label_valid = it_valid.get_next()
init = tf.global_variables_initializer()
_ = model(image)






Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.

Instructions for updating:
If using Keras pass *_constraint arguments to layers.





## Load the Model

In [4]:
w_var = model.get_thetas()
arch_params = model.arch_parameters()
alphas_normal, alphas_reduce = arch_params[0], arch_params[1]
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

In [None]:
with tf.Session() as sess:
    sess.run(init)
    tf.initialize_all_variables().run()
    w_var_o = sess.run(w_var)
    alphas_normal_bo, alphas_reduce_bo = sess.run([alphas_normal, alphas_reduce])
    
    saver.restore(sess, "./final_model/shit_model.ckpt")
    
    alphas_normal_ao, alphas_reduce_ao = sess.run([alphas_normal, alphas_reduce])
    w_var_ao = sess.run(w_var)
    
    print("Model Resotred!")

## Check w_var

In [6]:
w_var_o[0]

array([[[[-0.21781442, -0.13041043,  0.0720505 ,  0.18408327,
          -0.22147276,  0.1071315 , -0.03190242, -0.13923076,
          -0.1908624 ],
         [-0.12504655, -0.14907487, -0.20249611,  0.18117677,
           0.04531382, -0.01673152,  0.16581379,  0.07224469,
          -0.18228209],
         [-0.15290186, -0.19079895, -0.13119337,  0.21012391,
           0.03492816, -0.18359414,  0.18441956, -0.20286576,
          -0.17108479]],

        [[-0.15176928, -0.1514701 , -0.05060308, -0.17898767,
          -0.02843188,  0.05437095,  0.12146206,  0.07602449,
           0.21362664],
         [-0.19741905,  0.01360591, -0.18136469, -0.1479682 ,
          -0.11745962,  0.12857066,  0.1809571 ,  0.03187774,
           0.06209074],
         [ 0.09203146, -0.16484065,  0.09408928,  0.18601544,
          -0.1664211 , -0.22933337, -0.21667346, -0.18256925,
          -0.02358891]],

        [[ 0.07074387,  0.19907574,  0.02953865, -0.02254531,
           0.1776057 , -0.23015904, -0.0789384

In [8]:
w_var_ao[0]

array([[[[-0.03442837, -0.19600122,  0.12719032, -0.20992821,
          -0.01833246,  0.07180147, -0.03064003, -0.10792386,
          -0.19448635],
         [ 0.00414314,  0.07310251,  0.02488243, -0.17197381,
          -0.0906544 ,  0.06768876,  0.06518925, -0.18494917,
           0.09822389],
         [-0.1974354 ,  0.16325174, -0.13307647,  0.22510232,
          -0.22217709,  0.04119756, -0.07348753, -0.05445288,
          -0.11283895]],

        [[-0.10354199, -0.01992343,  0.19863144,  0.07962783,
          -0.07248777, -0.14840886,  0.10554582, -0.21028695,
           0.22161986],
         [-0.10730677, -0.16625865, -0.22848445,  0.11719748,
           0.18475099,  0.15106677, -0.09111515, -0.14113586,
           0.23229635],
         [ 0.0620399 , -0.13657337,  0.0093889 , -0.0810284 ,
          -0.15057312, -0.23169768, -0.0992375 , -0.08982612,
          -0.08481759]],

        [[ 0.04264848,  0.23551881, -0.1929766 , -0.06323685,
          -0.04846992, -0.09720816,  0.1750691

In [9]:
np.equal(w_var_o[0], w_var_ao[0])

array([[[[False, False, False, False, False, False, False, False,
          False],
         [False, False, False, False, False, False, False, False,
          False],
         [False, False, False, False, False, False, False, False,
          False]],

        [[False, False, False, False, False, False, False, False,
          False],
         [False, False, False, False, False, False, False, False,
          False],
         [False, False, False, False, False, False, False, False,
          False]],

        [[False, False, False, False, False, False, False, False,
          False],
         [False, False, False, False, False, False, False, False,
          False],
         [False, False, False, False, False, False, False, False,
          False]]],


       [[[False, False, False, False, False, False, False, False,
          False],
         [False, False, False, False, False, False, False, False,
          False],
         [False, False, False, False, False, False, False, False,
  

## Check alphas

In [10]:
alphas_normal_bo

array([[3.7092878e-04, 1.6795911e-04, 3.5203077e-04, 5.8859424e-04],
       [8.3135022e-04, 2.3740293e-05, 7.0695963e-04, 8.0207159e-04],
       [6.6350105e-05, 5.8585609e-04, 7.0565217e-04, 9.7406999e-04],
       [8.2912232e-04, 8.6010760e-04, 1.5207507e-04, 5.3793169e-04],
       [9.0187887e-04, 9.3262724e-04, 1.3480079e-04, 5.2903593e-04],
       [8.0715242e-04, 4.8329355e-04, 2.6738597e-04, 9.6407108e-04],
       [4.7285523e-04, 6.4862648e-04, 2.4022903e-04, 6.0483482e-04],
       [2.6485135e-04, 8.1331522e-04, 4.1582133e-04, 6.7548995e-04],
       [3.1587674e-04, 4.1381398e-04, 5.0004839e-04, 4.0725508e-04],
       [1.1215842e-04, 5.0315075e-04, 8.8371319e-04, 8.2838204e-04],
       [6.6652382e-04, 7.9944590e-04, 3.0231584e-04, 1.7037035e-05],
       [5.1672105e-04, 6.2618614e-04, 1.7023504e-04, 7.3237316e-04],
       [7.4511052e-05, 7.8145677e-04, 8.9976517e-04, 5.1348150e-04],
       [6.9281424e-04, 2.9768827e-04, 9.5252960e-04, 8.0167630e-04]],
      dtype=float32)

In [11]:
alphas_normal_ao

array([[7.1198615e-04, 7.2746788e-04, 9.6007105e-04, 7.4182759e-04],
       [1.8302909e-04, 2.9186808e-04, 4.3197046e-04, 8.7230606e-04],
       [8.0202968e-04, 7.9980376e-04, 9.1616018e-04, 3.0944683e-04],
       [5.9581187e-04, 4.7895222e-04, 4.9409480e-04, 7.5088424e-04],
       [1.6830792e-04, 1.2588678e-05, 8.1552973e-04, 9.3164120e-04],
       [2.9931555e-04, 1.9713894e-04, 4.2687057e-04, 7.7478675e-04],
       [4.1851398e-04, 8.5141859e-04, 6.6222821e-04, 3.8386270e-04],
       [9.5327944e-04, 5.5400473e-05, 3.5133364e-04, 8.9306623e-04],
       [2.5579846e-04, 1.5495979e-04, 7.2160648e-04, 9.0962322e-04],
       [9.9319368e-06, 6.9085346e-04, 4.3143576e-04, 4.5638229e-04],
       [1.6581672e-04, 2.4000688e-04, 3.2327723e-04, 5.9998449e-04],
       [1.6850508e-04, 9.7663119e-04, 2.3424468e-04, 1.9254586e-04],
       [4.7988590e-04, 9.0659113e-04, 5.0395430e-04, 1.1261527e-04],
       [7.0980861e-04, 2.8840150e-04, 8.8462618e-04, 3.2511406e-04]],
      dtype=float32)

In [12]:
np.equal(alphas_normal_bo, alphas_normal_ao)

array([[False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False],
       [False, False, False, False]])