In [2]:
import os
import sys
import time
import glob
import numpy as np
# import torch
import logging
import argparse
import tensorflow as tf
from tensorflow.keras import Model
import numpy as np
import utils
from tqdm import tqdm

In [4]:
from architect_graph import Architect

# Train Function

In [5]:
def train(x_train, y_train, x_valid, y_valid, logits, model, architect, optimizer):
    """Trains the network. Gradient step is performed here

    Args:
        train_queue (array): Train queue
        valid_queue (array): Validation queue
        model (Network): Network
        architect (Architect): the architechture of network
        criterion (fn): Loss function
        optimizer (Optimiser): Adam / SGD
        lr (float): Learning Rate

    Returns:
        (float, float): returns acc and miOu
    """

    # architect step
    architect_step = architect.step(input_train=x_train,
                                    target_train=y_train,
                                    input_valid=x_valid,
                                    target_valid=y_valid,
                                    unrolled=args.unrolled
                                    )
    w_var = model.get_thetas()

    # calculating accuracy and iou
    acc = utils.accuracy(logits, y_train)
    iou = utils.iou(logits, y_train)

    with tf.control_dependencies([architect_step]):
      loss = model._loss(logits, y_train)
      grads = tf.gradients(loss, w_var)
      clipped_gradients, norm = tf.clip_by_global_norm(grads, args.grad_clip)
      opt_op = optimizer.apply_gradients(zip(clipped_gradients, w_var))

    return opt_op, loss, acc, iou

In [6]:
args = {
    "momentum": 0.9,
    "weight_decay": 3e-4,
    "arch_learning_rate": 3e-1,
    "arch_weight_decay": 1e-3,
    "momentum": 0.9,
    "grad_clip": 5,
    "learning_rate_min": 0.001,
    "learning_rate": 0.025,
    "unrolled": True,
    "epochs": 10,
    "batch_size": 4,
    "save": "EXP"
}

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

args = Struct(**args)

In [17]:
np_ds_train = (np.random.randint(0, 256, (20, 16, 16, 1)).astype(np.float32), np.random.randint(0, 2, (20, 16, 16, 12)).astype(np.float32))
np_ds_valid = (np.random.randint(0, 256, (20, 16, 16, 1)).astype(np.float32), np.random.randint(0, 2, (20, 16, 16, 12)).astype(np.float32))
ds_train = tf.data.Dataset.from_tensor_slices(np_ds_train).batch(4)
ds_valid = tf.data.Dataset.from_tensor_slices(np_ds_valid).batch(4)
train_it = ds_train.make_one_shot_iterator()
valid_it = ds_valid.make_one_shot_iterator()

In [18]:
from model_search import Network
criterion = tf.losses.sigmoid_cross_entropy
model = Network(3, 3, criterion)
optimizer = tf.train.MomentumOptimizer(args.learning_rate_min, args.momentum)

In [19]:
architect = Architect(model, args)

In [20]:
sess = tf.Session()
init = tf.global_variables_initializer()
x_train, y_train = sess.run(train_it.get_next())
x_valid, y_valid = sess.run(valid_it.get_next())
net_out = model(x_train)

In [None]:
train_op, acc_op, iou_op = train(x_train=x_train, 
                  y_train=y_train,
                  x_valid=x_valid,
                  y_valid=y_valid,
                  logits=net_out,
                  model=model,
                  architect=architect,
                  optimizer=optimizer
                 )

In [None]:
with tf.Session() as sess:
    sess.run(init)
    sess.run(tf.initialize_all_variables())
    logits, _, acc, iou  = sess.run([net_out, train_op, acc_op, iou_op])

# Infer

In [None]:
def infer(x_valid, y_valid, logits, model, criterion):
  loss_op = model._loss(logits, y_valid)
  acc_op = utils.accuracy(logits, y_valid)
  iou_op = utils.iou(logits, y_valid)
  return loss_op, acc_op, iou_op

# Main

In [None]:
def main(args):
  np_ds_train = (np.random.randint(0, 256, (20, 16, 16, 3)).astype(np.float32), np.random.randint(0, 2, (20, 16, 16, 1)).astype(np.float32))
  np_ds_valid = (np.random.randint(0, 256, (20, 16, 16, 3)).astype(np.float32), np.random.randint(0, 2, (20, 16, 16, 1)).astype(np.float32))
  ds_train = tf.data.Dataset.from_tensor_slices(np_ds_train).batch(args.batch_size)
  ds_valid = tf.data.Dataset.from_tensor_slices(np_ds_valid).batch(args.batch_size)
  train_it = ds_train.make_one_shot_iterator()
  valid_it = ds_valid.make_one_shot_iterator()

  num_iterations = np_ds_train[0][0] / args.batch_size

  criterion = tf.losses.sigmoid_cross_entropy
  model = Network(3, 3, criterion)
  
  # Optimizer
  optimizer = tf.train.MomentumOptimizer(args.learning_rate_min, args.momentum)

  architect = Architect(model, args)

  sess = tf.Session()
  init = tf.global_variables_initializer()

  sess.run(init)

  mious = []
  for e in range(args.epochs):
    print("Epoch {}".format(e))

    tq1 = tqdm(range(num_iterations))
    genotype = model.genotype()

    # Train Loop
    train_miou = 0
    train_unions, train_intersections = [], []
    for i in tq1:
      x_train, y_train = sess.run(train_it.get_next())
      x_valid, y_valid = sess.run(valid_it.get_next())

      train_logits = model(x_train)
      train_op, train_loss_op, train_acc_op, train_iou_op = train(x_train=x_train, 
                  y_train=y_train,
                  x_valid=x_valid,
                  y_valid=y_valid,
                  logits=train_logits,
                  model=model,
                  architect=architect,
                  optimizer=optimizer
                 )
      
      _train, train_loss, train_acc, train_iou = sess.run([train_op, train_loss_op, train_acc_op, train_iou_op])
      if(i % args.report_freq == 0):
        tq1.set_postfix({
                "Train Loss": train_loss
                "Train Acc": train_acc,
                "Train IoU": train_iou[0]
                })
      train_unions.append(train_iou[1])
      train_intersections.append(train_iou[2])
    
    # Calculation of train miou
    train_unions = np.array(train_unions)
    train_intersections = np.array(train_intersections)
    train_non_zero_mask = train_unions != 0
    train_miou = np.mean(train_intersections[train_non_zero_mask])/(np.mean(train_unions[train_non_zero_mask]) + 1e-6)
    
    # Log train miou
    print(train_miou)

    # Validation loop
    valid_unions, valid_intersections = [], []
    tq2 = tqdm(range(num_iterations))
    for i in tq2:
      valid_logits = model(x_valid)
      valid_loss_op, valid_acc_op, valid_iou_op = infer(x_valid=x_valid, 
                                                   y_valid=y_valid,
                                                   logits=valid_logits,
                                                   model=model,
                                                   criterion=criterion
                                                   )
      valid_loss, valid_acc, valid_iou = sess.run([valid_loss_op, valid_acc_op, valid_iou_op])

      if(i % args.report_freq == 0):
        tq2.set_postfix({
                "Valid Loss": valid_loss,
                "Valid Acc": valid_acc,
                "Valid IoU": valid_iou[0]
                })
        valid_unions.append(valid_iou[1])
        valid_intersections.append(valid_iou[2])
    
    # Calculation of train miou
    valid_unions = np.array(valid_unions)
    valid_intersections = np.array(valid_intersections)
    valid_non_zero_mask = valid_unions != 0
    valid_miou = np.mean(valid_intersections[valid_non_zero_mask])/(np.mean(valid_unions[valid_non_zero_mask]) + 1e-6)
    
    #Log Miou
    print(valid_miou)

    mious.append(valid_miou)
  
  np.save(os.path.join(args.save,"mIoUs.npy"), mIoUs)

## Test

In [None]:
args = {
    "momentum": 0.9,
    "weight_decay": 3e-4,
    "arch_learning_rate": 3e-1,
    "arch_weight_decay": 1e-3,
    "momentum": 0.9,
    "grad_clip": 5,
    "learning_rate_min": 0.001,
    "learning_rate": 0.025,
    "unrolled": True,
    "epochs": 10,
    "batch_size": 4,
    "save": "EXP"
}

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

args = Struct(**args)

In [None]:
if __name__ == '__main__':
    args = parse_args()

    args.save = 'search-{}-{}'.format(args.save,
                                      time.strftime("%Y%m%d-%H%M%S"))
    utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    main(args)