In [None]:
!pip install git+https://github.com/tdrobbins/unet.git
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import unet
from glob import glob
import numpy as np
import sen2lte_utils as utils

# Parameters

In [None]:
# Session settings 
TAG   = "ca_1280-resize_256crop" # descriptive name for the run
SEED  = 1                       # random seed used by data generators

# Preprocessing parameters
RESIZE      = 1280      # downsample raw satellite image to RESIZExRESIZE
CROP        = 256       # split downsampled image into CROPxCROP images
N_CLASSES   = 5         # [2, 5, or 10] how many access classes to train
STATES      = ["ca"]    # which states' data to use (ca,co,or available)

TRAIN_SPLIT   = 0.8     # proportion of data to use for training
CLASS_WEIGHTS = "auto"  # set class weights automatically ("auto") or with a list
NORMALIZE     = False   # whether the data generator should normalize the data

# Model parameters
CONTINUE      = False   # [None, /path/model] previous model to contiune fitting (overrides model settings)
LEARNING_RATE = 0.0001 
LAYER_DEPTH   = 5       
ROOT_FILTERS  = 64
DROPOUT       = 0.0

# Training parameters 
EPOCHS        = 250
BATCH_SIZE    = 16
TRANSFER      = False

# Output path settings
SAVE_ROOT = "/kaggle/working/logs/"
SAVE_PATH = SAVE_ROOT + TAG + "/"
CP_PATH = SAVE_PATH+"checkpoint/"
MODEL_PATH = SAVE_PATH+"model/"


# Data loading and preprocessing

In [None]:
# loading input images and masks for each state
X = []
Y = []
for s in STATES:
    sat_files = np.sort(glob("/kaggle/input/sen2lte/Data/sentinel2/{}/composite/*.tif".format(s)))
    cell_files = np.sort(glob("/kaggle/input/sen2lte/Data/celltowers/sentinel2/{}/*.jpg".format(s)))
    lte_files = np.sort(glob("/kaggle/input/sen2lte/Data/fcc477actual/sentinel2/{}/cat{}*.jpg".format(s,N_CLASSES)))
    
    X.append(utils.load_sat_imgs(sat_files, cell_files,resize=RESIZE, crop_size=CROP))
    Y.append(utils.load_masks(lte_files, N_CLASSES, crop_size=CROP, resize=RESIZE, onehot=True))

X = np.concatenate(X)
Y = np.concatenate(Y)
m,nx,ny,nchannels = X.shape

In [None]:
# Splitting into train and dev sets   
(X_train, Y_train), (X_dev, Y_dev) = utils.split_train_dev(X,Y,split=TRAIN_SPLIT, batch_size=BATCH_SIZE, seed=SEED)

# Experimenting with settings CLASS_WEIGHTS automatically
if CLASS_WEIGHTS == "auto":
    counts = np.unique(Y_train.argmax(-1),return_counts=True)[1]
    CLASS_WEIGHTS = np.round(max(counts)/counts).astype(np.uint)

if CLASS_WEIGHTS is not None:
    Y_train = Y_train*CLASS_WEIGHTS
    Y_dev = Y_dev*CLASS_WEIGHTS

# Create generators using split datasets
train_datagen = ImageDataGenerator(featurewise_center=NORMALIZE,featurewise_std_normalization=NORMALIZE)
dev_datagen = ImageDataGenerator(featurewise_center=NORMALIZE,featurewise_std_normalization=NORMALIZE)

if NORMALIZE:
    train_datagen.fit(X_train)
    dev_datagen.fit(X_train)

train_set = train_datagen.flow(X_train,Y_train,batch_size=BATCH_SIZE)
dev_set = dev_datagen.flow(X_dev,Y_dev)

# Model creation

In [None]:
if CONTINUE: 
    unet_model = tf.keras.models.load_model(CONTINUE)

else:
  # Building the model
  unet_model = unet.build_model(nx, ny,
                            channels = nchannels,
                            num_classes = N_CLASSES,
                            layer_depth = LAYER_DEPTH,
                            filters_root= ROOT_FILTERS,
                            dropout_rate = DROPOUT,
                            padding = "same"
                            )

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=SAVE_PATH, histogram_freq=1)
best_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=CP_PATH,save_weights_only=False,save_best_only=True,monitor="val_Recall",mode="max",verbose=True)
epoch_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=MODEL_PATH,save_weights_only=False,period=50,verbose=True)
callbacks = [tensorboard_callback, best_checkpoint, epoch_checkpoint]

unet_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
            loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
            metrics=["Precision", "Recall","CategoricalAccuracy"])


# Training

In [None]:
if TRANSFER: unet_model.load_weights(TRANSFER)

model_history = unet_model.fit(train_set,
                              validation_data=dev_set,
                              epochs=EPOCHS, 
                              steps_per_epoch=len(X_train)//BATCH_SIZE,
                              callbacks=callbacks)

unet_model.save(MODEL_PATH)

In [None]:
utils.plot_metrics(model_history)

In [None]:
utils.plot_cropped_examples(unet_model, (X_train,Y_train), (X_dev,Y_dev), RESIZE, CROP, N_CLASSES)

In [None]:
select = np.random.choice(range(len(sat_files)),3)
utils.plot_uncropped_examples(unet_model, sat_files[select], cell_files[select], lte_files[select], RESIZE, CROP, N_CLASSES)