# Autoencoder Model

**Authors:** Christopher Sun, Jai Sharma, Milind Maiti

**Date:** 2022.06.16

**Description:** This module concerns the elevation data for the satellite images (from the IEEE dataport). The tasks here include:

1. Filter tif file elevation masks based on good and bad data.
2. Define a U-Net Deep Learning framework for predicting the pixelwise elevation of an RGB satellite image.
3. Interpolate the bad data discovered in number (1) using the U-Net Deep Learning model trained in number (2).

## Import Libraries

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.pyplot import figure
from mpl_toolkits import mplot3d
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, callbacks, Sequential, Input, Model
from tensorflow.keras.layers import *
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
import math
import scipy
import pickle
import gc 
from IPython.display import display

# Print Confirmation
print("Setup Complete")

## Load Data and Determine Good Data vs Bad Data

The following is the number of images from each location:

**San Fernando**: 2325

**Atlanta**: 705

**Jacksonville**: 1098

**Omaha**: 1796

In [None]:
# Load Data
j2ks = pickle.load(open("/kaggle/input/geocentric-pose-analysis-of-satellite-imagery/j2k_imgs.dat", "rb"))
tifs = pickle.load(open("/kaggle/input/geocentric-pose-analysis-of-satellite-imagery/tif_imgs.dat", "rb"))

In [None]:
# San Fernando
sf_threshold = 3000
for i in range(2325):
    if np.max(tifs[i]) < 65535:
        tifs[i][(tifs[i] >= sf_threshold) & (tifs[i] < 65535)] = np.median(tifs[i])

In [None]:
# Atlanta
atl_threshold = 4000
max_num = 100 # Max number of pixels above 4000 to be considered an outlier
              # Any tifs with more than 100 pixels will be considered good data
for i in range(2325, 3030):
    if (np.sum(tifs[i] >= atl_threshold) < max_num) and (np.max(tifs[i] < 65535)):
        tifs[i][(tifs[i] >= atl_threshold) & (tifs[i] < 65535)] = np.median(tifs[i])
    else:
        pass

In [None]:
# Sanity Check
threshold = 10000
for i in range(tifs.shape[0]):
    if (np.sum(tifs[i] >= threshold) < 50) and (np.max(tifs[i] < 65535)):
        tifs[i][(tifs[i] >= threshold) & (tifs[i] < 65535)] = np.median(tifs[i])
    else:
        pass

for i in range(tifs.shape[0]):
    tifs[i][(tifs[i] >= 20000) & (tifs[i] < 65535)] = np.median(tifs[i])

In [None]:
# Images without NaNs (good data)
idxs_keep = np.argwhere(np.sum(np.sum(tifs==65535, axis=1), axis=1) == 0).reshape(-1)
tifs = tifs[idxs_keep]
j2ks = j2ks[idxs_keep]

In [None]:
# Visualize distribution of tif values
tifs_flattened = tifs.flatten()
plt.figure(figsize=(8, 5))
plt.hist(tifs_flattened[(tifs_flattened >= 10000) & (tifs_flattened < 65535)])
plt.xlabel("Pixel Elevation (cm)", fontsize=12)
plt.ylabel("No. of Pixels", fontsize=12)
plt.title("Outlier Analysis")
plt.show()

In [None]:
# Replace the maximum of each example with the minimum if this value is over the threshold
# threshold = 10000
# for i in range(tifs.shape[0]):
#     if (np.sum(tifs[i] >= threshold) < 10) and (np.max(tifs[i] >= threshold)):
#         tifs[i][tifs[i] >= threshold] = np.median(tifs[i])
#     else:
#         pass

In [None]:
# Filter based on good data
# idxs_keep = np.argwhere(np.sum(np.sum(tifs==65535, axis=1), axis=1) != 0).reshape(-1)
# tifs = tifs[idxs_keep]
# j2ks = j2ks[idxs_keep]

In [None]:
# Find the number of NaN values for each example and display a corresponding histogram
num_nans = np.sum(np.sum(tifs == 65535, axis=1), axis=1)
plt.hist(num_nans[num_nans!=0])
plt.show()

In [3]:
# Display several examples of RGB images and their corresponding tif files with missing data
fig, axs = plt.subplots(1,4, figsize=(14,7))
axs[0].imshow(j2ks[72])
axs[0].axis("off")
axs[0].set_title("RGB Image")
axs[1].imshow(tifs[72])
axs[1].axis("off")
axs[1].set_title("Elevation Mask")
axs[2].imshow(j2ks[59])
axs[2].axis("off")
axs[2].set_title("RGB Image")
axs[3].imshow(tifs[59])
axs[3].axis("off")
axs[3].set_title("Elevation Mask")
fig.show()

In [None]:
# Find the median pixel value for each image
tif_medians = np.median(tifs, axis=(1,2))

In [None]:
# Find the number of tifs with median of either min or max value
# These could be examples of bad data
count = 0
for i in range(tif_medians.shape[0]):
    min_value = np.min(tifs[i])
    max_value = np.max(tifs[i])
    if (tif_medians[i] == min_value) or (tif_medians[i] == max_value):
        count += 1
print(count)

In [None]:
# Splitting elevation pixels into two classes based on if the median value is equal to 
# either the minimum or maximum value in the image. If the tif is kept, then it is 
# turned into a binary coloring based on the value of the median. 
for i in range(tif_medians.shape[0]):
    if tif_medians[i] == np.min(tifs[i]):
        tifs[i] = (tifs[i] != np.min(tifs[i]))
    elif tif_medians[i] == np.max(tifs[i]):
        tifs[i] = (tifs[i] == np.max(tifs[i]))
    else:
        tifs[i] = tifs[i] >= tif_medians[i]
        
tifs = tifs.astype("uint8")

In [None]:
# Show a random RGB image and the corresponding tif file.
# At this point, all tif files should be good data.
q = int(np.random.uniform(tifs.shape[0]))
print(q)
plt.imshow(j2ks[q])
plt.show()
plt.imshow(tifs[q])
plt.show()

## Build and Train U-Net Model

In [None]:
tifs_flattened = tifs.flatten()

In [None]:
# Define model architecture
def build_model():
    inputs = Input(shape=(256, 256, 3))
    conv1 = Conv2D(32, (5, 5), padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation("relu")(conv1)
    conv1 = Conv2D(32, (5, 5), padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation("relu")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(64, (5, 5), padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation("relu")(conv2)
    conv2 = Conv2D(64, (5, 5), padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation("relu")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (5, 5), padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation("relu")(conv3)
    conv3 = Conv2D(128, (5, 5), padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation("relu")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(256, (5, 5), padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation("relu")(conv4)
    conv4 = Conv2D(256, (5, 5), padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation("relu")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (5, 5), padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation("relu")(conv5)
    conv5 = Conv2D(512, (5, 5), padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation("relu")(conv5)
    pool5 = MaxPooling2D(pool_size=(2, 2))(conv5)

    up6 = concatenate([Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(256, (5, 5), padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation("relu")(conv6)
    conv6 = Conv2D(256, (5, 5), padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation("relu")(conv6)

    up7 = concatenate([Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (5, 5), padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation("relu")(conv7)
    conv7 = Conv2D(128, (5, 5), padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation("relu")(conv7)

    up8 = concatenate([Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (5, 5), padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation("relu")(conv8)
    conv8 = Conv2D(64, (5, 5), padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation("relu")(conv8)

    up9 = concatenate([Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (5, 5), padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation("relu")(conv9)
    conv9 = Conv2D(16, (5, 5), padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation("relu")(conv9)
  
    outputs = Conv2D(1, (1, 1), activation="relu")(conv9)

    model = Model(inputs=[inputs], outputs=[outputs])

    return model

In [None]:
# Create and compile the model
model = build_model()
model.compile(optimizer="adam", loss="mae", metrics=["mse"])

In [None]:
# View the model architecture summary
model.summary()

In [None]:
# Create the train and validation sets
X_t, X_v, y_t, y_v = train_test_split(j2ks, tifs, test_size=0.15, random_state=0)

In [None]:
# Train the model
history = model.fit(X_t, y_t, validation_data=(X_v, y_v), epochs=100)

In [None]:
# Visualize the learning curves
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.show()
plt.plot(history.history["mse"])
plt.plot(history.history["val_mse"])
plt.show()

## Vizualize Accuracy

In [None]:
# Visualize the accuracy of the U-Net model
tifs_test = model.predict(j2ks[idxs_keep]).reshape(j2ks[idxs_keep].shape[0], 256, 256)
fig, ax = plt.subplots(5,3, figsize=(9,16))
ax[0,0].imshow(j2ks[1])
ax[0,0].axis("off")
ax[0,1].imshow(tifs[1], vmax=8000)
ax[0,1].axis("off")
ax[0,2].imshow(tifs_test[1])
ax[0,2].axis("off")
for i in range(1, 5):
    q = int(np.random.uniform(tifs_test.shape[0]))
    ax[i,0].imshow(j2ks[q])
    ax[i,0].axis("off")
    ax[i,1].imshow(tifs[q], vmax=5000)
    ax[i,1].axis("off")
    ax[i,2].imshow(tifs_test[q])
    ax[i,2].axis("off")
fig.show()

## Interpolate Bad Data

In [None]:
# Find tifs which need to be interpolated
idxs_keep = np.argwhere(np.sum(np.sum(tifs==65535, axis=1), axis=1) != 0).reshape(-1)

# tifs_test contains only the bad tifs
# idxs_keep contains only the indices of the bad tifs

for i in range(idxs_keep.shape[0]):
    third_quartile = np.percentile(tifs[idxs_keep[i]][tifs[idxs_keep[i]] != 65535], 75) 
    first_quartile = np.percentile(tifs[idxs_keep[i]][tifs[idxs_keep[i]] != 65535], 25) 
    iqr = third_quartile - first_quartile
    fence = third_quartile + 1.5 * iqr
    
    # First, interpolating all pixels other than the NaNs that are considered statistical outliers 
    fence_condition = tifs[idxs_keep[i]] >= fence
    tifs[idxs_keep[i]][fence_condition] = tifs_test[i][fence_condition]
    
    # Interpolating the NaNs
    nan_condition = tifs[idxs_keep[i]] == 65535
    tifs[idxs_keep[i]][nan_condition] = tifs_test[i][nan_condition]

In [None]:
# Show a random example of an interpolated tif file
q = idxs_keep[int(np.random.uniform(idxs_keep.shape[0]))]
plt.imshow(tifs[q])
plt.show()
plt.imshow(j2ks[q])
plt.show()

In [None]:
# Save the new tifs
pickle.dump(tifs, open("new_tifs.dat", "wb"))

## Calculate Metrics

Here are the results for the U-Net model:

**Train $R^2$:** 0.9264

**Validation $R^2$:**   0.8655 

In [None]:
# Calculate the R^2 metric for the U-Net model for both the train and validation sets
pred = model.predict(X_v).reshape(X_v.shape[0], 256, 256)
x_pred = model.predict(X_t).reshape(X_t.shape[0], 256, 256)
R2 = 1 - np.sum((y_v - pred)**2)/np.sum((y_v - np.mean(y_v))**2)
print(R2)

R2_train = 1 - np.sum((y_t - x_pred)**2)/np.sum((y_t - np.mean(y_t))**2)
print(R2_train)