# Prepare CNN Data

## Setup

In [None]:
import ee
ee.Authenticate()
ee.Initialize()

In [16]:
import numpy as np
import geetools
from geetools import ui, cloud_mask
import os, datetime
import pandas as pd
import itertools
import tensorflow as tf
import config as cf
import ee_utils as utils

cloud_mask_landsatSR = cloud_mask.landsatSR()
cloud_mask_sentinel2 = cloud_mask.sentinel2()

_bytes_feature = utils._bytes_feature
_float_feature = utils._float_feature
_int64_feature = utils._int64_feature
survey_to_fc = utils.survey_to_fc
normalized_diff = utils.normalized_diff
ee_to_np_daytime = utils.ee_to_np_daytime
prep_cnn_np = utils.prep_cnn_np
chunk_ids = utils.chunk_ids

## Parameters

In [17]:
KERNEL_SIZE = 224
SURVEY_NAME = 'DHS'
SATELLITE = 's2'
SKIP_IF_SCRAPED = True
CHUNK_SIZE = 5 # Number of observtaions to scrape in GEE at any given time
PROJECT_DIR = cf.GOOGLEDRIVE_DIRECTORY

In [18]:
# Directory to store tfrecords
out_path = os.path.join(PROJECT_DIR, 
            'Data', 
            SURVEY_NAME, 
            'FinalData',
            'Individual Datasets',
            'cnn_' + SATELLITE,
            'tfrecords')

## Implement

In [49]:
## Load Survey
survey_df = pd.read_csv(os.path.join(PROJECT_DIR, 'Data', SURVEY_NAME, 'FinalData', 'Individual Datasets', 'survey_socioeconomic.csv'))
tf_record_list = list(np.unique(survey_df.tfrecord_name))

In [50]:
# If skip already scraped, remove existing tfrecords from tf_record_list
if SKIP_IF_SCRAPED:
    tf_records_exist = os.listdir(out_path)
    tf_record_list = [x for x in tf_record_list if x not in tf_records_exist]

In [None]:
### Loop through all tfrecords
for tfr_i in tf_record_list:

    print(tfr_i)

    survey_df_yeari = survey_df[survey_df['tfrecord_name'] == tfr_i]
    year_i = survey_df_yeari['year'].iloc[0]

    ### Loop through chunks within tfrecord (can only pull so much data from GEE at a time)
    survey_df_yeari['chunk_id'] = chunk_ids(survey_df_yeari.shape[0], CHUNK_SIZE)

    print(survey_df_yeari.shape)

    proto_examples_all = []
    for chunk_i in list(np.unique(survey_df_yeari.chunk_id)):
        
        survey_df_yeari_chunki = survey_df_yeari[survey_df_yeari['chunk_id'] == chunk_i]
      
        proto_examples_i = prep_cnn_np(survey_df_yeari_chunki, SATELLITE, KERNEL_SIZE, year_i)
        proto_examples_all.extend(proto_examples_i)

        print(len(proto_examples_all))

    ### Save data as tf record
    out_path_i = os.path.join(tfrecord_out_path, tfr_i + '.tfrecord')
    print(out_path_i)
    with tf.io.TFRecordWriter(out_path_i) as writer:
        for tf_example in proto_examples_all:
        writer.write(tf_example.SerializeToString())