In [1]:
import warnings
warnings.filterwarnings("ignore")
import hashlib
import io
import logging
import os
import random
import re
import cv2

import contextlib2

import numpy as np
import PIL.Image
import tensorflow as tf

from object_detection.dataset_tools import tf_record_creation_util
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util

In [2]:
def dict_to_tf_example(filename,
                       mask_path,
                       label_map_dict,
                       img_path):

    with tf.gfile.GFile(img_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = PIL.Image.open(encoded_jpg_io)
    width = np.asarray(image).shape[1]
    height = np.asarray(image).shape[0]
  
    if image.format != 'JPEG':
        raise ValueError('Image format not JPEG')
    key = hashlib.sha256(encoded_jpg).hexdigest()
    with tf.gfile.GFile(mask_path, 'rb') as fid:
        encoded_mask_png = fid.read()
    encoded_png_io = io.BytesIO(encoded_mask_png)
    mask = PIL.Image.open(encoded_png_io)

    mask_np = np.asarray(mask.convert('L'))
    if mask.format != 'PNG':
        raise ValueError('Mask format not PNG')
        
    xmins = []
    ymins = []
    xmaxs = []
    ymaxs = []
    classes = []
    classes_text = []
    truncated = []
    poses = []
    difficult_obj = []
    masks = []
    
    #cv2.imshow("origin", mask_np)
    #cv2.imwrite('origin.png', mask_np)
    
    for k in list(mask_pixel.keys()):
       
        class_name = k
        pixel_val = mask_pixel[class_name]
        
        #print('for pixel val#:', k,pixel_val) 

        mask_copy = mask_np.copy()
        mask_copy[mask_np == pixel_val] = 255


        ret,thresh = cv2.threshold(mask_copy, 254,255, cv2.THRESH_BINARY)
        contours, hier = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

        if contours != None:
            for c in contours:
                x,y,w,h = cv2.boundingRect(c)
                xmin = float(x)
                xmax = float(x+w)
                ymin = float(y)
                ymax = float(y+h)
                xmins.append(xmin / width)
                ymins.append(ymin / height)
                xmaxs.append(xmax / width)
                ymaxs.append(ymax / height)
                #print(filename, 'bounding box for', class_name,  xmin, xmax, ymin, ymax)

                classes_text.append(class_name.encode('utf8'))
                classes.append(label_map_dict[class_name])

                #mask_np_black = mask_copy*0

                mask_remapped = (mask_copy == 255).astype(np.uint8)

                masks.append(mask_remapped)
    feature_dict = {
    'image/height': dataset_util.int64_feature(height),
    'image/width': dataset_util.int64_feature(width),
    'image/filename': dataset_util.bytes_feature(
      filename.encode('utf8')),
    'image/source_id': dataset_util.bytes_feature(
      filename.encode('utf8')),
    'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
    'image/encoded': dataset_util.bytes_feature(encoded_jpg),
    'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
    'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
    'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
    'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
    'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
    'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
    'image/object/class/label': dataset_util.int64_list_feature(classes),
    'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
    'image/object/truncated': dataset_util.int64_list_feature(truncated),
    'image/object/view': dataset_util.bytes_list_feature(poses),
    }

    encoded_mask_png_list = []
    for mask in masks:
        img = PIL.Image.fromarray(mask)
        output = io.BytesIO()
        img.save(output, format='PNG')

        encoded_mask_png_list.append(output.getvalue())
        feature_dict['image/object/mask'] = (dataset_util.bytes_list_feature(encoded_mask_png_list))

    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return example

In [3]:
def create_tf_record(output_filename,
                     num_shards,
                     label_map_dict,
                     annotations_dir,
                     image_dir,
                     examples):
    with contextlib2.ExitStack() as tf_record_close_stack:
        output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
        tf_record_close_stack, output_filename, num_shards)
        for idx, example in enumerate(examples):
            if idx % 100 == 0:
                logging.info('On image %d of %d', idx, len(examples))
            mask_path = os.path.join(annotations_dir, example + '.png')
            image_path = os.path.join(image_dir, example + '.jpg')
            try:
                #print(mask_path)
                #print(image_path)
                tf_example = dict_to_tf_example(example,
                                            mask_path,
                                            label_map_dict,
                                            image_path)
                if tf_example:
                    shard_idx = idx % num_shards
                    output_tfrecords[shard_idx].write(tf_example.SerializeToString())
                    print("done")
            except ValueError:
                logging.warning('Invalid example: %s, ignoring.', xml_path)

In [5]:

mask_pixel = {'solid white':38,'broken white':75,'solid yellow':113,'broken yellow':14,'crosswalk':128,'double yellow':52,'double white':89}
xml_path='/Users/fionliang/bdd/labeltest/data_dataset_voc/xml'
#data_dir='/Users/fionliang/bdd/image/data_dataset_voc/'
output_dir='/Users/fionliang/bdd/data/val1.record'
image_dir='/Users/fionliang/bdd/image/val_jpg'
annotations_dir='/Users/fionliang/bdd/image/val_png'
#output_dir='/Users/fionliang/bdd/data/val.record'
#image_dir='/Users/fionliang/bdd/image/val_jpg'
#annotations_dir='/Users/fionliang/bdd/image/val_png'
label_map_path='/Users/fionliang/bdd/data/label_map.pbtxt'
label_map_dict = label_map_util.get_label_map_dict(label_map_path)
num_shards=1

logging.info('Reading from dataset.')
examples_list = os.listdir(image_dir)

for el in examples_list:
    if el[-3:] !='jpg':
        del examples_list[examples_list.index(el)]
for el in examples_list:  
    examples_list[examples_list.index(el)] = el[0:-4]

create_tf_record(output_dir,
              num_shards,
              label_map_dict,
              annotations_dir,
              image_dir,
              examples_list)

done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
done
