# Preprocess images to tfrecords

## Adapted from the demo tfos_mnist_preprocessing notebook

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from array import array
from hops import hdfs

import numpy as np
from PIL import Image
import io

def toTFExample(image, label):
    """Serializes an image/label as a TFExample byte string"""
    example = tf.train.Example(
      features = tf.train.Features(
        feature = {
          'label': tf.train.Feature(int64_list=tf.train.Int64List(value=label.astype("int64"))),
          'image': tf.train.Feature(int64_list=tf.train.Int64List(value=image.astype("int64")))
        }
      )
    )
    return example.SerializeToString()


def fromTFExample(bytestr):
    """Deserializes a TFExample from a byte string"""
    example = tf.train.Example()
    example.ParseFromString(bytestr)
    return example


def readDataToRDD(sc, folder):
    # Want an RDD of the form (label, np array)
    # sparkContext.binaryFiles("hdfs://a-hdfs-path") returns (a-hdfs-path/part-00000, its content)
    rdd = sc.binaryFiles(folder + "/**/*.png")
    #print("RDD raw")
    #print(rdd.take(2))
    # extract the label
    rdd = rdd.map(lambda x: (extract_label_one_hot(x[0]), x[1]))
    #print("RDD label raw")
    #print(rdd.take(1))
    # convert binary to array
    rdd = rdd.map(lambda x: (x[0], np.array(Image.open(io.BytesIO(x[1])))))
    #print("RDD label array")
    #print(rdd.take(1))
    #print("RDD array shape")
    #print(((rdd.take(1)[0])[1]).shape)
    # The numpy array is of shape 128x128 - all images should be of that size
    rdd = rdd.filter(lambda x: x[1].shape == (128, 128))
    #print("RDD size after filter")
    #print(str(rdd.count()))
    # Reshape numpy array to vector (1d)
    rdd = rdd.map(lambda x: (x[0], x[1].reshape(x[1].shape[0] * x[1].shape[1])))
    #print("RDD label reshaped array")
    #print(rdd.take(1))
    #print("RDD array shape")
    #print(((rdd.take(1)[0])[1]).shape)
    
    return rdd

def extract_label_one_hot(path):
    # File names are of the form genre_X_X.png
    # Our music genre labels
    label_dict = {
    'Classical': 0,
    'Techno': 1,
    'Pop': 2,
    'HipHop': 3,
    'Metal': 4,
    'Rock': 5
    }
    filename = path.split("/")[len(path.split("/")) - 1]
    genre = filename.split("_")[0]
    label_val = int(label_dict.get(genre)) # Should never get a label not in the dict
    label_one_hot = np.zeros(len(label_dict), dtype=np.uint8)
    label_one_hot[label_val] = 1
    return label_one_hot
    

def write_tf_records(sc, input_dir,  output_dir):
    rdd = readDataToRDD(sc, input_dir)
    tfRDD = rdd.map(lambda x: (bytearray(toTFExample(x[1], x[0])), None))
    # requires: --jars tensorflow-hadoop-1.0-SNAPSHOT.jar
    #tfRDD.saveAsNewAPIHadoopFile(output_dir, "org.tensorflow.hadoop.io.TFRecordFileOutputFormat",
    #                             keyClass="org.apache.hadoop.io.BytesWritable",
    #                             valueClass="org.apache.hadoop.io.NullWritable")
    
    


In [None]:
from pyspark.context import SparkContext
from pyspark.conf import SparkConf

sc = spark.sparkContext
hdfs_project_path = "hdfs:///Projects/genre_classifier_2/"
dataset_path = hdfs_project_path + "Spectrograms/"
output_base_folder = dataset_path + "tfrecords/"

# Testing
input_folder = dataset_path + "testing"
output_folder = output_base_folder + "testing"
write_tf_records(sc, input_folder, output_folder)