In [2]:
import os
import json
import pprint
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import inflection as inf
from collections import defaultdict

In [7]:
#######################################################################################
############################# CHANGE FILENAMES HERE ###################################
#######################################################################################

annotation_dir = "../annotations/splits"
determiners = ["a", "an", "all", "any", "every", "my", "your", "this", "that", "these", "those", "some", "many", "few", "both", "neither", "little", "much", "either", "our", "no", "several", "half", "each", "the"]
# filenames = ["annotations_val.json", "annotations_train.json", "annotations_test.json"]
filenames = ["train_001.json", "train_005.json", "train_010.json", "train_025.json", "train_050.json"]
save_dir = "../annotations/tfrecords/splits/"


In [8]:
temp = json.load(open("../annotations/annotations_val.json", 'r'))
categories = temp["categories"]
n_categories = len(categories)

In [9]:
def image_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.encode_png(value).numpy()])
    )

def bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))

def float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_feature_list(value):
    """Returns a list of int_list from a int."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def float_feature_list(value):
    """Returns a list of float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def object_features_list(objects): 
    feature_list = []
    for obj in objects:
        feature_list.append(int64_feature(obj["id"]))
    return tf.train.FeatureList(feature=feature_list)
    

def create_example(image_id, sample):
    caption = sample["image"]["caption"]
    file_name = sample["image"]["file_name"]
    det = caption.split()[0]
    noun = " ".join(caption.split()[1:])
    determiners.index(det)
    noun_id = 0 
    
    noun = inf.singularize(noun)
    for cat in categories: 
        if cat["name"] == noun: 
            noun_id = cat["id"]

    det_one_hot = [0 for i in range(len(determiners))]
    det_one_hot[determiners.index(det)] = 1
    noun_one_hot = [0 for i in range(len(categories))]
    noun_one_hot[noun_id] = 1
    caption_one_hot = det_one_hot + noun_one_hot
    
    max_bboxes = 20 
    
    input_one_hot = []
    output_one_hot =[]

    inputs = sample["inputs"]
    outputs = sample["outputs"]
    
    for ann in outputs: 
        one_hot = [0 for i in range(n_categories)]
        one_hot[ann["category_id"]] = 1
        output_one_hot.append(ann["bbox"] + [1] + one_hot)

    for j in range(len(output_one_hot), max_bboxes): 
        one_hot = [0 for i in range(n_categories)]
        output_one_hot.append([0,0,0,0] + [0] + one_hot)
            
    for ann in inputs: 
        one_hot = [0 for i in range(n_categories)]
        one_hot[ann["category_id"]] = 1
        input_one_hot.append(ann["bbox"] + [1] + one_hot + [ann["liqLevel"]])
        
    for j in range(len(input_one_hot), max_bboxes): 
        one_hot = [0 for i in range(n_categories)]
        input_one_hot.append([0,0,0,0] + [0] + one_hot)
        
    context = {
        # "image": image_feature(image),
        "file_name": bytes_feature(file_name),
        "image_id": int64_feature(image_id),
        "caption": bytes_feature(caption),
        "caption_one_hot": int64_feature_list(caption_one_hot),
        "areas": int64_feature_list([ann['area'] for ann in inputs]), 
        "category_ids": int64_feature_list([ann['category_id'] for ann in inputs]), 
        "output_category_ids": int64_feature_list([ann['category_id'] for ann in outputs]),        
        "output_areas": int64_feature_list([ann['area'] for ann in outputs]),
    }

    feature_list = {
        "input_bboxes": tf.train.FeatureList(feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=ann['bbox'])) for ann in inputs]),
        "output_bboxes": tf.train.FeatureList(feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=ann['bbox'])) for ann in outputs]), 
        "input_one_hot": tf.train.FeatureList(feature=[tf.train.Feature(float_list=tf.train.FloatList(value=val)) for val in input_one_hot]),
        "output_one_hot": tf.train.FeatureList(feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=val)) for val in output_one_hot]),
    }
     
    return tf.train.SequenceExample(context=tf.train.Features(feature=context), feature_lists=tf.train.FeatureLists(feature_list=feature_list))

def parse_tfrecord_fn(example, labeled=True):
    feature_description = {
        "file_name": tf.io.FixedLenFeature([], tf.string),
        #         "image": tf.io.FixedLenFeature([], tf.string),
        "image_id": tf.io.FixedLenFeature([], tf.int64),
        "caption": tf.io.VarLenFeature(tf.string),
        "caption_one_hot": tf.io.VarLenFeature(tf.int64),
        "areas": tf.io.VarLenFeature(tf.int64),
        "category_ids": tf.io.VarLenFeature(tf.int64),
        "output_category_ids": tf.io.VarLenFeature(tf.int64),
        "output_areas": tf.io.VarLenFeature(tf.int64)
    }

    sequence_features = {
        "input_bboxes": tf.io.VarLenFeature(tf.int64),
        "output_bboxes": tf.io.VarLenFeature(tf.int64),
        "input_one_hot": tf.io.VarLenFeature(tf.float32),
        "output_one_hot": tf.io.VarLenFeature(tf.int64)
    }
    context, sequence = tf.io.parse_single_sequence_example(example, context_features=feature_description,
                                                            sequence_features=sequence_features)

    example = {**context, **sequence}
    for key in example.keys():
        if type(example[key]) == tf.sparse.SparseTensor:
            if (example[key].dtype == "string"):
                example[key] = tf.sparse.to_dense(example[key], default_value='b')
            else:
                example[key] = tf.sparse.to_dense(example[key])

    prefix = "../DetermiNetProject/Assets/StreamingAssets/dataset/"
    raw = tf.io.read_file(prefix + example["file_name"])
    example["image"] = tf.io.decode_png(raw, channels=3)

    return example

In [10]:

for filename in filenames:
    annotation_filepath = os.path.join(annotation_dir, filename)
    dataset = json.load(open(annotation_filepath, 'r'))
    images = dataset["images"]
    input_annotations = dataset["input_oracle_annotations"]
    output_annotations = dataset["annotations"]
    n_samples = 4096
    n_tfrecords = len(images) // n_samples
    if len(images) % n_samples: 
        n_tfrecords += 1 
    
    split_dir = os.path.join(save_dir, filename.split(".")[0].split("_")[1])
    print(split_dir)
    if not os.path.exists(split_dir): 
        os.makedirs(split_dir)

    dataset_samples = defaultdict(lambda: defaultdict(list))

    for ann in input_annotations: 
        dataset_samples[ann["image_id"]]["inputs"].append(ann)
        
    for ann in output_annotations: 
        dataset_samples[ann["image_id"]]["outputs"].append(ann)

    for image in images: 
        dataset_samples[image["id"]]["image"] = image

    keys = list(dataset_samples.keys())
    for tfrec_num in range(n_tfrecords): 
        sample_keys = keys[tfrec_num*n_samples : (tfrec_num + 1) * n_samples]
        
        with tf.io.TFRecordWriter(
            split_dir + "/file_%.2i-%i.tfrec" % (tfrec_num, len(sample_keys))
        ) as writer:
            for key in sample_keys:
                dataset_sample = dataset_samples[key] 
                input_objects = dataset_samples[key]["input"]
                output_objects = dataset_samples[key]["output"]
                example = create_example(key, dataset_sample)
                writer.write(example.SerializeToString())

../annotations/tfrecords/splits/001
../annotations/tfrecords/splits/005
../annotations/tfrecords/splits/010
../annotations/tfrecords/splits/025
../annotations/tfrecords/splits/050
