In [1]:
import os
import sys
import time
import glob
import numpy as np
import logging
import argparse
import tensorflow as tf
import numpy as np
import utils
from tqdm import tqdm
import shutil
from model_search import Network
from architect_graph import Architect
import utils
from genotypes import Genotype
# tf.enable_eager_execution()
tf.set_random_seed(6969)
np.random.seed(6969)

In [66]:
args = {
    "momentum": 0.9,
    "weight_decay": 3e-4,
    "arch_learning_rate": 3e-1,
    "momentum": 0.9,
    "grad_clip": 5,
    "learning_rate": 0.025,
    "learning_rate_decay": 0.97,
    "learning_rate_min": 0.0001,
    "num_batches_per_epoch": 2000,
    
    "unrolled": True,
    "epochs": 10,
    "train_batch_size": 8,
    "batch_size": 8,
    "eval_batch_size": 8,
    "save": "EXP",
    "init_channels": 3,
    "num_layers": 3,
    "num_classes": 6,
    "crop_size": [128, 128],
    "save_checkpoints_steps": 100,
    "model_dir": 'gs://unet-darts/train-search-ckpts',
#     "max_steps": 10000,
    "max_steps": None,
    "steps": 20,
    # NEW
    "steps_per_eval": 2,
    "num_train_examples": 16,
    "num_batches_per_epoch": 2,
    #
    
    "train_report_freq": 10,
    "protocol": None,    
    "use_tpu": False,
    "use_host_call": False,
    "tpu": 'unet-darts',
    "zone": 'us-central1-f',
    "project": "isro-nas",
    "zone": None,
    "project": None
}

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

args = Struct(**args)

In [67]:
def make_inp_fn2(filename, mode, batch_size):
    
    def _input_fn(params):
        W, H = args.crop_size[0], args.crop_size[1]
        NUM_IMAGES = 20
        x_train = np.random.randint(0, 256, (NUM_IMAGES, W, H, 3)).astype(np.float32)
        y_train = np.random.randint(0, args.num_classes, (NUM_IMAGES, W, H, 1)).astype(np.float32)
        x_valid = np.random.randint(0, 256, (NUM_IMAGES, W, H, 3)).astype(np.float32)
        y_valid = np.random.randint(0, args.num_classes, (NUM_IMAGES, W, H, 1)).astype(np.float32)

        ds = (x_train, x_valid), (y_train, y_valid)
        dataset = tf.data.Dataset.from_tensor_slices(ds)
        num_epochs = None # indefinitely
        dataset = dataset.shuffle(buffer_size = 10 * batch_size)

        dataset = dataset.repeat(num_epochs).batch(batch_size, drop_remainder=True)
        return dataset
    
    return _input_fn

In [68]:
def summary_hook(global_step, learning_rate, loss, y, preds):
    # Outfeed supports int32 but global_step is expected to be int64.
    global_step = tf.reduce_mean(global_step)
    global_step = tf.cast(global_step, tf.int64)

    acc = tf.metrics.accuracy(labels=y, 
                                  predictions=tf.round(preds))
    writer = tf.contrib.summary.create_file_writer(args.model_dir)
    with writer.as_default():
        tf.contrib.summary.scalar(
            'learning_rate', tf.reduce_mean(learning_rate),
            step=global_step)
        tf.contrib.summary.scalar(
            'loss', tf.reduce_mean(loss),step=global_step)
        tf.contrib.summary.scalar(
            'acc', acc,step=global_step)
    all_summary_ops = tf.compat.v1.summary.all_v2_summary_ops()

    return writer, all_summary_ops, writer.flush()

In [69]:
def train(input_fn, args):
    g = tf.compat.v1.Graph()
    with g.as_default():
        ds = input_fn(args)
        it = ds.make_one_shot_iterator()

        criterion = tf.losses.softmax_cross_entropy
        model = Network(C=args.init_channels, net_layers=args.num_layers, criterion=criterion, num_classes=args.num_classes)
        global_step = tf.Variable(0, name='global_step',trainable=False)
        learning_rate_min = tf.constant(args.learning_rate_min)

        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            global_step,
            decay_rate=args.learning_rate_decay,
            decay_steps=args.num_batches_per_epoch,
            staircase=True,
        )

        lr = tf.maximum(learning_rate, learning_rate_min)

        optimizer = tf.train.MomentumOptimizer(lr, args.momentum)

        features, labels = it.get_next()
        (x_train, x_valid) = features
        (y_train, y_valid) = labels
        b, w, h, c = y_train.shape
        y = tf.reshape(tf.cast(y_train, tf.int64), (b, w, h))
        y = tf.one_hot(y, args.num_classes, on_value=1.0, off_value=0.0)
        preds = model(x_train)
#         architect = Architect(model, args)
#         tf.logging.info('Arch Step')
#         architect_step = architect.step(input_train=x_train,
#                                         target_train=y_train,
#                                         input_valid=x_valid,
#                                         target_valid=y_valid,
#                                         unrolled=args.unrolled,
#                                         )
#         with tf.control_dependencies([architect_step]):
        w_var = model.get_thetas()
        loss = model._loss(preds, y_train)
        grads = tf.gradients(loss, w_var)
        tf.logging.info('Model Step')
        train_op = optimizer.apply_gradients(zip(grads, w_var), global_step=tf.train.get_global_step())

        # Summary
        tf.logging.info("Setting up Summary Hook!")
        tf.summary.scalar('loss', loss)
        merged = tf.summary.merge_all()
        
        global_init_op = tf.global_variables_initializer()
        tf.logging.info('Initialized Global Initializer.')
        local_init_op = tf.local_variables_initializer()
        tf.logging.info('Initialized Local Initializer.')

        current_step = tf.train.latest_checkpoint(args.model_dir)
        next_checkpoint = 0

        try:
            current_step = int(current_step.split('-')[-1])
        except:
            current_step = 0

        if(args.max_steps == None and args.steps):
            next_checkpoint = current_step + args.steps
        else:
            next_checkpoint = args.max_steps

        tf.logging.info('Training for %d steps. Current step %d.' % (next_checkpoint, current_step))

        saver = tf.train.Saver()
        tf.logging.info('Initialized Training Saver!')
        
        
    # Training Loop
    with tf.Session(graph=g) as sess:
        train_writer = tf.summary.FileWriter(args.model_dir, sess.graph)
        sess.run([global_init_op, local_init_op])
        # Restoring Checkpoint
        latest_ckpt = tf.train.latest_checkpoint(args.model_dir)
        if(latest_ckpt):
            tf.logging.info("Loading from %s" % latest_ckpt)
            saver.restore(sess, latest_ckpt)
        
        while(current_step <= next_checkpoint):            
            # Summary Operation
            summary_tr = sess.run(merged)
            step = sess.run(global_step)
            train_writer.add_summary(summary_tr, step)
            sess.run(train_op)
            # writer op
#             sess.run(writer_flush)
            
            if(not current_step % args.train_report_freq):
                tf.logging.info("Current Step: %d" % current_step)
                
            current_step += 1
            
        save_path = saver.save(sess, os.path.join(args.model_dir, 'model.ckpt-' + str(current_step)))
        tf.logging.info("Model saved in path: %s" % save_path)

In [58]:
out = train(make_inp_fn2(filename = tf.gfile.Glob('../../datasets/train_search/train-search-*-00008.tfrecords'),
                        mode = tf.estimator.ModeKeys.TRAIN,
                        batch_size = args.train_batch_size), args)

INFO:tensorflow:Model Step


INFO:tensorflow:Model Step


INFO:tensorflow:Setting up Summary Hook!


INFO:tensorflow:Setting up Summary Hook!


INFO:tensorflow:Initialized Global Initializer.


INFO:tensorflow:Initialized Global Initializer.


INFO:tensorflow:Initialized Local Initializer.


INFO:tensorflow:Initialized Local Initializer.


INFO:tensorflow:Training for 62 steps. Current step 42.


INFO:tensorflow:Training for 62 steps. Current step 42.


INFO:tensorflow:Initialized Training Saver!


INFO:tensorflow:Initialized Training Saver!


INFO:tensorflow:Loading from gs://unet-darts/train-search-ckpts/model.ckpt-42


INFO:tensorflow:Loading from gs://unet-darts/train-search-ckpts/model.ckpt-42


INFO:tensorflow:Restoring parameters from gs://unet-darts/train-search-ckpts/model.ckpt-42


INFO:tensorflow:Restoring parameters from gs://unet-darts/train-search-ckpts/model.ckpt-42


INFO:tensorflow:Current Step: 50


INFO:tensorflow:Current Step: 50


INFO:tensorflow:Current Step: 60


INFO:tensorflow:Current Step: 60


INFO:tensorflow:Model saved in path: gs://unet-darts/train-search-ckpts/model.ckpt-63


INFO:tensorflow:Model saved in path: gs://unet-darts/train-search-ckpts/model.ckpt-63


In [70]:
# Soft device placement should be enabled to automatically place
# summary ops to CPU.
tf.config.set_soft_device_placement(True)
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args.tpu)
tf.config.experimental_connect_to_cluster(resolver, protocol=args.protocol)
logging.info('Remote eager configured')
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
    train(make_inp_fn2(filename = tf.gfile.Glob('../../datasets/train_search/train-search-*-00008.tfrecords'),
                        mode = tf.estimator.ModeKeys.TRAIN,
                        batch_size = args.train_batch_size), args)





INFO:tensorflow:Initializing the TPU system: unet-darts


INFO:tensorflow:Initializing the TPU system: unet-darts


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.18:8470) for TPU system metadata.


INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.18:8470) for TPU system metadata.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 9263471354727379776)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 9263471354727379776)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 15960800653966626927)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 15960800653966626927)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 16756564147706039495)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 16756564147706039495)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 15659625517490204949)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 15659625517490204949)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 11387176928378139660)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 11387176928378139660)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 13672379615462641528)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 13672379615462641528)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 17801541048304008048)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 17801541048304008048)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 1817974265396846832)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 1817974265396846832)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 9513160855619944561)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 9513160855619944561)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 8589934592, 13542879675874795063)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 8589934592, 13542879675874795063)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 17740901858987369205)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 17740901858987369205)


INFO:tensorflow:Model Step


INFO:tensorflow:Model Step


INFO:tensorflow:Setting up Summary Hook!


INFO:tensorflow:Setting up Summary Hook!


INFO:tensorflow:Initialized Global Initializer.


INFO:tensorflow:Initialized Global Initializer.


INFO:tensorflow:Initialized Local Initializer.


INFO:tensorflow:Initialized Local Initializer.


INFO:tensorflow:Training for 104 steps. Current step 84.


INFO:tensorflow:Training for 104 steps. Current step 84.


INFO:tensorflow:Initialized Training Saver!


INFO:tensorflow:Initialized Training Saver!


INFO:tensorflow:Loading from gs://unet-darts/train-search-ckpts/model.ckpt-84


INFO:tensorflow:Loading from gs://unet-darts/train-search-ckpts/model.ckpt-84


INFO:tensorflow:Restoring parameters from gs://unet-darts/train-search-ckpts/model.ckpt-84


INFO:tensorflow:Restoring parameters from gs://unet-darts/train-search-ckpts/model.ckpt-84


INFO:tensorflow:Current Step: 90


INFO:tensorflow:Current Step: 90


INFO:tensorflow:Current Step: 100


INFO:tensorflow:Current Step: 100


INFO:tensorflow:Model saved in path: gs://unet-darts/train-search-ckpts/model.ckpt-105


INFO:tensorflow:Model saved in path: gs://unet-darts/train-search-ckpts/model.ckpt-105
