In [None]:
!nvidia-smi

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
import numpy as np
import cv2
import os
import importlib
import time
import albumentations as albu

from utils import cpm_utils

# CONFIG

Config stuff

In [None]:
class FLAGS(object):
    """ """
    """
    General settings
    """
    input_size = (1024, 1024)
    heatmap_size = 128
    cpm_stages = 3
    joint_gaussian_variance = 1.0
    center_radius = 21
    num_of_joints = 8
    color_channel = 'RGB'
    normalize = True
    use_gpu = True
    gpu_id = 0
    
    gradient_clipping = True # gradient clipping

    keypoints_order = ["TAIL_NOTCH",
                        "ADIPOSE_FIN",
                        "UPPER_LIP",
                        "ANAL_FIN",
                        "PELVIC_FIN",
                        "EYE",
                        "PECTORAL_FIN",
                        "DORSAL_FIN"]

    """
    Training settings
    """
    network_def = 'fish_test'
    train_img_dir = ''
    val_img_dir = ''
    bg_img_dir = ''
    pretrained_model = 'fish_test'
    batch_size = 4
    init_lr = 0.001
    lr_decay_rate = 0.5
    lr_decay_step = 10000
    steps_per_epoch = 1000
    val_steps_per_epochs = 250
    
    augmentation = None
    buffer_range = [int(n) for n in np.arange(100, 600, 100)] # useless if crop = False
    crop = False # crop input image based on keypoints - for GTSF only
#     augmentation = albu.Compose([albu.HorizontalFlip(p=0.5),
# #                                  albu.Rotate(limit=10, p=1.0)
#                                 ], 
#                                  p=1.0,
#                                  keypoint_params={'format': 'xy'})
    
    epochs=200

    hnm = True  # Make sure generate hnm files first
    do_cropping = True

    """
    For Freeze graphs
    """
    output_node_names = 'stage_3/mid_conv7/BiasAdd:0'
    validation_files = None
    training_files = None

In [None]:
cpm_model = importlib.import_module('models.nets.' + FLAGS.network_def)

# MODEL CREATION

Creating a bunch of folder

In [None]:
from datetime import datetime

In [None]:
datenow = str(datetime.now()).split(".")[0].replace(" ","_").replace("-","_").replace(":","_")

In [None]:
base_dir = "/root/data/models/keypoints_detection/{}".format(datenow)

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

model_path_suffix = os.path.join(FLAGS.network_def,
                                 'input_{}_output_{}'.format(FLAGS.input_size, FLAGS.heatmap_size),
                                 'joints_{}'.format(FLAGS.num_of_joints),
                                 'stages_{}'.format(FLAGS.cpm_stages),
                                 'init_{}_rate_{}_step_{}'.format(FLAGS.init_lr, FLAGS.lr_decay_rate,
                                                                  FLAGS.lr_decay_step)
                                 )
model_save_dir = os.path.join(base_dir,
                              'weights')

Build network graph

In [None]:
model = cpm_model.CPM_Model(input_size=FLAGS.input_size,
                            heatmap_size=FLAGS.heatmap_size,
                            stages=FLAGS.cpm_stages,
                            joints=FLAGS.num_of_joints,
                            img_type=FLAGS.color_channel,
                            is_training=True)
model.build_loss(FLAGS.init_lr, 
                 FLAGS.lr_decay_rate, 
                 FLAGS.lr_decay_step, 
                 optimizer='RMSProp', 
                 clipping=FLAGS.gradient_clipping)
print('=====Model Build=====\n')

# DATA GENERATOR

Creating data generator. 

In [None]:
import glob
import json
import random

import matplotlib.pyplot as plt
from PIL import Image, ImageDraw

from utils.utils import DataGenerator

In [None]:
random.seed(193)

Load all the gtsf session

In [None]:
annotations = glob.glob('/root/data/bati/tfrecords/*')
print("Total number of days: {}".format(len(annotations)))

Train - Val split. Let's split by experiment. Better practice

In [None]:
cutoff = int(len(annotations)*0.8)
random.shuffle(annotations)
train_files = annotations[:cutoff]
val_files = annotations[cutoff:]
print("Number of training files: {}".format(len(train_files)))
print("Number of validation files: {}".format(len(val_files)))

In [None]:
FLAGS.validation_files = val_files
FLAGS.training_files = train_files

Create generator

In [None]:
def extract_fn(data_record):
    features = {'image': tf.FixedLenFeature([], tf.string),
                'heatmaps': tf.FixedLenFeature([], tf.string),
                'keypoints': tf.FixedLenFeature([], tf.string)}
    sample = tf.parse_single_example(data_record, features)
    
    image = tf.decode_raw(sample['image'], tf.uint8)      
    image.set_shape([FLAGS.input_size[0] * FLAGS.input_size[1] * 3])
    image = tf.reshape(image, [FLAGS.input_size[0], FLAGS.input_size[1], 3])
    
    heatmaps = tf.decode_raw(sample['heatmaps'], tf.float64)
    heatmaps.set_shape([FLAGS.heatmap_size * FLAGS.heatmap_size * (FLAGS.num_of_joints + 1)])
    heatmaps = tf.reshape(heatmaps, [FLAGS.heatmap_size, FLAGS.heatmap_size, (FLAGS.num_of_joints + 1)])
    
    keypoints = tf.decode_raw(sample['keypoints'], tf.int64)
    keypoints.set_shape([FLAGS.num_of_joints*2])
    keypoints = tf.reshape(keypoints, [FLAGS.num_of_joints, 2])
    
    return image, keypoints, heatmaps

In [None]:
def create_generator(files):
    # Initialize all tfrecord paths
    dataset = tf.data.TFRecordDataset(files).apply(tf.contrib.data.shuffle_and_repeat(100))
    dataset = dataset.map(extract_fn, num_parallel_calls=12)
#     dataset = dataset.shuffle(1000)
    dataset = dataset.batch(FLAGS.batch_size)
    dataset = dataset.prefetch(4)
    iterator = dataset.make_one_shot_iterator()
    out = iterator.get_next()
    return out


In [None]:
train_iterator = create_generator(train_files)
val_iterator = create_generator(val_files)

# TRAINING

In [None]:
if not os.path.isdir(model_save_dir):
    os.makedirs(model_save_dir)
print(model_save_dir)

In [None]:
# save config
with open(os.path.join(base_dir, "config.json"), "w") as f:
    json.dump({k:v for (k,v) in FLAGS.__dict__.items() if k not in  ["__dict__", '__weakref__', 'augmentation']}, f)

In [None]:
def print_current_training_stats(global_step, cur_lr, stage_losses, total_loss, time_elapsed):
    stats = 'Step: {}/{} ----- Cur_lr: {:1.7f} ----- Time: {:>2.2f} sec.'.format(global_step, FLAGS.steps_per_epoch * FLAGS.epochs,
                                                                                 cur_lr, time_elapsed)
    losses = ' | '.join(
        ['S{} loss: {:>7.2f}'.format(stage_num + 1, stage_losses[stage_num]) for stage_num in range(FLAGS.cpm_stages)])
    losses += ' | Total loss: {}'.format(total_loss)
    print(stats)
    print(losses + '\n')

In [None]:
merged_summary = tf.summary.merge_all()
device_count = {'GPU': 1} if FLAGS.use_gpu else {'GPU': 0}

# cause fuck tensorboard
history = {"train_stages_loss":[],
           "train_total_loss": [],
           "val_total_loss": [],
           "learning_rate": []}


with tf.Session(config=tf.ConfigProto(device_count=device_count,
                                      allow_soft_placement=True)) as sess:
    # Create model saver
    saver = tf.train.Saver(max_to_keep=None) #max_to_keep=None)

    # Init all vars
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    for epoch in range(FLAGS.epochs):
        print("Epoch {} starts. Number of training steps {}".format(epoch, FLAGS.steps_per_epoch))
        
        training_itr = 0
        
        while training_itr < FLAGS.steps_per_epoch:
            t1 = time.time()
            # load input + labels
            batch_x_np, _, batch_gt_heatmap_np = sess.run(train_iterator)
            # print(time.time() - t1)
            training_itr += 1
            
            # Forward and update weights
            stage_losses_np, total_loss_np, _, summaries, current_lr, \
            stage_heatmap_np, global_step = sess.run([model.stage_loss,
                                                  model.total_loss,
                                                  model.train_op,
                                                  merged_summary,
                                                  model.cur_lr,
                                                  model.stage_heatmap,
                                                  model.global_step
                                                  ],
                                                 feed_dict={model.input_images: batch_x_np,
                                                            model.gt_hmap_placeholder: batch_gt_heatmap_np})
            # print(time.time() - t1)
            history["train_stages_loss"].append([float(s) for s in stage_losses_np])
            history["train_total_loss"].append(float(total_loss_np))
            history['learning_rate'].append(float(current_lr))
            # Show training info
            if global_step % 10 == 0:
                print_current_training_stats(global_step, current_lr, stage_losses_np, total_loss_np, time.time() - t1)

        saver.save(sess=sess, save_path=model_save_dir + '/' + FLAGS.network_def.split('.py')[0], 
                   global_step=epoch)
        print('\nModel checkpoint saved...\n')
        
        # now validation stuff
        mean_val_loss = 0
        val_itr = 0
        while val_itr < FLAGS.val_steps_per_epochs:
            # load input + labels
            batch_x_np, _, batch_gt_heatmap_np = val_iterator
            val_itr += 1

            val_total_loss, summaries = sess.run([model.total_loss, merged_summary],
                                                 feed_dict={model.input_images: batch_x_np,
                                                               model.gt_hmap_placeholder: batch_gt_heatmap_np})
            mean_val_loss += val_total_loss
        
        val_mean_loss = mean_val_loss / FLAGS.val_steps_per_epochs
        history["val_total_loss"].append(float(val_mean_loss))
        print('\nValidation loss: {:>7.2f}\n'.format(val_mean_loss))
        # save history
        with open(os.path.join(base_dir, "history.json"), "w") as f:
            json.dump(history, f)
        
        print("#"*100)