In [None]:
import itertools
import time
import zipfile
import shutil
import json
import os
import sys
import logging
import glob

import numpy
import tensorflow as tf

# Installed in addition

#from osgeo import gdal
#import shapely.wkb
#import shapely.prepared
#import requests
#from retrying import retry

In [4]:
# Dataset paths

path_to_charlie_root = "../../.."
NOT_A_DAM_IMAGE_DIR = os.path.join(path_to_charlie_root,"data/imagery-6-7-2019/not_a_dam_images")
DAM_IMAGE_DIR = os.path.join(path_to_charlie_root,"data/imagery-6-7-2019/dam_images")

TM_WORLD_BORDERS_URL = 'https://storage.googleapis.com/ecoshard-root/ipbes/TM_WORLD_BORDERS_SIMPL-0.3_md5_15057f7b17752048f9bd2e2e607fe99c.zip'

if not os.path.exists(NOT_A_DAM_IMAGE_DIR):
    raise ValueError("can't find %s'" % NOT_A_DAM_IMAGE_DIR)
if not os.path.exists(DAM_IMAGE_DIR):
    raise ValueError("can't find %s'" % DAM_IMAGE_DIR)
    
WORKSPACE_DIR = os.path.join(path_to_charlie_root,"data/making_TFRecords_workspace")
    

#### Parameters

In [None]:
dev_set_portion = .2
DAMS_PER_RECORD = 500


In [None]:
gdal.SetCacheMax(2**30)

logging.basicConfig(
    level=logging.DEBUG,
    format=(
        '%(asctime)s (%(relativeCreated)d) %(levelname)s %(name)s'
        ' [%(funcName)s:%(lineno)d] %(message)s'),
    stream=sys.stdout)
LOGGER = logging.getLogger(__name__)
REQUEST_TIMEOUT = 1.0


def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


In [None]:
### Save aside South Africa

tm_world_borders_zip_path = os.path.join(
        WORKSPACE_DIR, os.path.basename(TM_WORLD_BORDERS_URL))
    if not os.path.exists(tm_world_borders_zip_path):
        download_url_to_file(TM_WORLD_BORDERS_URL, tm_world_borders_zip_path)
        with zipfile.ZipFile(tm_world_borders_zip_path, 'r') as zip_ref:
            zip_ref.extractall(WORKSPACE_DIR)

    tm_world_borders_vector_path = os.path.join(
        WORKSPACE_DIR, 'TM_WORLD_BORDERS-0.3.shp')
    tm_world_borders_vector = gdal.OpenEx(
        tm_world_borders_vector_path, gdal.OF_VECTOR)
    tm_world_borders_layer = tm_world_borders_vector.GetLayer()
    for border_feature in tm_world_borders_layer:
        if border_feature.GetField('NAME') == 'South Africa':
            sa_geom = border_feature.GetGeometryRef()
            sa_geom_prep = shapely.prepared.prep(
                shapely.wkb.loads(sa_geom.ExportToWkb()))
            break
    LOGGER.debug(sa_geom_prep)

In [5]:
download_url_to_file?

Object `download_url_to_file` not found.


In [None]:
def main():
    """Entry point."""
    try:
        os.makedirs(WORKSPACE_DIR)
    except OSError:
        pass

    training_log = {b'dam': 0, b'not_a_dam': 0}
    validation_log = {b'dam': 0, b'not_a_dam': 0}
    south_africa_log = {b'dam': 0, b'not_a_dam': 0}


    last_time = time.time()
    training_writer_count = 0
    validation_writer_count = 0
    tf_record_iteration = 0

    dam_file_iter = glob.iglob(os.path.join(DAM_IMAGE_DIR, '*pped.png'))
    no_a_dam_file_iter = glob.iglob(os.path.join(NOT_A_DAM_IMAGE_DIR, '*.png'))
    
    dam_list_iter = [
        path for path_tuple in zip(dam_file_iter, no_a_dam_file_iter)
        for path in path_tuple]

    while True:
        # this makes DAMS_PER_RECORD list of files
        dam_file_list = [
            (path, dam_type) for path, dam_type in zip(
                itertools.islice(
                    dam_list_iter,
                    DAMS_PER_RECORD*tf_record_iteration,
                    DAMS_PER_RECORD*(tf_record_iteration+1)),
                itertools.cycle([b'dam', b'not_a_dam']))]

        LOGGER.debug(dam_file_list)
        if not dam_file_list:
            break

        with tf.Graph().as_default(), tf.Session() as sess:
            training_writer = tf.python_io.TFRecordWriter(
                os.path.join(
                    WORKSPACE_DIR,
                    'dam_training_%d.record' % tf_record_iteration))
            validation_writer = tf.python_io.TFRecordWriter(
                os.path.join(
                    WORKSPACE_DIR,
                    'dam_validation_%d.record' % tf_record_iteration))
            south_africa_writer = tf.python_io.TFRecordWriter(os.path.join(
                WORKSPACE_DIR, 'south_africa_%d.record' %
                tf_record_iteration))

            for image_path, dam_type in dam_file_list:
                current_time = time.time()
                if current_time - last_time > 5.0:
                    LOGGER.info('training_log: %s', training_log)
                    LOGGER.info('validation_log: %s', validation_log)
                    LOGGER.info('south_africa_log: %s', south_africa_log)
                    LOGGER.info('training_writer_count: %d', training_writer_count)
                    LOGGER.info('validation_writer_count: %d', validation_writer_count)
                    last_time = current_time
                # looks like anything can be used here, including serializing
                # a tensor tf.serialize_tensor
                image_string = tf.read_file(image_path)
                image_decoded = tf.image.decode_png(image_string).eval()
                image_string = open(image_path, 'rb').read()
                feature_dict = {
                    'image/height': int64_feature(
                        image_decoded.shape[0]),
                    'image/width': int64_feature(
                        image_decoded.shape[1]),
                    'image/filename': bytes_feature(
                        bytes(image_path, 'utf8')),
                    'image/source_id': bytes_feature(
                        bytes(image_path, 'utf8')),
                    'image/encoded': bytes_feature(image_string),
                    'image/format': bytes_feature(b'png'),
                }
                if dam_type == b'dam':
                    json_path = image_path.replace('.png', '.json')
                    with open(json_path, 'r') as json_file:
                        image_metadata = json.load(json_file)
                    xmin = image_metadata['pixel_bounding_box'][0] / float(image_decoded.shape[0])
                    xmax = image_metadata['pixel_bounding_box'][2] / float(image_decoded.shape[0])
                    ymin = image_metadata['pixel_bounding_box'][3] / float(image_decoded.shape[1])
                    ymax = image_metadata['pixel_bounding_box'][1] / float(image_decoded.shape[1])
                    if (xmin < 0 or ymin < 0 or
                            xmax >= 1 or
                            ymax >= 1):
                        LOGGER.warning(
                            'bounding box out of bounds %s %s %s %s',
                            xmin, xmax, ymin, ymax)
                        xmin = max(0, xmin)
                        xmax = min(xmax, 1)
                        ymin = max(0, ymin)
                        ymax = min(ymax, 1)

                    feature_dict.update({
                        'image/object/bbox/xmin': float_list_feature([xmin]),
                        'image/object/bbox/xmax': float_list_feature([xmax]),
                        'image/object/bbox/ymin': float_list_feature([ymin]),
                        'image/object/bbox/ymax': float_list_feature([ymax]),
                        'image/object/class/label': int64_list_feature(
                            [1]),  # the '1' is type 1 which is a dam
                        'image/object/class/text': bytes_list_feature(
                            [b'dam']),
                    })
                    tf_record = tf.train.Example(features=tf.train.Features(
                        feature=feature_dict))

                    centroid = image_metadata['lng_lat_centroid']
                    if dam_type == b'dam' and sa_geom_prep.contains(
                            shapely.geometry.Point(centroid[0], centroid[1])):
                        writer = south_africa_writer
                        log = south_africa_log
                        writer.write(tf_record.SerializeToString())
                        log[dam_type] += 1
                        continue
                else:
                    tf_record = tf.train.Example(features=tf.train.Features(
                        feature=feature_dict))
                if numpy.random.random() > dev_set_portion:
                    writer = training_writer
                    log = training_log
                else:
                    writer = validation_writer
                    log = validation_log
                writer.write(tf_record.SerializeToString())
                log[dam_type] += 1

            LOGGER.info(
                "training writer full creating %d instance" %
                tf_record_iteration)
            tf_record_iteration += 1
            training_writer.close()
            validation_writer.close()
            south_africa_writer.close()

    with open('write_stats.txt', 'w') as write_stats_file:
        write_stats_file.write(
            f"""validation: dam({validation_log[b'dam']}) not_a_dam({
                validation_log[b'not_a_dam']})\n"""
            f"""training: dam({training_log[b'dam']}) not_a_dam({
                training_log[b'not_a_dam']})\n"""
            f"""south_africa: dam({south_africa_log[b'dam']}) not_a_dam({
                south_africa_log[b'not_a_dam']})\n""")


@retry(wait_exponential_multiplier=1000, wait_exponential_max=10000)


# Utils

In [None]:
def make_TFRecords():
    with tf.Graph().as_default(), tf.Session() as sess:
        training_writer = tf.python_io.TFRecordWriter(
            os.path.join(
                WORKSPACE_DIR,
                'dam_training_%d.record' % tf_record_iteration))
        validation_writer = tf.python_io.TFRecordWriter(
            os.path.join(
                WORKSPACE_DIR,
                'dam_validation_%d.record' % tf_record_iteration))
        south_africa_writer = tf.python_io.TFRecordWriter(os.path.join(
            WORKSPACE_DIR, 'south_africa_%d.record' %
            tf_record_iteration))

        for image_path, dam_type in dam_file_list:
            current_time = time.time()
            if current_time - last_time > 5.0:
                LOGGER.info('training_log: %s', training_log)
                LOGGER.info('validation_log: %s', validation_log)
                LOGGER.info('south_africa_log: %s', south_africa_log)
                LOGGER.info('training_writer_count: %d', training_writer_count)
                LOGGER.info('validation_writer_count: %d', validation_writer_count)
                last_time = current_time
            # looks like anything can be used here, including serializing
            # a tensor tf.serialize_tensor
            image_string = tf.read_file(image_path)
            image_decoded = tf.image.decode_png(image_string).eval()
            image_string = open(image_path, 'rb').read()
            feature_dict = {
                'image/height': int64_feature(
                    image_decoded.shape[0]),
                'image/width': int64_feature(
                    image_decoded.shape[1]),
                'image/filename': bytes_feature(
                    bytes(image_path, 'utf8')),
                'image/source_id': bytes_feature(
                    bytes(image_path, 'utf8')),
                'image/encoded': bytes_feature(image_string),
                'image/format': bytes_feature(b'png'),
            }
            if dam_type == b'dam':
                json_path = image_path.replace('.png', '.json')
                with open(json_path, 'r') as json_file:
                    image_metadata = json.load(json_file)
                xmin = image_metadata['pixel_bounding_box'][0] / float(image_decoded.shape[0])
                xmax = image_metadata['pixel_bounding_box'][2] / float(image_decoded.shape[0])
                ymin = image_metadata['pixel_bounding_box'][3] / float(image_decoded.shape[1])
                ymax = image_metadata['pixel_bounding_box'][1] / float(image_decoded.shape[1])
                if (xmin < 0 or ymin < 0 or
                        xmax >= 1 or
                        ymax >= 1):
                    LOGGER.warning(
                        'bounding box out of bounds %s %s %s %s',
                        xmin, xmax, ymin, ymax)
                    xmin = max(0, xmin)
                    xmax = min(xmax, 1)
                    ymin = max(0, ymin)
                    ymax = min(ymax, 1)

                feature_dict.update({
                    'image/object/bbox/xmin': float_list_feature([xmin]),
                    'image/object/bbox/xmax': float_list_feature([xmax]),
                    'image/object/bbox/ymin': float_list_feature([ymin]),
                    'image/object/bbox/ymax': float_list_feature([ymax]),
                    'image/object/class/label': int64_list_feature(
                        [1]),  # the '1' is type 1 which is a dam
                    'image/object/class/text': bytes_list_feature(
                        [b'dam']),
                })
                tf_record = tf.train.Example(features=tf.train.Features(
                    feature=feature_dict))

                centroid = image_metadata['lng_lat_centroid']
                if dam_type == b'dam' and sa_geom_prep.contains(
                        shapely.geometry.Point(centroid[0], centroid[1])):
                    writer = south_africa_writer
                    log = south_africa_log
                    writer.write(tf_record.SerializeToString())
                    log[dam_type] += 1
                    continue
            else:
                tf_record = tf.train.Example(features=tf.train.Features(
                    feature=feature_dict))
            if numpy.random.random() > dev_set_portion:
                writer = training_writer
                log = training_log
            else:
                writer = validation_writer
                log = validation_log
            writer.write(tf_record.SerializeToString())
            log[dam_type] += 1

        LOGGER.info(
            "training writer full creating %d instance" %
            tf_record_iteration)
        tf_record_iteration += 1
        training_writer.close()
        validation_writer.close()
        south_africa_writer.close()

In [9]:
def download_url_to_file(url, target_file_path):
    """Use requests to download a file.

    Parameters:
        url (string): url to file.
        target_file_path (string): local path to download the file.

    Returns:
        None.

    """
    try:
        response = requests.get(url, stream=True, timeout=REQUEST_TIMEOUT)
        try:
            os.makedirs(os.path.dirname(target_file_path))
        except OSError:
            pass
        with open(target_file_path, 'wb') as target_file:
            shutil.copyfileobj(response.raw, target_file)
        del response
    except:
        LOGGER.exception('download of {url} to {target_file_path} failed')
        # mods from LOGGER.exception(f'download of {url} to {target_file_path} failed')
        raise