In [None]:
from pathlib import Path
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
import json
import subprocess
from google.cloud import storage
tf.get_logger().setLevel('ERROR')

### Read from google storage bucket

In [None]:
def list_blobs(bucket_name):
    """Lists all the blobs in the bucket."""
    storage_client = storage.Client()

    # Note: Client.list_blobs requires at least package version 1.17.0.
    blobs = storage_client.list_blobs(bucket_name)

    # Note: The call returns a response only when the iterator is consumed.
    # for blob in blobs:
    #     print(blob.name)
    return [blob.name for blob in blobs]


def get_file_list(bucket_name):
    cmd  = f"gsutil ls gs://{bucket_name}"
    #if os.system(cmd) == 0:
    status,files = subprocess.getstatusoutput(cmd)
    if status == 0:
        return files.split("\n")
    else:
        raise Exception("no file was found")

def get_tfrecord_files(bucket_file_list,suffix='tfrecord'):
    return [ f for f in bucket_file_list if f.endswith(suffix) or f.endswith('gz')]

def get_json_file(gsfiles,file_prefix="mixer.json"):
    json_file_list = [f for f in gsfiles if file_prefix in f]
    if json_file_list:
        json_file = json_file_list[0] #str(path/(file_prefix+'mixer.json'))
        cmd = f"gsutil cat {json_file}"
        status,text = subprocess.getstatusoutput(cmd)
        if status == 0:
            mixer = json.loads(text)
            return mixer
        else:
            raise Exception("no json file was found")

# Parsing function.
# TO DO: include img_feature_dict as input so I can call it
# ds = image_dataset.map(lambda proto, features: tf.io.parse_single_example(proto, features),
def parse_image(example_proto):
    return tf.io.parse_single_example(example_proto, image_features_dict)


def select_tiles_on_classRatio(ds_np_gen , img_size = 256*256 ,class_ratio=0.5):
    thr = class_ratio * img_size
    for img_dic in  ds_np_gen:
        img = img_dic['cwf']
        if np.count_nonzero(img) >= thr:
            yield img_dic

def np_to_tfr(ds_gen, file_name="./test_tfRecord.gz"):
    with tf.io.TFRecordWriter(file_name,options=tf.io.TFRecordOptions(
    compression_type='GZIP')) as writer:
        for img_dic in ds_gen:
            
            feature = {}
            for k, v in img_dic.items():
                if k == 'cwf':
                    feature[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v.flatten()))
                else:
                    feature[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v.flatten()))                
            
            # Construct the Example proto object
            example = tf.train.Example(features=tf.train.Features(feature=feature))

            # Serialize the example to a string
            serialized = example.SerializeToString()

            # write the serialized objec to the disk
            writer.write(serialized)

def write_file_to_gs(file_name,bucket_name,blob_name):
    JSON_FILE_NAME = '/home/layla/service_account.json'
    client = storage.Client.from_service_account_json(JSON_FILE_NAME)
    bucket = client.get_bucket(bucket_name)
    blob = bucket.blob(blob_name)
    blob.upload_from_filename(file_name)

def delete_file_from_disk(file_name):
    cmd = f"rm {file_name}"
    status, _ = subprocess.getstatusoutput(cmd)
    return status


def select_and_export_tiles_to_gs(files_list, blob_prefix="test_tfRecord"):
    for i in range(len(files_list)):
        blob_name = f"{blob_prefix}{i+1}.gz"
        file_name = f"./{blob_name}"
        print("processing tile file", i)
        image_dataset = tf.data.TFRecordDataset(files_list[i], compression_type='GZIP')
        ds = image_dataset.map(parse_image, num_parallel_calls=5)
                                          num_parallel_calls=5)

        ds_np_gen = ds.as_numpy_iterator()
        ds_gen = select_tiles_on_classRatio(ds_np_gen,0.3)
        np_to_tfr(ds_gen,file_name=file_name)
        write_file_to_gs(file_name,bucket_name,blob_name)
        delete_file_from_disk(file_name)

In [None]:
bucket_name = "image_tiles_us_florida"
gsfiles = get_file_list(bucket_name)#list_blobs(bucket_name) 
files_list = get_tfrecord_files(gsfiles)
mixer = get_json_file(gsfiles)

In [None]:
# Get relevant info from the JSON mixer file.
patch_width = mixer['patchDimensions'][0]
patch_height = mixer['patchDimensions'][1]
patches = mixer['totalPatches']
patch_dimensions_flat = [patch_width, patch_height]
patch_size = patch_width * patch_height

bands = ['SR_B2','SR_B3','SR_B4','SR_B5','SR_B6','SR_B7','ST_B10','NDVI','NDWI','SR','EVI','OSAVI',
    'SR_B2_1','SR_B3_1','SR_B4_1','SR_B5_1','SR_B6_1','SR_B7_1','ST_B10_1','NDVI_1','NDWI_1','SR_1','EVI_1','OSAVI_1']

image_columns = [tf.io.FixedLenFeature(shape=patch_dimensions_flat, dtype=tf.float32) for k in bands]

bands += ['cwf']

image_columns += [tf.io.FixedLenFeature(shape=patch_dimensions_flat, dtype=tf.int64)]

# Parsing dictionary.
image_features_dict = dict(zip(bands, image_columns))


In [None]:
files_list

In [None]:
select_and_export_tiles_to_gs(files_list)

In [None]:
    
# files_list = files_list[:2]
# for i in range(len(files_list)):
#     blob_name = f"test_tfRecord{i+1}.gz"
#     file_name = f"./{blob_name}"
#     print(i)
#     # Note that you can make one dataset from many files by specifying a list.
#     image_dataset = tf.data.TFRecordDataset(files_list[i], compression_type='GZIP')
#     ds = image_dataset.map(parse_image, num_parallel_calls=5)
#     ds_np_gen = ds.as_numpy_iterator()
#     ds_gen = select_tiles_on_classRatio(ds_np_gen,0.3)
    
#     np_to_tfr(ds_gen,file_name=file_name)
#     write_file_to_gs(file_name,bucket_name,blob_name)
#     delete_file_from_disk(file_name)
    