In [0]:
!pip install git+https://github.com/tdrobbins/unet.git@nchannels

In [0]:
%load_ext tensorboard
import tensorflow as tf
import unet
from PIL import Image
import pandas as pd
import matplotlib.pyplot as pp
import seaborn as sns
from glob import glob
from google.colab import files
import numpy as np

# Preprocessing

In [0]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

In [0]:
def image_crop(img, size):
  imgwidth, imgheight = img.size
  rows = np.int(imgheight/size)
  cols = np.int(imgwidth/size)
  output_images = []
  for i in range(rows):
      for j in range(cols):
          new_area= (j*size, i*size, (j+1)*size, (i+1)*size)
          new_image = img.crop(new_area)
          output_images.append(new_image)
  return output_images

def load_data(img_files, crop_size=None, resize=None, scale=None):
  imgs = [Image.open(img).resize((5120,5120), Image.ANTIALIAS) for img in img_files]
  
  if crop_size:
    cropped_imgs = []
    for img in imgs:
      cropped_imgs.extend(image_crop(img, crop_size))
    imgs = cropped_imgs
    
  if resize: imgs = [img.resize((resize, resize),Image.ANTIALIAS) for img in imgs]

  imgs = np.asarray([np.asarray(img) for img in imgs])

  if scale: imgs = imgs*scale 

  return imgs

In [0]:
# Loading and preprocessing data

# Preprocessing paramters
CROP        = 1024    # split raw images into many CROPxCROP images
RESIZE      = 256    # downsample cropped image to RESIZExRESIZE
N_CLASSES   = 4      # 2, 4, or 10
CELLTOWERS  = True  # currently doesn't work... may require forking unet due to dumb bug

# Loading satellite images into an array
sat_files = np.sort(glob("/content/drive/My Drive/cs230/sen2_composite/*.tif"))
X = load_data(sat_files,crop_size=CROP,resize=RESIZE,scale=1./256.)

# Loading nclasses version of lte mask images into an array
lte_files = np.sort(glob("/content/drive/My Drive/cs230/lte/cat{}*.tif".format(N_CLASSES)))
Y = load_data(lte_files,crop_size=CROP,resize=RESIZE)

# Resizing and scaling images

if CELLTOWERS:
  cell_files = np.sort(glob("/content/drive/My Drive/cs230/celltowers/*.tif"))
  cell_arr = load_data(cell_files,crop_size=CROP,resize=RESIZE)
  X = np.asarray([np.dstack((x,c)) for x,c in zip(X,cell_arr)])
  

# Model Training

In [0]:
%tensorboard --logdir /content/unet/

In [0]:
 # Setting up the datasets for tf
TRAIN_SPLIT   = 0.8
BATCH_SIZE    = 8
CLASS_WEIGHTS = [5,5,5,1]

split = np.round(X.shape[0]*TRAIN_SPLIT)

sat_lte_dataset = tf.data.Dataset.from_tensor_slices((X,tf.one_hot(Y,N_CLASSES)*CLASS_WEIGHTS)).shuffle(1000)
train_set = sat_lte_dataset.take(split).batch(BATCH_SIZE)
test_set = sat_lte_dataset.skip(split).batch(BATCH_SIZE)

In [0]:
# Model hyperparameters
LEARNING_RATE = 0.0001
LAYER_DEPTH   = 5
ROOT_FILTERS  = 64
DROPOUT       = 0.0

# Building the model
unet_model = unet.build_model(X.shape[1], X.shape[2],
                          channels = X.shape[3],
                          num_classes = N_CLASSES,
                          layer_depth = LAYER_DEPTH,
                          filters_root= ROOT_FILTERS,
                          dropout_rate = DROPOUT,
                          padding = "same"
                          )

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="/content/unet/"+datetime.datetime.now().strftime("%y%m%d-%H%M"), histogram_freq=1)

unet_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=[tf.keras.metrics.CategoricalAccuracy()])

In [0]:
# Dataset parameters 
EPOCHS        = 5
model_history = unet_model.fit(train_set, 
                               validation_data=test_set,
                               epochs=EPOCHS, 
                               batch_size=BATCH_SIZE,
                               callbacks=[tensorboard_callback])

In [0]:
# TODO: Doesn't support shuffling
for i in range(0,15):
  pp.figure(figsize=(10,10))
  pp.subplot(133)
  pp.imshow(unet_model.predict(np.expand_dims(X[i],0))[0][:,:,:3].argmax(-1))

  pp.subplot(132)
  pp.imshow(Y[i])

  pp.subplot(131)
  pp.imshow(X[i])
  if i > N_TRAIN:
    pp.title("dev set image {}".format(i))