In [None]:
import tensorflow as tf
from tensorflow.python.ops import init_ops
import numpy as np
from scipy import misc
#import matplotlib.pyplot as plt
from copy import deepcopy, copy
import random
import os
import logging
import pickle
import cv2
from datetime import datetime
from tqdm import tqdm_notebook
import tqdm
#%matplotlib inline
print(tf.__version__)

In [None]:
def get_cleaned_data():
    empty = 0
    total = 0
    error = 0

    images = []
    for d, dirs, files in tqdm_notebook(list(os.walk('files/grayschema/18/'))):
        for f in files:
            total += 1
            schema_path = os.path.join(d,f)
            try:
                if not os.path.isfile(schema_path):
                    raise RuntimeError("{} doesn't exist".format(schema_path))
                img = cv2.imread(schema_path)
                if (img <= 128).all():
                    empty += 1
                else:
                    relpath = schema_path[len('files/grayschema/'):]
                    sat_path = os.path.join('files/mapbox.satellite/', relpath)
                    if not os.path.isfile(sat_path):
                        raise RuntimeError("{} doesn't exist".format(sat_path))
                    cv2.imread(sat_path)
                    images.append(relpath)
            except Exception as e:
                print(e)
                error += 1
    
    print("Total {}, empty {} {:.2%}, error {}".format(
        total,
        empty, empty * 1.0/ total,
        error))
    
    data = []
    shuffled_images = copy(images)
    random.shuffle(shuffled_images)
    for p,wp in zip(images, shuffled_images):
        data.append((os.path.join('files/mapbox.satellite/', p),
            os.path.join('files/grayschema/', p), 1, 0))
        data.append((os.path.join('files/mapbox.satellite/', p),
            os.path.join('files/grayschema/', wp), 0, 1))
    random.shuffle(data)
        
    sat_file_paths = []
    schema_file_paths = []
    labels = []
    for item in data:
        sat_file_paths.append(item[0])
        schema_file_paths.append(item[1])
        labels.append(item[2:4])

    return sat_file_paths, schema_file_paths, labels

In [None]:
sat_file_paths, schema_file_paths, labels = get_cleaned_data()
with open("cleaned_data.pickle", "wb+") as f:
    pickle.dump((sat_file_paths, schema_file_paths, labels), f)

In [None]:
with open("cleaned_data.pickle", "rb") as f:
    sat_file_paths, schema_file_paths, labels = pickle.load(f)

In [None]:
def splittvt(arr, frac):
    count = len(arr)
    valid_start = int(count * (1-2*frac))
    test_start = int(count * (1-frac))
    return arr[:valid_start], arr[valid_start: test_start], arr[test_start:]

TST_FRAC = 0.2
train_sat_paths, valid_sat_paths, test_sat_paths = splittvt(sat_file_paths, TST_FRAC)
train_map_paths, valid_map_paths, test_map_paths = splittvt(schema_file_paths, TST_FRAC)
train_labels, valid_labels, test_labels = splittvt(labels, TST_FRAC)
print(len(train_sat_paths), len(valid_sat_paths), len(test_sat_paths))

In [None]:
IMG_SIZE = 256
batch_size = 32
DF_DIM = 64

def read_img(path, f):
    if f == 'png':
        return -1 + tf.cast(tf.image.decode_png(tf.read_file(path), channels=1), tf.float32)/128.0
    else:
        return -1 + tf.cast(tf.image.decode_jpeg(tf.read_file(path)), tf.float32)/128.0

In [None]:
def encode(data, num_outputs, stride, is_training):
    d = tf.layers.Conv2D(
            num_outputs,
            (5, 5),
            strides=stride,
            padding='SAME'
        )(data)
    d = tf.layers.BatchNormalization()(d)
    d = tf.nn.leaky_relu(d)
    return d


def decode(data, num_outputs, stride, is_training):
    d = tf.layers.Conv2DTranspose(
            data, num_outputs,
            (5, 5),
            strides=stride,
            padding='SAME'
        )(data)
    d = tf.layers.BatchNormalization()(d)
    d = tf.nn.leacky_relu(d)
    return d

    
def generator(data, is_training):
    enc_1 = encode(data, DF_DIM, 2, is_training)      # 128 x 128 x 64
    enc_2 = encode(enc_1, DF_DIM * 2, 2, is_training)  # 64 x 64 x 128
    enc_3 = encode(enc_2, DF_DIM * 4, 2, is_training)  # 32 x 32 x 256
    enc_4 = encode(enc_3, DF_DIM * 8, 2, is_training)  # 16 x 16 x 512
    enc_5 = encode(enc_4, DF_DIM * 8, 2, is_training)  # 8 x 8 x 512
    enc_6 = encode(enc_5, DF_DIM * 8, 2, is_training)  # 4 x 4 x 512
    enc_7 = encode(enc_6, DF_DIM * 8, 2, is_training)  # 2 x 2 x 512
    enc_8 = encode(enc_7, DF_DIM * 8, 2, is_training)  # 1 x 1 x 512

    dec_7 = decode(enc_8, DF_DIM * 8, 2, is_training)  # 2 x 2 x 512
    
    dec_6 = decode(tf.concat([dec_7, enc_7], 3), DF_DIM * 8, 2, is_training)  # 4 x 4 x 512
    dec_5 = decode(tf.concat([dec_6, enc_6], 3), DF_DIM * 8, 2, is_training)  # 8 x 8 x 512
    dec_4 = decode(tf.concat([dec_5, enc_5], 3), DF_DIM * 8, 2, is_training)  # 16 x 16 x 512
    dec_3 = decode(tf.concat([dec_4, enc_4], 3), DF_DIM * 4, 2, is_training)  # 32 x 32 x 256
    dec_2 = decode(tf.concat([dec_3, enc_3], 3), DF_DIM * 2, 2, is_training)  # 64 x 64 x 128
    dec_1 = decode(tf.concat([dec_2, enc_2], 3), DF_DIM , 2, is_training)  # 128 x 128 x 64
    result = decode(tf.concat([dec_1, enc_1], 3), 2, is_training)  # 256 x 256 x 3
                    
    return result

def discriminator(data, is_training):
    conv = data

    conv = encode(conv, DF_DIM, 2, is_training)
    conv = encode(conv, DF_DIM * 2, 2, is_training)
    conv = encode(conv, DF_DIM * 4, 2, is_training)
    conv = encode(conv, DF_DIM * 8, 1, is_training)
    
    hidden = tf.layers.Flatten()(conv)
    logits = tf.layers.Dense(2)(hidden)
    return logits


graph = tf.Graph()
with graph.as_default():
    tf_is_training = tf.placeholder_with_default(True, shape=())
    tf_map_files = tf.placeholder(tf.string, shape=[batch_size])
    tf_sat_files = tf.placeholder(tf.string, shape=[batch_size])
    tf_labels = tf.placeholder(tf.float32, shape=[batch_size, 2])
    
    map_imgs = tf.stack([read_img(path, 'png') for path in tf.unstack(tf_map_files)])
    print("Map imgs, shape", map_imgs.shape)
    sat_imgs = tf.stack([read_img(path, 'jpg') for path in tf.unstack(tf_sat_files)])
    print("Sat imgs, shape", sat_imgs.shape)

    tf_data = tf.reshape(tf.concat([
            map_imgs,
            sat_imgs
        ],3), [batch_size, IMG_SIZE, IMG_SIZE, 4])
    
    print("Input shape: ", tf_data.shape)
    
    logits = discriminator(tf_data, tf_is_training)
    
    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=logits, labels=tf_labels))
    tf.summary.scalar('loss', loss)
    global_step = tf.Variable(0)
    optimizer = tf.train.AdamOptimizer(0.0001).minimize(loss, global_step=global_step)
    
    prediction = tf.nn.softmax(logits)
    correct = tf.nn.in_top_k(logits, tf.math.argmax(tf_labels, 1), 1)
    count_correct = tf.count_nonzero(correct)
    tf.summary.scalar("acc", count_correct/batch_size)
    merged = tf.summary.merge_all()
    saver = tf.train.Saver()

In [None]:
def doEpoch(session, epoch, writer):
    train_loss = 0.0
    train_correct = 0
    num_batches = (len(train_map_paths) - batch_size) // batch_size + 1
    for step in range(num_batches):
        offset = (step * batch_size) % (len(train_map_paths) - batch_size + 1)
        feed_dict={
            tf_map_files: train_map_paths[offset:offset+batch_size],
            tf_sat_files: train_sat_paths[offset:offset+batch_size],
            tf_labels: train_labels[offset:offset+batch_size]
        }
        _,l,c,ms, gs = session.run([optimizer, loss, count_correct, merged, global_step], feed_dict=feed_dict)
        train_loss += l
        train_correct += c
        #print("TBatch acc:", 1.0*c/batch_size)

        writer.add_summary(ms, gs)

    return l, train_loss / num_batches, train_correct
    
def doValidate(session, writer):
    valid_correct = 0
    for step in range((len(valid_map_paths) - batch_size) // batch_size + 1):
        offset = (step * batch_size) % (len(valid_map_paths) - batch_size + 1)
        feed_dict={
            tf_map_files: valid_map_paths[offset:offset+batch_size],
            tf_sat_files: valid_sat_paths[offset:offset+batch_size],
            tf_labels: valid_labels[offset:offset+batch_size],
            tf_is_training: False
        }

        l, c, ms, gs = session.run([loss, count_correct, merged, global_step], feed_dict=feed_dict)
        #print("Batch acc:", 1.0*c/batch_size)
        valid_correct += c
        writer.add_summary(ms, gs)
    return (1.0 * valid_correct) / len(valid_map_paths)

tf.logging.set_verbosity(tf.logging.INFO)
tf.debugging.set_log_device_placement(True)

In [None]:
EPOCHS = 20

run_id="disk_1"
train_writer = tf.summary.FileWriter("./summary/nopipeline_train_{}".format(run_id), graph=graph)
valid_writer = tf.summary.FileWriter("./summary/nopipeline_valid_{}".format(run_id))

with tf.Session(graph=graph) as session: 
    tf.global_variables_initializer().run()
    print(datetime.now(), "Initialized.")
    for epoch in tqdm_notebook(range(EPOCHS)):
        last_loss, train_loss, correct = doEpoch(session, epoch, train_writer)
        accuracy = float('NaN')
        accuracy = doValidate(session, valid_writer)
        if epoch % 10 == 0:
            print(datetime.now(), "{} loss: {:.6f}, acc: {:.6f} {:.6f}".format(
                epoch, train_loss, accuracy, 1.0*correct/len(train_sat_paths)))
        saver.save(session, "checkpoints/checkpoint_{}".format(run_id))
    print(datetime.now(),
          "Final loss: {:.6f}, acc: {:.6f} {:.6f}".format(
              train_loss, accuracy, 1.0*correct/len(train_sat_paths)))

In [None]:
with tf.Session(graph=graph) as session: 
    saver.restore(session, "checkpoints/checkpoint_{}".format(run_id))

    offset = 0
    feed_dict={
        tf_map_files: train_map_paths[offset:offset+batch_size],
        tf_sat_files: train_sat_paths[offset:offset+batch_size],
        tf_labels: train_labels[offset:offset+batch_size],
        tf_is_training: False
    }
    l, c, ms, p = session.run([loss, count_correct, merged, prediction], feed_dict=feed_dict)
    print("Batch acc:", 1.0*c/batch_size)
    print("Prediction:", p)
    print("GT:", train_labels[offset:offset+batch_size])

In [None]:
with tf.Session(graph=graph) as session: 
    tf.global_variables_initializer().run()
    print "Initialized."

    offset = 0
    feed_dict={
        tf_map_files: train_map_paths[offset:offset+batch_size],
        tf_sat_files: train_sat_paths[offset:offset+batch_size],
        tf_labels: train_labels[offset:offset+batch_size],
        tf_is_training: False
    }
    p,l,c = session.run([prediction, loss, count_correct], feed_dict=feed_dict)