Create tf records files for faster training

In [None]:
import cv2
# import cpm_utils
import numpy as np
import math
import tensorflow as tf
import time
import random
import json
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils.utils import load_image_keypoints
import glob

In [None]:
class FLAGS(object):
    """ """
    """
    General settings
    """
    input_size = (1024, 1024)
    heatmap_size = 128
    cpm_stages = 4
    joint_gaussian_variance = 1.0
    center_radius = 21
    num_of_joints = 8
    color_channel = 'RGB'
    normalize = True
    use_gpu = True
    gpu_id = 0
    
    gradient_clipping = True # gradient clipping

    keypoints_order = ["TAIL_NOTCH",
                        "ADIPOSE_FIN",
                        "UPPER_LIP",
                        "ANAL_FIN",
                        "PELVIC_FIN",
                        "EYE",
                        "PECTORAL_FIN",
                        "DORSAL_FIN"]

    """
    Training settings
    """
    network_def = 'fish_test'
    train_img_dir = ''
    val_img_dir = ''
    bg_img_dir = ''
    pretrained_model = 'fish_test'
    batch_size = 4
    init_lr = 0.001
    lr_decay_rate = 0.45
    lr_decay_step = 8000
    augmentation = None
    buffer_range = [int(n) for n in np.arange(100, 600, 100)] # useless if crop = False
    crop = False # crop input image based on keypoints - for GTSF only
    
    epochs=200

    hnm = True  # Make sure generate hnm files first
    do_cropping = True

In [None]:
tfr_file = 'cpm_sample_dataset_512x512.tfrecords'
dataset_dir = 'utils/dataset/training/'

helper functions

In [None]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _float64_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [None]:
# labels = ['/root/data/bati/labels/labels_2019-04-16.json']
labels = glob.glob('/root/data/bati/labels/labels_2019-0*.json')

Create tf records here

In [None]:
from utils.cpm_utils import make_heatmap_single
from tqdm import tqdm

In [None]:
for jsonfile in tqdm(labels):
    date = os.path.basename(jsonfile).split('.')[0]
    annotations = json.load(open(jsonfile))
    print('{}. Total number of annotations: {}'.format(date, len(annotations)))
    record_path = '/root/data/bati/tfrecords/{}.records'.format(date)
    if os.path.isfile(record_path):
        continue
    with tf.python_io.TFRecordWriter(record_path) as writer:
        for ann in tqdm(annotations):
            try:
                image, keypoints = load_image_keypoints(ann, FLAGS)
                heatmap = make_heatmap_single(FLAGS.input_size[0], 
                                              FLAGS.heatmap_size, 
                                              FLAGS.joint_gaussian_variance,
                                              keypoints)
                image = image.flatten()
                heatmap = heatmap.flatten()
                keypoints = keypoints.flatten()

                # print(image.shape)
                img_bytes = image.tostring()
                heatmap_bytes = heatmap.tostring()
                kps_bytes = keypoints.tostring()

                data = {'image': _bytes_feature(img_bytes),
                        'heatmaps':  _bytes_feature(heatmap_bytes),
                        'keypoints': _bytes_feature(kps_bytes)}
    #             data = {'keypoints': _bytes_feature(kps_bytes)}
                feature = tf.train.Features(feature=data)
                example = tf.train.Example(features=feature)
                serialized = example.SerializeToString()
                writer.write(serialized)
            except Exception as e:
                print(e)
                continue

Check files

In [None]:
DEPTH = 3
HEIGHT = 1024
WIDTH = 1024

In [None]:
def extract_fn(data_record):
    features = {'image': tf.FixedLenFeature([], tf.string),
                'heatmaps': tf.FixedLenFeature([], tf.string),
                'keypoints': tf.FixedLenFeature([], tf.string)}
    sample = tf.parse_single_example(data_record, features)
    
    image = tf.decode_raw(sample['image'], tf.uint8)      
    image.set_shape([HEIGHT * WIDTH * DEPTH])
    image = tf.reshape(image, [HEIGHT, WIDTH, DEPTH])
    
    heatmaps = tf.decode_raw(sample['heatmaps'], tf.float64)
    heatmaps.set_shape([FLAGS.heatmap_size * FLAGS.heatmap_size * (FLAGS.num_of_joints + 1)])
    heatmaps = tf.reshape(heatmaps, [FLAGS.heatmap_size, FLAGS.heatmap_size, (FLAGS.num_of_joints + 1)])
    
    keypoints = tf.decode_raw(sample['keypoints'], tf.int64)
    keypoints.set_shape([FLAGS.num_of_joints*2])
    keypoints = tf.reshape(keypoints, [FLAGS.num_of_joints, 2])
    
    return image, keypoints, heatmaps

In [None]:
files = glob.glob('/root/data/bati/tfrecords/*')

In [None]:
# Initialize all tfrecord paths
dataset = tf.data.TFRecordDataset(files).repeat()
dataset = dataset.map(extract_fn, num_parallel_calls=4)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(FLAGS.batch_size)
dataset = dataset.prefetch(4)
iterator = dataset.make_one_shot_iterator()
image, keypoints, heatmaps = iterator.get_next()

In [None]:
from time import time

In [None]:
count = 0
with tf.Session() as sess:
    while True:
        print(count)
        count += 1
        start = time()
        out = sess.run(image)
        kps = sess.run(keypoints)
        hms = sess.run(heatmaps)
        end = time()
        print(end - start)
#         plt.imshow(image)
#         plt.scatter(keypoints[:, 0], keypoints[:, 1])
#         plt.show()