In [4]:
%load_ext autoreload
%autoreload 2
import dataset
import models
import pickle
from glob import glob

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
import tensorflow as tf
tf.__version__

'1.7.0'

In [6]:
batch_size = 32
lr_schedule = [(1, 0.001), (5, 0.002), (5, 0.001), (5, 0.0001)]
model_dir = 'models/test_model'
image_shape = [48, 48, 1]

In [7]:
train_input_func, num_train_examples = dataset.get_input_function(
    "data/omniglot_training_48_48.pkl", 
    batch_size, do_augmentation=True)
valid_input_func, num_valid_examples = dataset.get_input_function(
    "data/omniglot_validation_48_48.pkl", 
    batch_size, do_augmentation=False)

In [8]:
train_epochs = sum([sh[0] for sh in lr_schedule])
steps_per_epoch = num_train_examples // batch_size
steps_per_validation = 32
train_epochs, steps_per_epoch, steps_per_validation

(16, 602, 32)

In [9]:
class_mapping = pickle.load(open('data/classes_level0_mapping.pkl', 'rb'))
num_classes = len(class_mapping)
class_mapping, num_train_examples, num_valid_examples, num_classes

({'Angelic': 0,
  'Atemayar_Qelisayer': 1,
  'Atlantean': 2,
  'Aurek-Besh': 3,
  'Avesta': 4,
  'Ge_ez': 5,
  'Glagolitic': 6,
  'Gurmukhi': 7,
  'Kannada': 8,
  'Keble': 9,
  'Malayalam': 10,
  'Manipuri': 11,
  'Mongolian': 12,
  'Old_Church_Slavonic_(Cyrillic)': 13,
  'Oriya': 14,
  'Sylheti': 15,
  'Syriac_(Serto)': 16,
  'Tengwar': 17,
  'Tibetan': 18,
  'ULOG': 19},
 19280,
 13180,
 20)

In [10]:
model_func = models.resnet_model_fn(
    feature_extractor_fn = models.resnet_feature_extractor,
    head_fn = models.classification_head,
    loss_fn = models.classification_loss,
    metrics_fn = models.classification_metrics,
    optimizer_fn = models.adam_optimizer_fn(),
    learning_rate_fn = models.learning_rate_scheduler(
        schedule=lr_schedule, 
        batch_size=batch_size, 
        num_examples=num_train_examples
    ),
    global_losses_fns = [models.global_averaged_l2_loss(0.001)]
)

In [11]:
import hooks
from tqdm import tqdm

In [12]:
session_config = tf.ConfigProto(
    inter_op_parallelism_threads=8,
    intra_op_parallelism_threads=8,
    allow_soft_placement=True)

run_config = tf.estimator.RunConfig().replace(
    save_checkpoints_secs=1e9,
    session_config=session_config)

classifier = tf.estimator.Estimator(
    model_fn=model_func, 
    model_dir=model_dir, 
    config=run_config,
    params={
        'num_classes': num_classes,
        'image_shape': image_shape
    }
)


for epoch in tqdm(range(train_epochs)):
        
    print(f'Starting a training epoch [{epoch}/{train_epochs}]')
    
    train_hooks = [
        hooks.get_examples_per_second_hook(batch_size=batch_size),
        hooks.get_profiler_hook(),
        hooks.get_logging_tensor_hook(tensors_to_log=[
            'head_accuracy', 'cross_entropy', 'learning_rate'])        
    ]
    
    classifier.train(
        input_fn=train_input_func, 
        hooks=train_hooks,
        max_steps=steps_per_epoch)

    print('Starting to evaluate')
    
    eval_results = classifier.evaluate(
        input_fn=valid_input_func, 
        steps=steps_per_validation
    )
    print(eval_results)

INFO:tensorflow:Using config: {'_model_dir': 'models/test_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 1000000000.0, '_session_config': intra_op_parallelism_threads: 8
inter_op_parallelism_threads: 8
allow_soft_placement: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f69b758bb70>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


  0%|          | 0/16 [00:00<?, ?it/s]

Starting a training epoch [0/16]
INFO:tensorflow:Calling model_fn.
input_image: Tensor("Reshape:0", shape=(?, 48, 48, 1), dtype=float32)
group 0 with shape: (?, 22, 22, 32)
group 1 with shape: (?, 11, 11, 64)
group 2 with shape: (?, 6, 6, 128)
Trainable params: 729984
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


ResourceExhaustedError: models/test_model/graph.pbtxt.tmp92254be3217b40198c97e45d3d3d06da; No space left on device