<a href="https://colab.research.google.com/github/goodakai/GEE/blob/main/Build%2C_train%2C_and_save_the_U_Net_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# python code for mapping leucogranite based on GEE
# author: Dakai Guo,Ziye Wang
# contact: Ziye Wang (Email: ziyewang@cug.edu.cn)

In [None]:
# mount google cloud drive and GEE API
from google.colab import drive
drive.mount('/content/drive')
import ee
ee.Authenticate()
ee.Initialize(project='my-project')# GEE Project Title

In [None]:
# Import necessary packages
import tensorflow as tf
from tensorflow.keras import layers, models

In [None]:
# Set the GCP project ID and GCS BUCKET ID
PROJECT = 'GCP project ID'
OUTPUT_BUCKET = 'GCS BUCKET ID'
# Preset input and output variables
INPUT_BANDS  = ['a3n','a04','a06','a07','s11']
OUTPUT_BANDS = ['lithology']
N_CLASSES = len(OUTPUT_BANDS)
BANDS = INPUT_BANDS + OUTPUT_BANDS
COLUMNS = [
  tf.io.FixedLenFeature(shape=[256, 256], dtype=tf.float32)# Preset Data Shapes
  for k in BANDS
]
FEATURES_DICT = dict(zip(BANDS, COLUMNS))

In [None]:
# Call the training label vector file stored in GEE
labelfc = ee.FeatureCollection("Asset ID");
# Vector to image
label_im = labelfc.reduceToImage(['Lithology'], 'mean')
# Determine the training data area
aoi = labelfc.geometry().bounds();
# Call the pre-processed remote sensing data stored in GEE
rsimage = ee.Image("Asset ID").select(INPUT_BANDS)
# Combining remote sensing data with labels
trainimage = rsimage.addBands(label_im.select(['mean'],['lithology'])).clip(aoi)
print(trainimage.getInfo())

In [None]:
# Generate training data export task
task_config = {
    'image': trainimage,
    'description': '5b_tirtraindata256',
    'bucket': OUTPUT_BUCKET,
    'fileNamePrefix': '5b_tirtraining256',
    'region': aoi,
    'scale': 30,
    'fileFormat': 'TFRecord',
    'maxPixels': 1e13,
    'formatOptions': {
        'patchDimensions': [32, 32],# Specify the core size for ​​one tiles
        'kernelSize': [224, 224],# Specify the buffer size for ​​one tiles
        'compressed': True,
        'maxFileSize': 104857600 * 10
    }
}

In [None]:
# Start export task
task = ee.batch.Export.image.toCloudStorage(**task_config)
task.start()

In [None]:
print(task.status())

In [None]:
# Load training data from GCS and Augment
class Augment(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs, labels):
        # Flip along the X and Y axes
        augmented_inputs = [inputs, tf.image.flip_left_right(inputs), tf.image.flip_up_down(inputs)]
        augmented_labels = [labels, tf.image.flip_left_right(labels), tf.image.flip_up_down(labels)]

        # Rotate 3 times 90°
        for k in range(1, 3):
            augmented_inputs.append(tf.image.rot90(inputs, k=k))
            augmented_labels.append(tf.image.rot90(labels, k=k))

        # Merge all augmented images
        augmented_inputs = tf.concat(augmented_inputs, axis=0)
        augmented_labels = tf.concat(augmented_labels, axis=0)

        return augmented_inputs, augmented_labels

# Parses a single TFRecord example into a dictionary of features
def parse_tfrecord(example_proto):
    return tf.io.parse_single_example(example_proto, FEATURES_DICT)

def to_tuple(inputs):
    # Expand each input channel (from INPUT_BANDS) by adding a new axis at the end
    # This turns each channel into a tensor of shape (height, width, 1)
    input_channels = [tf.expand_dims(inputs[name], axis=-1) for name in INPUT_BANDS]
    # Concatenate all channels along the last axis to get a tensor of shape (height, width, num_channels)
    inputs_concatenated = tf.concat(input_channels, axis=-1)
    # Get the output label and cast it to float32 type
    labels = tf.cast(inputs[OUTPUT_BANDS[0]], tf.float32)
    # Expand the label by adding a new axis at the end, resulting in shape (height, width, 1)
    labels = tf.expand_dims(labels, axis=-1)
    return inputs_concatenated, labels

def filter_black_borders(inputs, labels):
    # Create a boolean mask where True means the pixel is non-zero
    mask = tf.reduce_any(tf.not_equal(inputs, 0), axis=-1)
    # Check if all pixels in the image are non-zero
    all_non_zero = tf.reduce_all(mask)
    return all_non_zero

def get_dataset(pattern, batch_size):
    dataset = tf.data.Dataset.list_files(pattern).interleave(
        lambda filename: tf.data.TFRecordDataset(filename, compression_type='GZIP'))
    dataset = dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.filter(filter_black_borders)
    dataset = dataset.cache()
    dataset = dataset.shuffle(512)
    dataset = dataset.map(lambda inputs, labels: Augment()(inputs, labels), num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.unbatch()
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

# Create the training dataset with batch size of 6
batch_size = 6
training_dataset = get_dataset("Path of training data", batch_size)

# Inspect the first element from the training dataset.
for inputs, outputs in training_dataset.take(1):
    print("inputs:")
    print(f"  {inputs.dtype.name} {inputs.shape}")
    print(f"outputs: {outputs.dtype.name} {outputs.shape}")


In [None]:
# Calculate the size of the training and test sets
DATASET_SIZE = len(list(training_dataset))
train_size = int(0.8 * DATASET_SIZE)
val_size = int(0.2 * DATASET_SIZE)

# Divide the training set
train_dataset = training_dataset.take(train_size)

# Divide validation set
val_dataset = training_dataset.skip(train_size).take(val_size)

In [None]:
# Building the U-Net model
def get_unet_model(input_shape, num_classes):
    inputs = tf.keras.Input(shape=[None, None, len(INPUT_BANDS)])

    conv1 = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
    conv1 = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)
    conv2 = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)
    conv3 = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = layers.Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)
    conv4 = layers.Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = layers.Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)
    conv5 = layers.Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)

    up6 = layers.concatenate([layers.UpSampling2D(size=(2, 2))(conv5), conv4])
    conv6 = layers.Conv2D(256, (3, 3), activation="relu", padding="same")(up6)
    conv6 = layers.Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)

    up7 = layers.concatenate([layers.UpSampling2D(size=(2, 2))(conv6), conv3])
    conv7 = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(up7)
    conv7 = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)

    up8 = layers.concatenate([layers.UpSampling2D(size=(2, 2))(conv7), conv2])
    conv8 = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(up8)
    conv8 = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)

    up9 = layers.concatenate([layers.UpSampling2D(size=(2, 2))(conv8), conv1])
    conv9 = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(up9)
    conv9 = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)

    conv10 = layers.Conv2D(num_classes, (1, 1), activation="sigmoid")(conv9)

    model = models.Model(inputs, outputs=conv10)
    return model
model = get_unet_model([256, 256, len(INPUT_BANDS)], 1)

In [None]:
# Setting the learning rate and optimizer
learning_rate = 0.0001
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
# Compile the model
model.compile(
    optimizer=optimizer,
    loss="binary_crossentropy",
    metrics=['accuracy']
)
# Training the model
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=200,
)

In [None]:
# Deserialize and package the trained model
class DeSerializeInput(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, inputs_dict):
    return {
      k: tf.map_fn(lambda x: tf.io.parse_tensor(x, tf.float32),
                   tf.io.decode_base64(v),
                   fn_output_signature=tf.float32)
        for (k, v) in inputs_dict.items()
    }

  def get_config(self):
    config = super().get_config()
    return config


class ReSerializeOutput(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, output_tensor):
    return tf.map_fn(lambda x: tf.io.encode_base64(tf.io.serialize_tensor(x)),
                    output_tensor,
                    fn_output_signature=tf.string)

  def get_config(self):
    config = super().get_config()
    return config

input_deserializer = DeSerializeInput()
output_deserializer = ReSerializeOutput()

serialized_inputs = {
    model.inputs[0].name: tf.keras.Input(shape=[], dtype='string', name='array_image')
}
updated_model_input = input_deserializer(serialized_inputs)
updated_model = model(updated_model_input)
updated_model = output_deserializer(updated_model)
updated_model= tf.keras.Model(serialized_inputs, updated_model)

In [None]:
# Set the model name
MODEL_NAME = 'MODEL_NAME'
# Replace your-bucket with your bucket name.
MODEL_DIR = 'gs://your-bucket/' + MODEL_NAME
updated_model.save(MODEL_DIR)