<a href="https://colab.research.google.com/github/heromiya/GEE-LCM/blob/master/Landcover_using_ANN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install keras-rectified-adam

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import ee
ee.Authenticate()
ee.Initialize()

In [None]:
import os
import glob
import shutil
import time
import json
from pprint import pprint

import numpy as np
import tensorflow as tf
import folium

import gdal
import osr
import matplotlib.pyplot as plt

In [None]:
def reduce_class_value(feat):
  return feat.set('class', ee.Number(feat.get('class')).subtract(1))

def cloudMask(img):
  return img.updateMask(img.select('BQA').lt(64))

def select_landsat(year):
  if (year >= 2013):
    return {
        'bands': ['B2', 'B3', 'B4', 'B5', 'B6', 'B7'],
        'FCCbands': ['B5', 'B4', 'B3'],
        'TextureBand': ['B5'],
        'bandsClassify': ['B2', 'B3', 'B4', 'B5', 'B6', 'B7','B5_1'],
        'Landsat': 'LANDSAT/LC08/C01/T1'
    }

  if (year >= 1999 and year <= 2002):
    return {
        'bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7'],
        'TextureBand': ['B4'],
        'bandsClassify': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7','B4_1'],
        'FCCbands': ['B4', 'B3', 'B2'],
        'Landsat': 'LANDSAT/LE07/C01/T1'
    }

  if (year >= 1984 and year <= 1998 or year >=2003 and year <=2012):
    return{
        'bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7'],
        'bandsClassify': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7','B4_1'],
        'FCCbands': ['B4', 'B3', 'B2'],
        'TextureBand': ['B4'],
        'Landsat': 'LANDSAT/LT05/C01/T1'
    }
  
  if (year == 1983):
    return{
        'bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7'],
        'bandsClassify': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7','B4_1'],
        'FCCbands': ['B4', 'B3', 'B2'],
        'TextureBand': ['B4'],
        'Landsat': 'LANDSAT/LT04/C01/T1'
    }
  
  if (year >= 1979 and year <= 1982):
    return {
        'bands':         ['B4_median', 'B5_median', 'B6_median','B7_median'],
        'bandsClassify': ['B4_median', 'B5_median', 'B6_median','B7_median','B6_median_1'],
        'FCCbands':      ['B6_median', 'B5_median', 'B4_median'],
        'TextureBand':   ['B6_median'],
        'Landsat': 'LANDSAT/LM03/C01/T1'
    }

  if (year >= 1975 and year <= 1978):
    return {
        'bands':         ['B4_median', 'B5_median', 'B6_median','B7_median'],
        'bandsClassify': ['B4_median', 'B5_median', 'B6_median','B7_median','B6_median_1'],
        'FCCbands':      ['B6_median', 'B5_median', 'B4_median'],
        'TextureBand':   ['B6_median'],
        'Landsat': 'LANDSAT/LM02/C01/T2'
    }

  if (year >= 1972 and year <= 1974):
    return {
        'bands':         ['B4_median', 'B5_median', 'B6_median','B7_median'],
        'bandsClassify': ['B4_median', 'B5_median', 'B6_median','B7_median','B6_median_1'],
        'FCCbands':      ['B6_median', 'B5_median', 'B4_median'],
        'TextureBand':   ['B6_median'],
        'Landsat': 'LANDSAT/LM01/C01/T2'
    }

In [None]:
def landsat_product(landsat_params):
  landsat_id = landsat_params['Landsat']
  if(landsat_id == 'LANDSAT/LM01/C01/T2' or landsat_id == 'LANDSAT/LM02/C01/T2' or landsat_id == 'LANDSAT/LM03/C01/T1'):
    image = ee.ImageCollection(landsat_id).filterDate(str(gtYearBegin) + '-01-01',str(gtYearEnd) + '-12-31').filter(ROI[roi]['doyFilterLandsat']).filterBounds(out_ext).filterMetadata('CLOUD_COVER_LAND', 'less_than', 20).map(cloudMask).reduce(ee.Reducer.median())
    return {
      'region': image.geometry().bounds().getInfo(),
      'image': image,
      'image_int': image.toByte()       
    }
  else:
    image_col = ee.ImageCollection(landsat_id).filterDate(str(gtYearBegin) + '-01-01',str(gtYearEnd) + '-12-31').filter(ROI[roi]['doyFilterLandsat']).filterBounds(out_ext)
    return {
      'region': image_col.geometry().bounds().getInfo(),
      'image': ee.Algorithms.Landsat.simpleComposite(image_col, 50, cloud, 40, True).select(landsat_params['bands']),
      'image_int': ee.Algorithms.Landsat.simpleComposite(image_col, 50, cloud, 40, False).select(landsat_params['bands'])
  }

In [None]:
def add_indices(params, product):
  L = params['Landsat']
  if(L =='LANDSAT/LC08/C01/T1'):
    swir  = product['image'].select('B6')
    nir   = product['image'].select('B5')
    red   = product['image'].select('B4')
    green = product['image'].select('B3')

  if(L == 'LANDSAT/LE07/C01/T1'):
    swir  = product['image'].select('B5')
    nir   = product['image'].select('B4')
    red   = product['image'].select('B3')
    green = product['image'].select('B2')

  if(L == 'LANDSAT/LM04/C01/T1' or L == 'LANDSAT/LT05/C01/T1'):
    swir  = product['image'].select('B5')
    nir   = product['image'].select('B4')
    red   = product['image'].select('B3')
    green = product['image'].select('B2')

  ndvi  = nir.subtract(red).divide(nir.add(red))
  ndbi  = swir.subtract(nir).divide(swir.add(nir))
  bi    = ndbi.subtract(ndvi)
  ndwi  = green.subtract(swir).divide(green.add(swir))

  product['image'] = product['image'].addBands([ndvi, ndbi, bi, ndwi])
  product['image'] = product['image'].rename(ee.List(params['bands'] + ['ndvi','ndbi','bi','ndwi']))
  params['bandsClassify'].extend(['ndvi', 'ndbi', 'bi', 'ndwi'])

def add_indices_mss(params, product):
  nir = product['image'].select('B6_median')
  red = product['image'].select('B5_median')
  ndvi = nir.subtract(red).divide(nir.add(red))

  product['image'] = product['image'].addBands(ndvi)
  product['image'] = product['image'].rename(ee.List(params['bands'] + ['ndvi']))
  params['bandsClassify'].append('ndvi')
  

In [None]:
##### User defined paramters ######

ts = time.time()
timestamp = int(ts)

cloud = 30
spatial_resolution = 30
n_sample = 10000
distance = 370000

kernel_size_list = [1, 3, 5, 7, 9]

base_distance = 5000
patch_size = 256
buffer_distance = base_distance * (round((spatial_resolution * 0.5 * patch_size)/base_distance) +1)

USER_NAME = 'heromiya'
GT_DATA = 'users/heromiya/gt-pt-170-52-2019-2021'

ROI = [
       {'cityName': 'Ethiopia-2-1_lower-left_res120_spring', 'LatMax': 11.0, 'LatMin': 10.0, 'LonMax': 37.85, 'LonMin': 36.5, 'yearBegin':2019, 'yearEnd':2021, 'doyFilter':ee.Filter.And(ee.Filter.greaterThanOrEquals('doy',  1), ee.Filter.lessThanOrEquals('doy',  366)), 'doyFilterLandsat': ee.Filter.dayOfYear(244,335)},
]

MY_DRIVE_PATH = '/content/drive/MyDrive'
SHARED_DRIVE_PATH = '/content/drive/Shareddrives/Miyazaki Lab./' + USER_NAME + '/'
OUTPUT_ASSET_ID = 'users/' + USER_NAME + '/demo'

In [None]:
roi= ''
out_ext = ''

for roi in range(0, len(ROI), 1):
  city_name = ROI[roi]['cityName']
  year_begin = ROI[roi]['yearBegin']
  year_end = ROI[roi]['yearEnd']

  IMAGE_FILE_PREFIX = city_name + '_' + str(year_begin) + '_' + str(year_end)
  FOLDER_NAME = USER_NAME + '_' + IMAGE_FILE_PREFIX + '_' + str(timestamp)
  FOLDER_PATH = MY_DRIVE_PATH + '/' + FOLDER_NAME
  TRAIN_FILE_PREFIX = 'Training_' + str(timestamp)
  TEST_FILE_PREFIX = 'Testing_' + str(timestamp)

  file_extension = '.tfrecord.gz'
  TRAIN_FILE_PATH =  MY_DRIVE_PATH + '/' + FOLDER_NAME + '/sample/' +TRAIN_FILE_PREFIX + file_extension
  TEST_FILE_PATH = MY_DRIVE_PATH + '/' + FOLDER_NAME + '/sample/' + TEST_FILE_PREFIX + file_extension

  OUTPUT_IMAGE_TFR = MY_DRIVE_PATH + '/' + FOLDER_NAME + '/'+ IMAGE_FILE_PREFIX+'.TFRecord'
  OUTPUT_IMAGE_TIF = MY_DRIVE_PATH + '/' + FOLDER_NAME + '/'+ IMAGE_FILE_PREFIX+'.tif'

  for year in range(year_begin, year_end+1, 1):
    nSampleClass = []
    out_ext = ee.Geometry.Rectangle([ ROI[roi]['LonMin'], ROI[roi]['LatMin'], ROI[roi]['LonMax'], ROI[roi]['LatMax']])
    EXPORT_REGION = out_ext
    out_ext_center = out_ext.centroid()

    out_ext_buffer = out_ext.buffer(buffer_distance).bounds()

    center_lat = out_ext_center.getInfo()['coordinates'][1]
    center_lon = out_ext_center.getInfo()['coordinates'][0]

    landsat_params = select_landsat(year_begin)

    gtYearBegin = year_begin
    gtYearEnd = year_end

    product = landsat_product(landsat_params)

    ref_point = ee.FeatureCollection(GT_DATA).map(reduce_class_value)
    LABEL_DATA = ref_point.filter(ee.Filter.And(ee.Filter.greaterThanOrEquals('year', gtYearBegin), ee.Filter.lessThanOrEquals('year', gtYearEnd), ROI[roi]['doyFilter']))
    LABEL_DATA = LABEL_DATA.filterBounds(product['region'])
    if (year_end > 1982):
      add_indices(landsat_params, product)
    else:
      add_indices_mss(landsat_params, product)

    gt_image = product['image']
    gt_image_texture = product['image_int']
    for i in kernel_size_list:
      gt_image_texture = gt_image.addBands(product['image_int'].select(landsat_params['TextureBand']).entropy(ee.Kernel.gaussian(i)))


    out_image = product['image'].clip(out_ext_buffer)
    out_image_texture = product['image_int'].clip(out_ext_buffer)
    for i in kernel_size_list:
      out_image_texture = out_image.addBands(product['image_int'].clip(out_ext_buffer).select(landsat_params['TextureBand']).entropy(ee.Kernel.gaussian(i)))

In [None]:
print(gt_image.getInfo())

In [None]:
BANDS = landsat_params['bandsClassify']
LABEL = 'class'
N_CLASSES = 4

FEATURE_NAMES = list(BANDS)
FEATURE_NAMES.append(LABEL)

sample = gt_image_texture.sampleRegions(
    collection = LABEL_DATA,
    properties = [LABEL],
    scale = spatial_resolution,
).randomColumn()

training = sample.filter(ee.Filter.lt('random', 0.8))
testing = sample.filter(ee.Filter.gte('random', 0.8))

#pprint({'training': training.first().getInfo()})
#pprint({'testing': testing.first().getInfo()})

In [None]:
training_task = ee.batch.Export.table.toDrive(
  collection=training,
  description='Training Export',
  folder=FOLDER_NAME,
  fileNamePrefix=TRAIN_FILE_PREFIX,
  fileFormat='TFRecord',
  selectors=FEATURE_NAMES)

testing_task = ee.batch.Export.table.toDrive(
  collection=testing,
  description='Testing Export',
  folder=FOLDER_NAME,
  fileNamePrefix=TEST_FILE_PREFIX,
  fileFormat='TFRecord',
  selectors=FEATURE_NAMES)

In [None]:
training_task.start()
testing_task.start()

In [None]:
while training_task.active():
  print('Polling for task (id: {}).'.format(training_task.id))
  time.sleep(20)
print('Done with training export.')

while testing_task.active():
  print('Polling for task (id: {}).'.format(testing_task.id))
  time.sleep(20)
print('Done with testing export.')

In [None]:
time.sleep(20)
drive.mount('/content/drive', force_remount=True)

In [None]:
src_dir = MY_DRIVE_PATH + '/' + FOLDER_NAME
src_files = os.listdir(src_dir)

dest_dir = os.path.join(src_dir, 'sample')

os.makedirs(dest_dir, exist_ok = True)

for file in src_files:
    shutil.move(os.path.join(src_dir, file), dest_dir)

In [None]:
print('Found training file.' if tf.io.gfile.exists(TRAIN_FILE_PATH) 
    else 'No training file found.')

print('Found testing file.' if tf.io.gfile.exists(TEST_FILE_PATH) 
    else 'No testing file found.')

In [None]:
train_dataset = tf.data.TFRecordDataset(TRAIN_FILE_PATH, compression_type='GZIP')
test_dataset = tf.data.TFRecordDataset(TEST_FILE_PATH, compression_type='GZIP')

# print(iter(train_dataset).next())
# print(iter(test_dataset).next())

In [None]:
columns = [
           tf.io.FixedLenFeature(shape=[1], dtype=tf.float32) for k in FEATURE_NAMES
]

features_dict = dict(zip(FEATURE_NAMES, columns))

pprint(features_dict)

In [None]:
def parse_tfrecord(example_proto):
  parsed_features = tf.io.parse_single_example(example_proto, features_dict)
  labels = parsed_features.pop(LABEL)
  return parsed_features, tf.cast(labels, tf.int32)

parsed_trainset = train_dataset.map(parse_tfrecord, num_parallel_calls=5)
parsed_testset = test_dataset.map(parse_tfrecord, num_parallel_calls=5)

# pprint(iter(parsed_trainset))
# pprint(iter(parsed_testset))

In [None]:
batch_size = 50

def to_tuple(inputs, label):
  return (tf.transpose(list(inputs.values())), tf.one_hot(indices=label, depth=N_CLASSES))

input_dataset = parsed_trainset.map(to_tuple).batch(batch_size)
validate_dataset = parsed_testset.map(to_tuple).batch(batch_size)


In [None]:
drop = 0.05
N_UNITS=64
model = tf.keras.models.Sequential()

model.add(tf.keras.layers.Dense(N_UNITS, activation='relu'))
for i in range(8):
  model.add(tf.keras.layers.Dropout(drop))
  model.add(tf.keras.layers.Dense(N_UNITS, activation='relu'))

model.add(tf.keras.layers.Dense(N_CLASSES, activation=tf.nn.softmax))

In [None]:
from keras_radam import RAdam

model.compile(
    optimizer=RAdam(),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
from keras.callbacks import ModelCheckpoint
from keras.callbacks import LearningRateScheduler

log_d = MY_DRIVE_PATH + '/' + FOLDER_NAME + '/' + 'model'
os.makedirs(log_d, exist_ok=True)

def build_callbacks():
    checkpointer = ModelCheckpoint(filepath = log_d + '/best_model.h5', verbose=0, save_best_only=True, monitor='val_loss')
    callbacks = [checkpointer]
    return callbacks
history = model.fit(x=input_dataset,validation_data=(validate_dataset), epochs=100, callbacks=build_callbacks())

In [None]:
import matplotlib.pyplot as plt
# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
from tensorflow import keras
model = keras.models.load_model(log_d + '/best_model.h5', custom_objects={'RAdam': RAdam})
model.evaluate(validate_dataset)

In [None]:
image_export_options = {
    'patchDimensions': [patch_size, patch_size],
    'maxFileSize': 104857600,
    'compressed': True
}

image_task = ee.batch.Export.image.toDrive(
  image=out_image_texture,
  description='Image Export',
  fileNamePrefix= IMAGE_FILE_PREFIX,
  folder= FOLDER_NAME,
  scale=spatial_resolution,
  fileFormat='TFRecord',
  region=out_ext_buffer,
  formatOptions=image_export_options,
)

image_task.start()

In [None]:
while image_task.active():
  print('Polling for task (id: {}).'.format(image_task.id))
  time.sleep(20)
print('Done with image export.')

In [None]:
time.sleep(20)
drive.mount('/content/drive', force_remount=True)

In [None]:
tfrecord_list = glob.glob(MY_DRIVE_PATH +'/' + FOLDER_NAME + '/' + '*.tfrecord.gz')
tfrecord_list.sort()
print(tfrecord_list)

json_list = glob.glob(MY_DRIVE_PATH +'/' + FOLDER_NAME + '/' + '*.json')
print(json_list)

In [None]:
with open(json_list[0], "r") as read_file:
   mixer = json.load(read_file)

mixer

In [None]:
patch_width = mixer['patchDimensions'][0]
patch_height = mixer['patchDimensions'][1]
patches = mixer['totalPatches']
patch_dimensions_flat = [patch_width*patch_height, 1]

image_columns = [
                tf.io.FixedLenFeature(shape=patch_dimensions_flat, dtype=tf.float32) for k in BANDS
]

image_features_dict = dict(zip(BANDS, image_columns))

image_dataset = tf.data.TFRecordDataset(tfrecord_list, compression_type='GZIP')


In [None]:
def parse_image(example_proto):
  return tf.io.parse_single_example(example_proto, image_features_dict)

In [None]:
image_dataset = image_dataset.map(parse_image, num_parallel_calls=5)

image_dataset = image_dataset.flat_map(
    lambda features: tf.data.Dataset.from_tensor_slices(features)
)

image_dataset = image_dataset.map(
  lambda data_dict: (tf.transpose(list(data_dict.values())), )
)

image_dataset = image_dataset.batch(patch_width * patch_height)

In [None]:
predictions = model.predict(image_dataset, steps=patches, verbose=1)
print(predictions[0])

In [None]:
writer = tf.io.TFRecordWriter(OUTPUT_IMAGE_TFR)

patch = [[]]
cur_patch = 1
for prediction in predictions:
  patch[0].append(tf.argmax(prediction, 1))

  if (len(patch[0])==patch_width*patch_height):
    print('Done with Patch ' + str(cur_patch) + ' of ' + str(patches) + '...')
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'prediction': tf.train.Feature(
                    int64_list=tf.train.Int64List(
                        value=patch[0]
                    )
                )
            }
        )
    )

    writer.write(example.SerializeToString())
    patch=[[]]
    cur_patch +=1
  
writer.close()

In [None]:
record_iterator = tf.compat.v1.python_io.tf_record_iterator(path=OUTPUT_IMAGE_TFR)

n_row = patches/mixer['patchesPerRow']
n_col = mixer['patchesPerRow']
current_row = 0
current_col = 0
counter = 0

for string_record in record_iterator:
  example = tf.train.Example()
  example.ParseFromString(string_record)
  values = np.array(example.features.feature['prediction'].int64_list.value).reshape(patch_width, patch_height).astype(np.int8)

  if (current_col==0):
    horizontal_strip = values
  else:
    horizontal_strip = np.concatenate([horizontal_strip, values], axis=1)
  current_col += 1

  if (current_col == n_col):
    if (current_row==0):
      image = horizontal_strip
    else:
      image = np.concatenate([image, horizontal_strip], axis=0)
      horizontal_strip = []

    current_row +=1
    current_col = 0
    print(counter)
  counter +=1
print(counter)

In [None]:
imgplot = plt.imshow(image, cmap=plt.get_cmap('jet'), vmin=0, vmax=1)
plt.show()

In [None]:
affine = mixer['projection']['affine']['doubleMatrix']
geotransform = (affine[2], affine[0], affine[1], affine[5], affine[3], affine[4])
crs = int(mixer['projection']['crs'][-4:])
print(affine)
print(geotransform)
print(crs)

In [None]:
ny, nx = image.shape
dst_ds = gdal.GetDriverByName('GTiff').Create(OUTPUT_IMAGE_TIF, nx, ny, 1, gdal.GDT_Byte)
dst_ds.SetGeoTransform(tuple(geotransform))
srs = osr.SpatialReference()
srs.ImportFromEPSG(crs)
dst_ds.SetProjection(srs.ExportToWkt())
dst_ds.GetRasterBand(1).WriteArray(image)
dst_ds.FlushCache()
dst_ds = None

print("Exorting " + OUTPUT_IMAGE_TIF + " completed.")

In [None]:
fcc_image = out_image.select(landsat_params['FCCbands'])
fcc_prefix = city_name + '_' + str(year_begin) + '_' + str(year_end)+'_fcc'

fcc_task = ee.batch.Export.image.toDrive(**{
    'image': fcc_image,
    'description': fcc_prefix,
    'folder': FOLDER_NAME,
    'fileNamePrefix': fcc_prefix,
    'scale': spatial_resolution,
    'region': out_ext
})
fcc_task.start()

In [None]:
while fcc_task.active():
  print('Polling for task (id: {}).'.format(fcc_task.id))
  time.sleep(20)
print('Done with fcc image export.')

In [None]:
time.sleep(20)
drive.mount('/content/drive', force_remount=True)

In [None]:
lc_dir = MY_DRIVE_PATH + '/' + FOLDER_NAME + '/lc_map'
fcc_path = MY_DRIVE_PATH + '/' + FOLDER_NAME + '/'+ fcc_prefix+'.tif'
ref_dir = MY_DRIVE_PATH + '/' + FOLDER_NAME + '/ref_map'

os.makedirs(lc_dir, exist_ok = True)
os.makedirs(ref_dir, exist_ok = True)

shutil.move(OUTPUT_IMAGE_TIF, lc_dir)
shutil.move(OUTPUT_IMAGE_TFR, lc_dir)
shutil.move(fcc_path, ref_dir)

In [None]:
drive.mount('/content/drive', force_remount=True)
shutil.move(FOLDER_PATH, SHARED_DRIVE_PATH)