In [None]:
!nvidia-smi

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
import numpy as np
import cv2
import os
import importlib
import time
import albumentations as albu

from utils import cpm_utils

# CONFIG

Config stuff

In [None]:
class FLAGS(object):
    """ """
    """
    General settings
    """
    input_size = (512, 512)
    heatmap_size = 64
    cpm_stages = 3
    joint_gaussian_variance = 4.0
    center_radius = 21
    num_of_joints = 8
    color_channel = 'RGB'
    normalize = True
    use_gpu = True
    gpu_id = 0


    """
    Demo settings
    """
    # 'MULTI': show multiple stage heatmaps
    # 'SINGLE': show last stage heatmap
    # 'Joint_HM': show last stage heatmap for each joint
    # 'image or video path': show detection on single image or video
    DEMO_TYPE = 'SINGLE'

    model_path = 'cpm_hand'
    cam_id = 0

    webcam_height = 480
    webcam_width = 640

    use_kalman = True
    kalman_noise = 0.03
    keypoints_order = ["TAIL_NOTCH",
                        "ADIPOSE_FIN",
                        "UPPER_LIP",
                        "ANAL_FIN",
                        "PELVIC_FIN",
                        "EYE",
                        "PECTORAL_FIN",
                        "DORSAL_FIN"]

    """
    Training settings
    """
    network_def = 'fish_test'
    train_img_dir = ''
    val_img_dir = ''
    bg_img_dir = ''
    pretrained_model = 'fish_test'
    batch_size = 8
    init_lr = 0.001
    lr_decay_rate = 0.5
    lr_decay_step = 10000
    augmentation = None
    augmentation = albu.Compose([#albu.RandomContrast(limit=0.3, p=0.3),
                                 #albu.RandomBrightness(limit=0.4, p=0.3),
                                 albu.Rotate(limit=10, p=1.0)], 
                                 p=1.0,
                                 keypoint_params={'format': 'xy'})
    
    epochs=300
#     augmentation_config = {'hue_shift_limit': (-5, 5),
#                            'sat_shift_limit': (-10, 10),
#                            'val_shift_limit': (-15, 15),
#                            'translation_limit': (-0.15, 0.15),
#                            'scale_limit': (-0.3, 0.5),
#                            'rotate_limit': (-90, 90)}
    hnm = True  # Make sure generate hnm files first
    do_cropping = True

    """
    For Freeze graphs
    """
    output_node_names = 'stage_3/mid_conv7/BiasAdd:0'

In [None]:
cpm_model = importlib.import_module('models.nets.' + FLAGS.network_def)

# MODEL CREATION

Creating a bunch of folder

In [None]:
from datetime import datetime

In [None]:
datenow = str(datetime.now()).split(".")[0].replace(" ","_").replace("-","_").replace(":","_")

In [None]:
base_dir = "/root/data/models/keypoints_detection/{}".format(datenow)

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

model_path_suffix = os.path.join(FLAGS.network_def,
                                 'input_{}_output_{}'.format(FLAGS.input_size, FLAGS.heatmap_size),
                                 'joints_{}'.format(FLAGS.num_of_joints),
                                 'stages_{}'.format(FLAGS.cpm_stages),
                                 'init_{}_rate_{}_step_{}'.format(FLAGS.init_lr, FLAGS.lr_decay_rate,
                                                                  FLAGS.lr_decay_step)
                                 )
model_save_dir = os.path.join(base_dir,
                              'weights')

Build network graph

In [None]:
model = cpm_model.CPM_Model(input_size=FLAGS.input_size,
                            heatmap_size=FLAGS.heatmap_size,
                            stages=FLAGS.cpm_stages,
                            joints=FLAGS.num_of_joints,
                            img_type=FLAGS.color_channel,
                            is_training=True)
model.build_loss(FLAGS.init_lr, FLAGS.lr_decay_rate, FLAGS.lr_decay_step, optimizer='RMSProp')
print('=====Model Build=====\n')

# DATA GENERATOR

Creating data generator. 

In [None]:
import glob
import json
import random

import matplotlib.pyplot as plt
from PIL import Image, ImageDraw

from utils.utils import DataGenerator

In [None]:
random.seed(258)

Load all the gtsf session

In [None]:
annotations = glob.glob("/root/data/gtsf_phase_I/*/*_cogito_annotations.json")
print("Total number of gtsf sessions: {}".format(len(annotations)))

Train - Val split. Let's split by experiment. Better practice

In [None]:
cutoff = int(len(annotations)*0.8)
random.shuffle(annotations)
train_files = annotations[:cutoff]
val_files = annotations[cutoff:]
print("Number of train files: {}".format(len(train_files)))
print("Number of validation files: {}".format(len(val_files)))

In [None]:
train_annotations = []
for jpath in train_files:
    train_annotations += json.load(open(jpath))
print("Number of training data: {}".format(len(train_annotations)))
train_annotations= [ann for ann in train_annotations if ann["species"] == "salmon"]
print("Number of training data: {}".format(len(train_annotations)))

In [None]:
val_annotations = []
for jpath in val_files:
    val_annotations += json.load(open(jpath))
print("Number of validation data: {}".format(len(val_annotations)))
val_annotations = [ann for ann in val_annotations if ann["species"] == "salmon"]
print("Number of validation data: {}".format(len(val_annotations)))

In [None]:
# for ann in val_annotations + train_annotations:
#     local_path = os.path.join("/root/data/gtsf_phase_I/", 
#           "/".join(ann["Labeled Data"].split("/")[7:]))
#     if not os.path.isfile(local_path):
#         print(local_path)

Create generator

In [None]:
import os
import cv2
import numpy as np
from utils.utils import load_image_keypoints, DataGenerator

Test the function

In [None]:
im, kps = load_image_keypoints(np.random.choice(val_annotations), FLAGS)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(im)
plt.scatter(kps[:, 0], kps[:, 1])
for i in range(8):
    plt.text(kps[i, 0], kps[i,1], FLAGS.keypoints_order[i])
plt.show()

Create generator itself

In [None]:
train_generator = DataGenerator(train_annotations, FLAGS)
val_generator = DataGenerator(train_annotations, FLAGS)

In [None]:
xb, yb, heatmaps = val_generator[4]

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(xb[1, ...])
i = 0
for kpp in yb[1]:
    plt.scatter(kpp[0], kpp[1], c="r")
    plt.text(kpp[0], kpp[1], FLAGS.keypoints_order[i], {"color": "w"})
    i += 1
plt.show()

Augmentations

In [None]:
# plt.imshow(im)
# plt.show()
# results = FLAGS.augmentation(image=im, keypoints=list(kps.values()))
# nkps = np.array(results['keypoints'])
# plt.imshow(results["image"])
# plt.scatter(nkps[:, 0], nkps[:, 1])
# plt.show()

# TRAINING

In [None]:
if not os.path.isdir(model_save_dir):
    os.makedirs(model_save_dir)

In [None]:
# save config
with open(os.path.join(base_dir, "config.json"), "w") as f:
    json.dump({k:v for (k,v) in FLAGS.__dict__.items() if k not in  ["__dict__", '__weakref__', 'augmentation']}, f)

In [None]:
def print_current_training_stats(global_step, cur_lr, stage_losses, total_loss, time_elapsed):
    stats = 'Step: {}/{} ----- Cur_lr: {:1.7f} ----- Time: {:>2.2f} sec.'.format(global_step, train_generator.__len__() * FLAGS.epochs,
                                                                                 cur_lr, time_elapsed)
    losses = ' | '.join(
        ['S{} loss: {:>7.2f}'.format(stage_num + 1, stage_losses[stage_num]) for stage_num in range(FLAGS.cpm_stages)])
    losses += ' | Total loss: {}'.format(total_loss)
    print(stats)
    print(losses + '\n')

In [None]:
merged_summary = tf.summary.merge_all()
device_count = {'GPU': 1} if FLAGS.use_gpu else {'GPU': 0}

# cause fuck tensorboard
history = {"train_stages_loss":[],
           "train_total_loss": [],
           "val_total_loss": []}
with tf.Session(config=tf.ConfigProto(device_count=device_count,
                                      allow_soft_placement=True)) as sess:
#     # Create tensorboard
#     train_writer = tf.summary.FileWriter(train_log_save_dir, sess.graph)
#     test_writer = tf.summary.FileWriter(test_log_save_dir, sess.graph)
    
    # Create model saver
    saver = tf.train.Saver(max_to_keep=None) #max_to_keep=None)

    # Init all vars
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    # saver.restore(sess, '/root/data/models/keypoints_detection/2019_03_27_01_52_17/weights/fish_test-39')
    train_generator = DataGenerator(train_annotations, FLAGS)
    val_generator = DataGenerator(train_annotations, FLAGS)
    
    for epoch in range(FLAGS.epochs):
        print("Epoch {} starts".format(epoch))
        
        training_itr = 0
        
        while training_itr < train_generator.__len__():
            t1 = time.time()
            # load input + labels
            batch_x_np, _, batch_gt_heatmap_np = train_generator[training_itr]
            training_itr += 1
            
            # Forward and update weights
            stage_losses_np, total_loss_np, _, summaries, current_lr, \
            stage_heatmap_np, global_step = sess.run([model.stage_loss,
                                                  model.total_loss,
                                                  model.train_op,
                                                  merged_summary,
                                                  model.cur_lr,
                                                  model.stage_heatmap,
                                                  model.global_step
                                                  ],
                                                 feed_dict={model.input_images: batch_x_np,
                                                            model.gt_hmap_placeholder: batch_gt_heatmap_np})
            
            history["train_stages_loss"].append([float(s) for s in stage_losses_np])
            history["train_total_loss"].append(float(total_loss_np))
            # Show training info
            print_current_training_stats(global_step, current_lr, stage_losses_np, total_loss_np, time.time() - t1)
            
            # Write logs
            # train_writer.add_summary(summaries, epoch*training_itr)
        
        # shuffle on epoch end
        train_generator.on_epoch_end()
        
        saver.save(sess=sess, save_path=model_save_dir + '/' + FLAGS.network_def.split('.py')[0], 
                   global_step=epoch)
        print('\nModel checkpoint saved...\n')
        
        # now validation stuff
        mean_val_loss = 0
        val_itr = 0
        while val_itr < val_generator.__len__():
            # load input + labels
            batch_x_np, _, batch_gt_heatmap_np = val_generator[val_itr]
            val_itr += 1

            val_total_loss, summaries = sess.run([model.total_loss, merged_summary],
                                                 feed_dict={model.input_images: batch_x_np,
                                                               model.gt_hmap_placeholder: batch_gt_heatmap_np})
            mean_val_loss += val_total_loss
        val_mean_loss = mean_val_loss / val_generator.__len__()
        history["val_total_loss"].append(float(val_mean_loss))
        print('\nValidation loss: {:>7.2f}\n'.format(val_mean_loss))
        # test_writer.add_summary(summaries, global_step)
        # save history
        with open(os.path.join(base_dir, "history.json"), "w") as f:
            json.dump(history, f)
        
        print("#"*100)

Some visualization code

In [None]:
final_stage_heatmap = stage_heatmap_np[-1][0, ...]
print(final_stage_heatmap.shape)
f, ax = plt.subplots(5, 2, figsize=(20, 30))
c = 0
for i in range(5):
    for j in range(2):
        if c == 9:
            continue
        hm = cv2.resize(final_stage_heatmap[..., c], FLAGS.input_size)
        hm_max = np.where(hm == hm.max())
        ax[i,j].imshow(xb[0, ...])
        ax[i,j].imshow(hm, alpha=0.5)
        ax[i,j].scatter(hm_max[1], hm_max[0], c="r")
        ax[i,j].axis("off")
        c+=1
plt.show()