In [None]:
import os
import glob
import time
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

import config

from utils.img_aug import img_aug
from utils.imgShow import imgShow, imsShow
from utils.tfrecord_io import parse_image,parse_shape,toPatchPair, image_example
from utils.geotif_io import readTiff, writeTiff
from utils.path_io import crop_patch

from model.seg_model.watnet import watnet
from model.seg_model.deeplabv3_plus import deeplabv3_plus
from model.seg_model.resunet32 import ResUNet34
from model.seg_model.resnet18 import ResNet18


## TF Records
### _Write the tfrecord data_

In [None]:
path_tra_tfrecord_scene = 'data/tfrecords/train_scene.tfrecords'
path_test_tfrecord_scene = 'data/tfrecords/test_scene.tfrecords'
tra_scene_paths = sorted(glob.glob('data/train_test_data/train_scene/*.tif'))
tra_truth_paths = sorted(glob.glob('data/train_test_data/train_truth/*.tif'))
tra_pair_data = list(zip(tra_scene_paths, tra_truth_paths))
print('tra_data length:', len(tra_pair_data))
test_scene_paths = sorted(glob.glob('data/train_test_data/test_scene/*.tif'))
test_truth_paths = sorted(glob.glob('data/train_test_data/test_truth/*.tif'))
test_pair_data = list(zip(test_scene_paths, test_truth_paths))
print('test_data length:', len(test_pair_data))

In [None]:
# Trainging data: Write to a `.tfrecords` file.
with tf.io.TFRecordWriter(path_tra_tfrecord_scene) as writer:
    for path_scene, path_truth in tra_pair_data:
        scene,_ = readTiff(path_scene)
        truth,_ = readTiff(path_truth)
        scene = np.clip(scene/10000,0,1)  
        patch, truth = crop_patch(img=scene, truth=truth)        
        tf_example = image_example(patch, truth)
        writer.write(tf_example.SerializeToString())
        print(path_scene)

In [None]:
# Test data: Write to a `_scene.tfrecords` file.
with tf.io.TFRecordWriter(path_test_tfrecord_scene) as writer:
    for path_scene, path_truth in test_pair_data:
        scene,_ = readTiff(path_scene)
        truth,_ = readTiff(path_truth)
        scene = np.clip(scene/10000,0,1)
        patch, truth = crop_patch(img=scene, truth=truth)        
        tf_example = image_example(patch, truth)
        writer.write(tf_example.SerializeToString())
        print(path_scene)

## Data loading
### _Load and parse the tfrecord data_

In [None]:
### data loading from .tfrecord file
path_train_data_scene = r'data\tfrecords\train_scene.tfrecords'
path_test_data_scene = r'data\tfrecords\test_scene.tfrecords'

## training data
tra_dset = tf.data.TFRecordDataset(path_train_data_scene) 

tra_dset = tra_dset.map(parse_image).map(parse_shape)\
            .cache()\
            .map(toPatchPair)\
            .map(img_aug)
tra_dset = tra_dset.shuffle(config.buffer_size).batch(config.batch_size)

## Test data
test_dset = tf.data.TFRecordDataset(path_test_data_scene)
test_dset = test_dset.map(parse_image).map(parse_shape)\
            .map(toPatchPair)\
            .cache()

test_batch = 8
test_dset = test_dset.batch(test_batch)


In [None]:
## check
# for i in range(5):
start = time.time()
i_batch = i_scene = 0
for patch, truth in tra_dset:
    i_batch += 1
    i_scene += patch.shape[0]
imsShow(img_list=[patch[0], truth[0]], 
        img_name_list=['patch', 'truth'],
        clip_list=[2,0])

plt.show()
print('number of batches:', i_batch)
print('num of scenes:', i_scene)
print('time:', time.time()-start)


In [None]:
## check
# for i in range(5):
start = time.time()
i_batch = i_scene = 0
for patch, truth in test_dset:
    i_batch += 1
    i_scene += patch.shape[0]
imsShow(img_list=[patch[0], truth[0]], 
        img_name_list=['patch', 'truth'],
        clip_list=[2,0])

plt.show()
print('number of batches:', i_batch)
print('num of scenes:', i_scene)
print('time:', time.time()-start)

# Model Training

In [None]:
## model configuration
#model = watnet(input_shape=(config.patch_size, config.patch_size, config.num_bands), nclasses=2)
#model = deeplabv3_plus(nclasses=2, input_shape=(config.patch_size, config.patch_size, config.num_bands))
#model = ResUNet34(input_shape=(config.patch_size, config.patch_size, config.num_bands))
model = ResNet18(input_shape=(config.patch_size, config.patch_size, config.num_bands))


In [None]:
'''------1. train step------'''
@tf.function
def train_step(model, loss_fun, optimizer, x, y):
    with tf.GradientTape() as tape:
        y_pre = model(x, training=True)
        loss = loss_fun(y, y_pre)
    grads = tape.gradient(loss, model.trainable_weights)
    #grad_norms = [tf.norm(g) for g in grads if g is not None]
    #tf.print("Gradient norms:",tf.reduce_max(grad_norms), tf.reduce_min(grad_norms))
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    config.tra_loss.update_state(loss)
    config.tra_oa.update_state(y, y_pre)
    config.tra_miou.update_state(y, y_pre)
    return config.tra_loss.result(), config.tra_oa.result(), config.tra_miou.result()

'''------2. test step------'''
@tf.function
def test_step(model, loss_fun, x, y):
    with tf.GradientTape() as tape:
        y_pre = model(x, training=False)
        loss = loss_fun(y, y_pre)
    config.test_loss.update_state(loss) 
    config.test_oa.update_state(y, y_pre)
    config.test_miou.update_state(y, y_pre)
    return config.test_loss.result(), config.test_oa.result(), config.test_miou.result()

'''------3. train loops------'''
def train_loops(model, loss_fun, optimizer, tra_dset, test_dset, epochs, output):

    # Initialize the outputs
    test_miou, test_loss, test_oa, tra_miou, tra_loss, tra_oa = [], [], [], [], [], []
    
    for epoch in range(epochs):
        start = time.time()
        ###--- train the model ---
        for x_batch, y_batch in tra_dset:
            tra_loss_epoch, tra_oa_epoch,tra_miou_epoch = train_step(model, loss_fun, optimizer, x_batch, y_batch)
        ### --- test the model ---
        for x_batch, y_batch in test_dset:
            test_loss_epoch, test_oa_epoch, test_miou_epoch = test_step(model, loss_fun, x_batch, y_batch)
        ### --- update the metrics ---
        config.tra_loss.reset_states(), config.tra_oa.reset_states(), config.tra_miou.reset_states()
        config.test_loss.reset_states(), config.test_oa.reset_states(), config.test_miou.reset_states()
        format = 'Ep {}/{}: traLoss:{:.3f},traOA:{:.3f},traMIoU:{:.3f},testLoss:{:.3f},testOA:{:.3f},testMIoU:{:.3f},time:{:.1f}s'
        print(format.format(epoch + 1, config.epochs, tra_loss_epoch, tra_oa_epoch, tra_miou_epoch, test_loss_epoch, test_oa_epoch, test_miou_epoch, time.time() - start))
        test_miou.append(test_miou_epoch.numpy())
        test_loss.append(test_loss_epoch.numpy())
        test_oa.append(test_oa_epoch.numpy())
        tra_miou.append(tra_miou_epoch.numpy())
        tra_loss.append(tra_loss_epoch.numpy())
        tra_oa.append(tra_oa_epoch.numpy())

        # Save the final iteration
        if (epoch+1)%config.epochs == 0:
            path_save = 'model/pretrained/'+output+'.h5py'
            model.save(path_save) 
            metric_path = 'results/metrics_'+output+'.csv'
            dataframe = pd.DataFrame({'test_miou':test_miou, 'test_oa':test_oa, 'test_loss':test_loss, 'tra_miou':tra_miou, 'tra_oa':tra_oa, 'tra_loss':tra_loss})
            dataframe.to_csv(metric_path, index=False, sep=',')
            
        ## --- visualize the results ---
        if epoch%10 == 0:
            i = np.random.randint(test_batch)
            for test_patch, test_truth in test_dset.take(1):
                plt.figure(figsize=(10,4))
                pre = model(test_patch, training=False)
                imsShow(img_list=[test_patch.numpy()[i], test_truth.numpy()[i], pre.numpy()[i]], \
                        img_name_list=['test_patch', 'test_truth', 'prediction'], \
                        clip_list=[2,0,0],\
                        color_bands_list=None)
                plt.show()
    
    return test_miou, test_loss, test_oa, tra_miou, tra_loss, tra_oa


In [None]:
## Model training
with tf.device('/device:GPU:1'):
   test_miou, test_loss, test_oa, tra_miou, tra_loss, tra_oa = train_loops(model=model, \
                        loss_fun=config.loss_dice, \
                        #loss_fun=config.loss_focal, \
                        #loss_fun = config.loss_focal_dice, \
                        #loss_fun = config.loss_bce_dice, \
                        #loss_fun = config.loss_bce, \
                        optimizer=config.opt_adam, \
                        tra_dset=tra_dset, \
                        test_dset=test_dset, \
                        epochs=config.epochs,
                       output = 'test_resnet_dice'
                    )
