In [1]:
import os
import time
from collections import defaultdict
import json
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from skimage.io import imread, imsave
from skimage.transform import resize

import tensorflow as tf

SEED = 10
np.random.seed(SEED)

In [2]:
def get_partition(dataset_root, partition="train"):
    """Get the (img, annotation) pairs for the given partition."""
    assert (partition == "train") or (partition == "test")

    if partition == "train":
        BASE_IMG_DIR = os.path.join(dataset_root, "edges/imgs/train/rgbr/aug")
        BASE_ANNOT_DIR = os.path.join(
            dataset_root, "edges/edge_maps/train/rgbr/aug")

    else:
        BASE_IMG_DIR = os.path.join(dataset_root, "edges/imgs/test/rgbr")
        BASE_ANNOT_DIR = os.path.join(
            dataset_root, "edges/edge_maps/test/rgbr")

    root = Path(BASE_IMG_DIR)
    img_names = [str(fn).replace(BASE_IMG_DIR + '/', '')
                 for fn in root.glob("**/*.jpg")]
    img_names = sorted(img_names)

    img_set = [os.path.join(BASE_IMG_DIR, fn) for fn in img_names]
    annot_set = [os.path.join(BASE_ANNOT_DIR, fn.replace(".jpg",
                                                         ".png")) for fn in img_names]
    return (img_set, annot_set)

# Helper functions for defining tf types
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [3]:
def create_tf_record(x_set, y_set, tfrecords_filename):
    """Writes given image/annotation pairs to the tfrecords file.
    The function reads each image/annotation pair given filenames
    of image and respective annotation and writes it to the tfrecord
    file.
    Parameters
    ----------
    filename_pairs : array of tuples (img_filepath, annotation_filepath)
        Array of tuples of image/annotation filenames
    tfrecords_filename : string
        Tfrecords filename to write the image/annotation pairs
    """
    writer = tf.io.TFRecordWriter(tfrecords_filename)

    for img_path, annotation_path in zip(x_set, y_set):
        img_data = tf.io.gfile.GFile(img_path, 'rb').read()
        annotation = tf.io.gfile.GFile(annotation_path, 'rb').read()
        
        example = tf.train.Example(features=tf.train.Features(feature={
            'image': _bytes_feature(img_data),
            'segmentation_mask': _bytes_feature(annotation)
        }))

        writer.write(example.SerializeToString())

    writer.close()
    print(f"Wrote tfrecord file to {tfrecords_filename}")

In [4]:
def create_metadata(trn_x, val_x, tst_x, metadata_filepath):
    """
    Creates a metadata file for the training scrit to use.
    """
    metadata = {
        "train_length": len(trn_x),
        "val_length": len(val_x),
        "test_length": len(tst_x),
    }
    
    with open(metadata_filepath, 'w') as wf:
        json.dump(metadata, wf, indent=4)

## Create TF Record Files

In [5]:
OUT_DIR = "./datasets/"
DATASET_DIR = '/home/mxs8x15/datasets/BIPED'

# Test only
(img_set_tst, annot_set_tst) = get_partition(DATASET_DIR, 'test')

# All augmented samples
(img_set_full, annot_set_full) = get_partition(DATASET_DIR, 'train')

# Randomly shuffle the indices
indices = np.arange(len(img_set_full))
np.random.shuffle(indices)

# Index into the original arrays in this order
img_set_full = [img_set_full[i] for i in indices]
annot_set_full = [annot_set_full[i] for i in indices]

# Create training and validation sets
CUTOFF = int(0.85 * len(indices))

img_set_trn = img_set_full[:CUTOFF]
annot_set_trn = annot_set_full[:CUTOFF]

img_set_val = img_set_full[CUTOFF:]
annot_set_val = annot_set_full[CUTOFF:]

In [6]:
%%time
create_tf_record(img_set_trn, annot_set_trn, os.path.join(OUT_DIR,
                                                          "biped_trn.tfrecord"))

create_tf_record(img_set_val, annot_set_val, os.path.join(OUT_DIR,
                                                          "biped_val.tfrecord"))

create_tf_record(img_set_tst, annot_set_tst, os.path.join(OUT_DIR,
                                                          "biped_tst.tfrecord"))

metadata_filepath = os.path.join(OUT_DIR, "meta.json")
create_metadata(img_set_trn, img_set_val, img_set_tst, metadata_filepath)

Wrote tfrecord file to ./datasets/biped_trn.tfrecord
Wrote tfrecord file to ./datasets/biped_val.tfrecord
Wrote tfrecord file to ./datasets/biped_tst.tfrecord
CPU times: user 17.4 s, sys: 17 s, total: 34.3 s
Wall time: 5min 30s
