In [1]:
! pip install objax

Collecting objax
  Downloading https://files.pythonhosted.org/packages/88/bd/b105855a6093bb0c05d42723ce5680d6149ed36f407f8f444917cf3e458b/objax-1.0.2.tar.gz
Collecting parameterized
  Downloading https://files.pythonhosted.org/packages/ba/6b/73dfed0ab5299070cf98451af50130989901f50de41fe85d605437a0210f/parameterized-0.7.4-py2.py3-none-any.whl
Building wheels for collected packages: objax
  Building wheel for objax (setup.py) ... [?25l[?25hdone
  Created wheel for objax: filename=objax-1.0.2-cp36-none-any.whl size=57983 sha256=ae7e201dae509b80a46594aa17967ac16ebd5fa7303ea33c077bd89bdd1b1bfc
  Stored in directory: /root/.cache/pip/wheels/a0/73/a8/df38d6aa0448d888b2111d2493e6aaf6b6320289250c83e891
Successfully built objax
Installing collected packages: parameterized, objax
Successfully installed objax-1.0.2 parameterized-0.7.4


In [2]:
import os
import time

import tensorflow
import objax
from objax.zoo.resnet_v2 import ResNet50
import jax
import jax.numpy as jn
import numpy as np

In [3]:
! ls ./drive/My\ Drive/mlin/flower-imgs

test  train  val


In [3]:
fnames = list(map(lambda fn: f'./drive/My Drive/mlin/flower-imgs/train/{fn}', os.listdir('./drive/My Drive/mlin/flower-imgs/train')))
test_names = list(map(lambda fn: f'./drive/My Drive/mlin/flower-imgs/val/{fn}', os.listdir('./drive/My Drive/mlin/flower-imgs/val')))

In [4]:
from absl import app, flags, logging

flags.DEFINE_string('model_dir', '', 'Model directory.')
flags.DEFINE_integer('train_device_batch_size', 64, 'Per-device training batch size.')
flags.DEFINE_integer('eval_device_batch_size', 250, 'Per-device eval batch size.')
flags.DEFINE_integer('max_eval_batches', -1, 'Maximum number of batches used for evaluation, '
                                             'zero or negative number means use all batches.')
flags.DEFINE_integer('eval_every_n_steps', 1000, 'How often to run eval.')
flags.DEFINE_float('num_train_epochs', 90, 'Number of training epochs.')
flags.DEFINE_float('base_learning_rate', 0.1, 'Base learning rate.')
flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay (L2 loss) coefficient.')
flags.DEFINE_boolean('use_sync_bn', True, 'If true then use synchronized batch normalization, '
                                          'otherwise use per-replica batch normalization.')
flags.DEFINE_string('tfds_data_dir', None, 'Optional TFDS data directory.')

FLAGS = flags.FLAGS

# running internally, not from cli
FLAGS = flags.FLAGS
FLAGS(['a', 'b', 'c'])
FLAGS.train_device_batch_size

64

In [5]:
IMAGE_SIZE = [192, 192] 
BATCH_SIZE = FLAGS.train_device_batch_size #32 * 2 #strategy.num_replicas_in_sync
AUTO = tensorflow.data.experimental.AUTOTUNE

def decode_image(image_data):
    image = tensorflow.image.decode_jpeg(image_data, channels=3)
    image = tensorflow.cast(image, tensorflow.float32) / 255.0  
    image = tensorflow.reshape(image, [*IMAGE_SIZE, 3]) 
    return image


def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tensorflow.io.FixedLenFeature([], tensorflow.string), 
        "class": tensorflow.io.FixedLenFeature([], tensorflow.int64),  
    }
    example = tensorflow.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tensorflow.cast(example['class'], tensorflow.int32)
    return image, label 


def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tensorflow.io.FixedLenFeature([], tensorflow.string), 
        "id": tensorflow.io.FixedLenFeature([], tensorflow.string), 
    }
    example = tensorflow.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum 


def load_dataset(filenames, labeled = True, ordered = False):
    
    ignore_order = tensorflow.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False 
        
    dataset = tensorflow.data.TFRecordDataset(filenames, num_parallel_reads = AUTO) 
    dataset = dataset.with_options(ignore_order) 
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls = AUTO) 
    return dataset


def data_augment(image, label):

    image = tensorflow.image.random_flip_left_right(image)
    return image, label   

def get_training_dataset(dataset,do_aug=True):
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    if do_aug: dataset = dataset.map(transform, num_parallel_calls=AUTO)
    dataset = dataset.repeat() 
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    #dataset = dataset.prefetch(AUTO) 
    return dataset

def get_validation_dataset(dataset):
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) 
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) 
    return dataset

In [6]:
all_elements = get_training_dataset(load_dataset(fnames),do_aug=False)
itrrr = all_elements.as_numpy_iterator()

test_elements = get_training_dataset(load_dataset(test_names),do_aug=False).take(64 * 6)
test_iter = test_elements.as_numpy_iterator()

In [None]:
#batch = next(itrrr)
#tensorflow.transpose(batch[0], [0, 3, 1, 2]).shape

TensorShape([64, 3, 192, 192])

In [7]:
NUM_CLASSES = 104

class Experiment:
    # based on https://github.com/google/objax/blob/master/examples/classify/img/imagenet/imagenet_train.py

    def __init__(self):
        # Some constants
        total_batch_size = FLAGS.train_device_batch_size * jax.device_count()
        self.eval_batch_index = -1
        self.base_learning_rate = FLAGS.base_learning_rate * total_batch_size / 256
        # Create model
        bn_cls = objax.nn.SyncedBatchNorm2D if FLAGS.use_sync_bn else objax.nn.BatchNorm2D
        self.model = ResNet50(in_channels=3, num_classes=NUM_CLASSES, normalization_fn=bn_cls)
        self.model_vars = self.model.vars()
        #print(self.model_vars)
        # Create parallel eval op
        self.evaluate_batch_parallel = objax.Parallel(self.evaluate_batch, self.model_vars,
                                                      reduce=lambda x: x.sum(0))
        # Create parallel training op
        self.optimizer = objax.optimizer.Momentum(self.model_vars, momentum=0.9, nesterov=True)
        self.compute_grads_loss = objax.GradValues(self.loss_fn, self.model_vars)
        self.all_vars = self.model_vars + self.optimizer.vars()
        self.train_op_parallel = objax.Parallel(
            self.train_op, self.all_vars, reduce=lambda x: x[0])
        # Summary writer
        self.summary_writer = objax.jaxboard.SummaryWriter(os.path.join(
            FLAGS.model_dir, 'tb'))

    def evaluate_batch(self, images, labels):
        logits = self.model(images, training=False)
        num_correct = jn.count_nonzero(jn.equal(jn.argmax(logits, axis=1), labels))
        return num_correct

    def run_eval(self):
        print("eval:")
        """Runs evaluation and returns top-1 accuracy."""
        correct_pred = 0
        total_examples = 0
        test_iter = test_elements.as_numpy_iterator()
        for batch in test_iter:
            self.eval_batch_index += 1
            imgs = jn.array(tensorflow.transpose(batch[0], [0, 3, 1, 2]))

            correct_pred += self.evaluate_batch_parallel(imgs,
                                                         batch[1])
            total_examples += imgs.shape[0]
            if ((FLAGS.max_eval_batches > 0)
                    and (self.eval_batch_index + 1 >= FLAGS.max_eval_batches)):
                break

        return correct_pred / total_examples

    def loss_fn(self, images, labels):
        """Computes loss function.
        Args:
          images: tensor with images NCHW
          labels: tensors with dense labels, shape (batch_size,)
        Returns:
          Tuple (total_loss, losses_dictionary).
        """
        #print(images.shape)
        logits = self.model(images, training=True)
        xent_loss = objax.functional.loss.cross_entropy_logits_sparse(logits, labels).mean()
        wd_loss = FLAGS.weight_decay * 0.5 * sum((v.value ** 2).sum()
                                                 for k, v in self.model_vars.items()
                                                 if k.endswith('.w'))
        total_loss = xent_loss + wd_loss
        return total_loss, {'total_loss': total_loss,
                            'xent_loss': xent_loss,
                            'wd_loss': wd_loss}

    def learning_rate(self, epoch: float):
        """Computes learning rate for given fractional epoch."""
        # Linear warm up to base_learning_rate value for first 5 epochs.
        # Then use 1.0 * base_learning_rate until epoch 30
        # Then use 0.1 * base_learning_rate until epoch 60
        # Then use 0.01 * base_learning_rate until epoch 80
        # Then use 0.001 * base_learning_rate until until end of traning.
        lr_linear_till = 5
        boundaries = jn.array((30, 60, 80))
        values = jn.array([1., 0.1, 0.01, 0.001]) * self.base_learning_rate

        index = jn.sum(boundaries < epoch)
        lr = jn.take(values, index)
        return lr * jn.minimum(1., epoch / lr_linear_till)

    def train_op(self, images, labels, cur_epoch):
        cur_epoch = cur_epoch[0]  # because cur_epoch is array of size 1
        grads, (_, losses_dict) = self.compute_grads_loss(images, labels)
        grads = objax.functional.parallel.pmean(grads)
        losses_dict = objax.functional.parallel.pmean(losses_dict)
        learning_rate = self.learning_rate(cur_epoch)
        self.optimizer(learning_rate, grads)
        return dict(**losses_dict, learning_rate=learning_rate, epoch=cur_epoch)

    def train_and_eval(self):
        """Runs training and evaluation."""

        steps_per_epoch = 100 #num_exampls / (FLAGS.train_device_batch_size * jax.device_count())
        total_train_steps = int(steps_per_epoch * FLAGS.num_train_epochs)
        eval_every_n_steps = FLAGS.eval_every_n_steps

        checkpoint = objax.io.Checkpoint(FLAGS.model_dir, keep_ckpts=10)
        start_step, _ = checkpoint.restore(self.all_vars)
        cur_epoch = np.zeros([jax.local_device_count()], dtype=np.float32)
        for big_step in range(start_step, total_train_steps, eval_every_n_steps):
            print(f'Running training steps {big_step + 1} - {big_step + eval_every_n_steps}')
            with self.all_vars.replicate():
                # training
                start_time = time.time()
                for cur_step in range(big_step + 1, big_step + eval_every_n_steps + 1):
                    batch = next(itrrr)
                    cur_epoch[:] = cur_step / steps_per_epoch
                    imgs = jn.array(tensorflow.transpose(batch[0], [0, 3, 1, 2])) #.numpy()
                    # print(imgs.shape)
                    monitors = self.train_op_parallel(imgs, batch[1], cur_epoch)
                elapsed_train_time = time.time() - start_time
                # eval
                start_time = time.time()
                accuracy = self.run_eval()
                elapsed_eval_time = time.time() - start_time
            # save summary
            summary = objax.jaxboard.Summary()
            for k, v in monitors.items():
                summary.scalar(f'train/{k}', v)
            # # Uncomment following two lines to save summary with training images
            # summary.image('input/train_img',
            #               imagenet_data.normalize_image_for_view(batch['images'][0]))
            summary.scalar('test/accuracy', accuracy * 100)
            self.summary_writer.write(summary, step=cur_step)
            # save checkpoint
            checkpoint.save(self.all_vars, cur_step)
            # print info
            print('Step %d -- Epoch %.2f -- Loss %.2f  Accuracy %.2f'
                  % (cur_step, cur_step / steps_per_epoch,
                     monitors['total_loss'], accuracy * 100))
            print('    Training took %.1f seconds, eval took %.1f seconds'
                  % (elapsed_train_time, elapsed_eval_time), flush=True)

In [8]:
### Optional cell - activate TPU backend
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

grpc://10.7.42.194:8470


In [9]:
experiment = Experiment()
experiment.train_and_eval()

Running training steps 1 - 1000
eval:
Step 1000 -- Epoch 10.00 -- Loss 4.05  Accuracy 1.01
    Training took 590.5 seconds, eval took 193.4 seconds
Running training steps 1001 - 2000
eval:
Step 2000 -- Epoch 20.00 -- Loss 2.86  Accuracy 7.84
    Training took 577.7 seconds, eval took 186.2 seconds
Running training steps 2001 - 3000
eval:
Step 3000 -- Epoch 30.00 -- Loss 2.10  Accuracy 34.55
    Training took 582.6 seconds, eval took 185.4 seconds
Running training steps 3001 - 4000
eval:
Step 4000 -- Epoch 40.00 -- Loss 1.30  Accuracy 61.31
    Training took 583.7 seconds, eval took 187.2 seconds
Running training steps 4001 - 5000
eval:
Step 5000 -- Epoch 50.00 -- Loss 1.22  Accuracy 60.05
    Training took 585.2 seconds, eval took 185.6 seconds
Running training steps 5001 - 6000
eval:
Step 6000 -- Epoch 60.00 -- Loss 1.19  Accuracy 61.80
    Training took 583.4 seconds, eval took 185.3 seconds
Running training steps 6001 - 7000
eval:
Step 7000 -- Epoch 70.00 -- Loss 1.15  Accuracy 69.1