In [None]:
import sys
sys.path.append('/home/dan/utils/dspix/geopix/')

First create a list of image files and associated geojson files that contain the bounding boxes.

In [None]:
import glob
import ogr
import re

tifs = glob.glob('AOI_5_Khartoum_Train/RGB-PanSharpen/*.tif')
geojsons = glob.glob('AOI_5_Khartoum_Train/geojson/buildings/*.geojson')

tif_ids = [int(re.findall(r'\d+', tif)[2]) for tif in tifs]
geojson_ids = [int(re.findall(r'\d+', js)[2]) for js in geojsons]

files = [[tifs[i], geojsons[geojson_ids.index(id_)]] for i, id_ in enumerate(tif_ids)]

files[:3]

In [None]:
geo = ogr.Open(files[0][1]) # The first geojson
layer = geo.GetLayer()

bboxes = []
for feat in layer:
    geom = feat.GetGeometryRef()
    bbox = geom.GetEnvelope() # minX, maxX, minY, maxY
    bboxes.append(bbox)

bboxes[:2]

To convert the projected data to pixel coordinates we need the extent of the source image in projected coordinates, we also need to normalize so they are scaled between 0 and 1 (i.e. the *relative position on the image). So repeating with the coordinate transform we get.

In [None]:
import geopix as gp
import gdal

image = gdal.Open(files[0][0]) # The first image

n_x = image.RasterXSize
n_y = image.RasterYSize

gt = image.GetGeoTransform()
gref = gp.Georeference(gt)

layer.ResetReading()

bboxes = []
for feat in layer:
    geom = feat.GetGeometryRef()
    min_x, max_x, min_y, max_y = geom.GetEnvelope()
    
    # Note that transforming to pixel coords shifts the origin to the top left corner
    bbox = gref.world2pix([[min_x, max_y], [max_x, min_y]]).ravel() # minx, miny, maxx, maxy
    bbox = bbox[[1, 0, 3, 2]] / [n_y, n_x, n_y, n_x]
    bboxes.append(bbox)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
%matplotlib notebook

def scale_to_8bit(pix):
    scaled = np.zeros(pix.shape, dtype=float)
    for b in range(len(pix)):
        scaled[b] = pix[b]/pix[b].max()
    
    return (scaled*255)

bands = scale_to_8bit(image.ReadAsArray())
tcc = bands[:3].transpose(1, 2, 0)

def draw_boxes(bboxes, ax):
    '''Draws on the axis image'''
    for box in bboxes:
        xy = box[[1, 0]] * [n_x, n_y]
        w = (box[3]-box[1]) * n_x
        h = (box[2]-box[0]) * n_y
        rect = patches.Rectangle(xy, w, h, lw=1, ec='r', fc='none')
        ax.add_patch(rect)

fig, ax = plt.subplots()
ax.imshow(tcc.astype(int))
draw_boxes(bboxes, ax)

In [None]:
def geojson_to_bboxes(geojson_fn, georef, nx, ny):
    '''
    '''
    json = ogr.Open(geojson_fn)
    layer = json.GetLayer()
    bboxes = []
    for feat in layer:
        geom = feat.GetGeometryRef()
        min_x, max_x, min_y, max_y = geom.GetEnvelope()
        bbox = georef.world2pix([[min_x, max_y], [max_x, min_y]]).ravel() # minx, miny, maxx, maxy
        bbox = bbox[[1, 0, 3, 2]] / [ny, nx, ny, nx]
        bboxes.append(bbox) # ymin, xmin, ymax, xmax
    
    return bboxes

np.stack(geojson_to_bboxes(files[0][1], gref, n_x, n_y))[:5]

Now write functions to encode the image data into an `uint8` array with shape (rows, columns, bands) and the bounding boxes for each object. See https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md

In [None]:
import tensorflow as tf

In [None]:
def _int64(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_list(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_list(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _bytes_list(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def encode_sample(name, pixels, boxes):
    '''Creates a tf.Example proto from image data and boxes.
    Args:
        name: The image filename.
        pixels: A numpy array of image data of shape (rows, columns, bands)
        boxes: A numpy array of n boxes with shape (ymin, xmin, ymax, xmax)
    
    Returns:
        example: The tf.Example created.
    '''
    rows, cols, bands = pixels.shape
    
    xmins = boxes[:, 1]
    xmaxs = boxes[:, 3]
    ymins = boxes[:, 0]
    ymaxs = boxes[:, 2]
    classes_text = ['building'.encode('utf8')]*len(boxes)
    classes = [1]*len(boxes)
    
    encode_pixels_as_jpeg = tf.image.encode_jpeg(
        tf.convert_to_tensor(pixels.astype(np.uint8), dtype=tf.uint8)
    ).numpy() # Expects an encoded jpeg (bytes)
    
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/source_id': _bytes(name.encode('utf8')),
                'image/filename': _bytes(name.encode('utf8')),
                'image/height': _int64(rows),
                'image/width': _int64(cols),
                'image/depth': _int64(bands),
                'image/encoded': _bytes(encode_pixels_as_jpeg),
                #'image/label': _bytes(labels.astype(np.uint8).tobytes()),
                'image/format': _bytes('JPEG'.encode('utf8')), # Hardcoded
                'image/object/bbox/xmin': _float_list(xmins),
                'image/object/bbox/xmax': _float_list(xmaxs),
                'image/object/bbox/ymin': _float_list(ymins),
                'image/object/bbox/ymax': _float_list(ymaxs),
                'image/object/class/text': _bytes_list(classes_text),
                'image/object/class/label': _int64_list(classes),
            }))

    return example

Putting it all together to serialize the data to tf.records.

In [None]:
import os

out_path_template = 'AOI_5_Khartoum_Train/train_records/aoi5_khartoum_{:04d}.tfrecords'

for i, (image_fn, geojson_fn) in enumerate(files):
    # Create a new file every 200 records
    if i % 200 == 0:
        if i != 0:
            writer.close()
        writer = tf.io.TFRecordWriter(out_path_template.format(i))
    
    image = gdal.Open(image_fn)
    n_x = image.RasterXSize
    n_y = image.RasterYSize
    gref = gp.Georeference(image.GetGeoTransform())
    
    # Parse the boxes
    boxes = geojson_to_bboxes(geojson_fn, gref, n_x, n_y)
    if len(boxes) != 0:
        boxes = np.stack(boxes).astype('float32')
    
        # Strip out the filename
        name = os.path.split(os.path.splitext(files[0][1])[0])[1]

        # Load and scale the pixel data
        pixels = scale_to_8bit(image.ReadAsArray())
        pixels = pixels.transpose(1, 2, 0) # Band axis is different to GDAL

        # Serialize and write to file
        record = encode_sample(name, pixels, boxes)
        writer.write(record.SerializeToString())

writer.close()
    

In [None]:
!du -h AOI_5_Khartoum_Train/

In [None]:
!ls -hault AOI_5_Khartoum_Train/train_records

In [None]:
#!rm -f AOI_5_Khartoum_Train/train_records/*

Check the data from file

In [None]:
raw_dataset = tf.data.TFRecordDataset('AOI_5_Khartoum_Train/train_records/aoi5_khartoum_1000.tfrecords')
raw_dataset

In [None]:
# Create a dictionary describing the features.
image_feature_description = {
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'image/filename': tf.io.FixedLenFeature([], tf.string),
    'image/height': tf.io.FixedLenFeature([], tf.int64),
    'image/width': tf.io.FixedLenFeature([], tf.int64),
    'image/depth': tf.io.FixedLenFeature([], tf.int64),
    'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
}

sequence_features = {
    'image/object/bbox/xmin': tf.io.FixedLenSequenceFeature([], tf.float32),
    'image/object/bbox/xmax': tf.io.FixedLenSequenceFeature([], tf.float32),
    'image/object/bbox/ymin': tf.io.FixedLenSequenceFeature([], tf.float32),
    'image/object/bbox/ymax': tf.io.FixedLenSequenceFeature([], tf.float32),
#     'image/object/class/text': tf.io.FixedLenFeature([], tf.int64),
#     'image/object/class/label': tf.io.FixedLenFeature([], tf.int64),
}

def _parse_image_function(example_proto):
    return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_dataset.map(_parse_image_function)
parsed_image_dataset

In [None]:
import IPython.display as display

image_feature = list(parsed_image_dataset)[0]
image_raw = tf.image.decode_jpeg(image_feature['image/encoded']).numpy()
height = image_feature['image/height'].numpy()
width = image_feature['image/width'].numpy()
image = np.frombuffer(image_raw, np.uint8)

xmins = tf.sparse.to_dense(image_feature['image/object/bbox/xmin'], default_value=0).numpy()
ymins = tf.sparse.to_dense(image_feature['image/object/bbox/ymin'], default_value=0).numpy()
xmaxs = tf.sparse.to_dense(image_feature['image/object/bbox/xmax'], default_value=0).numpy()
ymaxs = tf.sparse.to_dense(image_feature['image/object/bbox/ymax'], default_value=0).numpy()

bs = np.column_stack([ymins, xmins, ymaxs, xmaxs])

fig, ax = plt.subplots()
ax.imshow(image.reshape(height, width, 3))
draw_boxes(bs, ax)    