In [None]:
!pip install git+https://github.com/tdrobbins/unet.git@nchannels

In [None]:
%load_ext tensorboard
import tensorflow as tf
import unet
import pandas as pd
import matplotlib.pyplot as pp
import seaborn as sns
from glob import glob
import numpy as np
import utils

from google.colab import files

# Preprocessing

In [None]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

In [None]:
# Loading and preprocessing data

# Preprocessing paramters
CROP        = 1024    # split raw images into many CROPxCROP images
RESIZE      = 256    # downsample cropped image to RESIZExRESIZE
N_CLASSES   = 4      # 2, 4, or 10
CELLTOWERS  = True  # currently doesn't work... may require forking unet due to dumb bug

# Loading satellite images into an array
sat_files = np.sort(glob("/content/drive/My Drive/cs230/sen2_composite/*.tif"))
X = utils.load_data(sat_files,crop_size=CROP,resize=RESIZE,scale=1./256.)

# Loading nclasses version of lte mask images into an array
lte_files = np.sort(glob("/content/drive/My Drive/cs230/lte/cat{}*.tif".format(N_CLASSES)))
Y = utils.load_data(lte_files,crop_size=CROP,resize=RESIZE)

# Resizing and scaling images

if CELLTOWERS:
  cell_files = np.sort(glob("/content/drive/My Drive/cs230/celltowers/*.tif"))
  cell_arr = utils.load_data(cell_files,crop_size=CROP,resize=RESIZE)
  X = np.asarray([np.dstack((x,c)) for x,c in zip(X,cell_arr)])
  

# Model Training

In [None]:
%tensorboard --logdir /content/unet/

In [None]:
 # Setting up the datasets for tf
TRAIN_SPLIT   = 0.8
BATCH_SIZE    = 8
CLASS_WEIGHTS = [5,5,5,1]

split = np.round(X.shape[0]*TRAIN_SPLIT)

sat_lte_dataset = tf.data.Dataset.from_tensor_slices((X,tf.one_hot(Y,N_CLASSES)*CLASS_WEIGHTS)).shuffle(1000)
train_set = sat_lte_dataset.take(split).batch(BATCH_SIZE)
test_set = sat_lte_dataset.skip(split).batch(BATCH_SIZE)

In [None]:
# Model hyperparameters
LEARNING_RATE = 0.0001
LAYER_DEPTH   = 5
ROOT_FILTERS  = 64
DROPOUT       = 0.0

# Building the model
unet_model = unet.build_model(X.shape[1], X.shape[2],
                          channels = X.shape[3],
                          num_classes = N_CLASSES,
                          layer_depth = LAYER_DEPTH,
                          filters_root= ROOT_FILTERS,
                          dropout_rate = DROPOUT,
                          padding = "same"
                          )

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="/content/unet/"+datetime.datetime.now().strftime("%y%m%d-%H%M"), histogram_freq=1)

unet_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=[tf.keras.metrics.CategoricalAccuracy()])

In [None]:
# Dataset parameters 
EPOCHS        = 5
model_history = unet_model.fit(train_set, 
                               validation_data=test_set,
                               epochs=EPOCHS, 
                               batch_size=BATCH_SIZE,
                               callbacks=[tensorboard_callback])