In [1]:
import tensorflow as tf
from Unet import get_model

In [2]:
!gcloud auth activate-service-account --key-file YOUR_BUCKET_KEY.json
!gsutil cp gs://YOUR_BUCKET_NAME/Landcover.tfrecord.gz .
!gsutil cp gs://YOUR_BUCKET_NAME/landcover.zip .
!unzip landcover.zip

In [3]:
!gsutil cp gs://YOUR_BUCKET_NAME/Landcover_Eval.tfrecord.gz .

In [5]:
model = get_model()

In [6]:
model.output

<tf.Tensor 'conv2d_91/truediv:0' shape=(None, None, None, 4) dtype=float32>

# Dataset Parsing Helpers

In [7]:
KERNEL_SHAPE = [256, 256]
COLUMNS = [
  tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32) for k in ['water','dev','vegi','wetland','B4','B3','B2',"B7","B6","B5"]
]
FEATURES_DICT = dict(zip(['water','dev','vegi','wetland','B3','B2',"B7","B6","B5"], COLUMNS))

FEATURES_DICT = {
    'water': tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32),
    'dev': tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32),
    'vegi': tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32),
    'wetland': tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32),
    'B4':tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32),
    'B3':tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32),
    'B2':tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32),
    'B7':tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32),
    'B6':tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32),
    'B5':tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32)
}
def parse_tfrecord(example_proto):
    return tf.io.parse_single_example(example_proto, FEATURES_DICT)

def to_tuple(inputs):
    inputsList = [inputs.get(key) for key in ['water','dev','vegi','wetland','B4','B3','B2',"B7","B6","B5"]]
    stacked = tf.stack(inputsList, axis=0)
  
    # CHW to HWC
    stacked = tf.transpose(stacked, [1, 2, 0])
    return stacked[:,:,4:], stacked[:,:,:4]

def get_dataset(filename):
    dataset = tf.data.TFRecordDataset(filename, compression_type='GZIP')
    dataset = dataset.map(parse_tfrecord)
    dataset = dataset.map(to_tuple)
    return dataset

In [8]:
dataset = get_dataset('Landcover.tfrecord.gz').batch(10)
eval_dataset = get_dataset('Landcover_Eval.tfrecord.gz').batch(10)

In [9]:
for data in dataset:
    print(data[0].shape)
    print(data[1].shape)
    break

(10, 256, 256, 6)
(10, 256, 256, 4)


In [4]:
model.fit_generator(dataset, validation_data=eval_dataset, epochs=50)
    # x=dataset, 
    # epochs=10, 
    # steps_per_epoch=360, 
    # validation)

In [None]:
model.save("./landcover/")

INFO:tensorflow:Assets written to: ./landcover/assets


In [None]:
!zip -r landcover.zip ./landcover

  adding: landcover/ (stored 0%)
  adding: landcover/variables/ (stored 0%)
  adding: landcover/variables/variables.index (deflated 79%)
  adding: landcover/variables/variables.data-00000-of-00001 (deflated 8%)
  adding: landcover/saved_model.pb (deflated 92%)
  adding: landcover/assets/ (stored 0%)
