# Setup

### 1. Libaries 

In [None]:
import numpy as np
import tensorflow as tf

# import rasterio
import json
import geojson
import glob
import math
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv3D, Reshape, Conv2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# import tensorflow_data_validation as tfdv
from sklearn.metrics import r2_score
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

### 2. Training parameters

### NOTE FOR ME: check on which of these actually get used in the model fitting, and which ones aren't. Currently the call to create_model_function has no parameters (uses all defaults). RECONCILE.

##### These are the parameters that I landed on that work well with canopy cover and the simple model architecture in the function below. These parameters are the result of a series of gridsearchCV iterations. 

In [None]:
# BATCH_SIZE   = 2048
BATCH_SIZE = 64
NUM_EPOCHS = 20
HEIGHT = 30
WIDTH = 30
NUM_CHANNELS = 4
NUM_VALIDATION_RECORDS = 5000
BUFFER_SIZE = 10000
tfrecord_path = "/FILL/THIS/IN/WITH/THE/PATH/TO/YOUR/CNN/training_data/NAIP_x_LIDAR_training_30k_v4.TFrecord"

### 3. Some GPU management / checking

##### See what's available 

In [None]:
print(tf.config.list_physical_devices())
print(tf.__version__)

##### If you run into trouble where your GPU-utilizing code won't run because the GPU RAM is allocated (even though you're pretty sure it should be free), try tf.keras.backend.clear_session(), which sometimes helps 

In [None]:
tf.keras.backend.clear_session()

# Functions

### 1. parse_training_tfrecord

##### uses an example of what each training pair (image + label) should look like to unpack the training pairs from the training dataset file

In [None]:
# Function to parse the training TFRecord data
def parse_training_tfrecord(example_proto):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.string),
    }

    features = tf.io.parse_single_example(example_proto, feature_description)

    image = tf.io.decode_raw(features["image"], tf.float32)
    label = tf.io.decode_raw(features["label"], tf.float32)
    image = tf.reshape(image, [HEIGHT, WIDTH, NUM_CHANNELS])

    return image, label

### 2. create_model_function

##### contains the method for building the neural network. This particular network will have a 3D convolution layer (the filters have dimensions of kernel_dim x kernel_dim x number of bands), a flatten, and three dense layers (with dropouts between the layers). This is a simple formulation of a convolutional neural network. 

In [None]:
# Function to construct the CNN
def create_model_function(
    init_mode="uniform",
    optimiz0r="adam",
    batch_size=1024,
    kernel_dim=5,
    filter_no=32,
    Dense1_no=256,
    Dense2_no=32,
    learn_rate=0.001,
    momentum=0.2,
):
    kernel_size = (kernel_dim, kernel_dim, NUM_CHANNELS)

    model = tf.keras.models.Sequential(
        [
            Conv3D(
                filters=filter_no,
                kernel_size=kernel_size,
                input_shape=(HEIGHT, WIDTH, NUM_CHANNELS, 1),
                padding="valid",
                activation="relu",
                use_bias=True,
            ),
            Flatten(),
            Dense(Dense1_no, activation="relu"),
            Dropout(0.4),
            Dense(Dense2_no, activation="relu"),
            Dropout(0.4),
            Dense(units=1, activation="linear"),
        ]
    )

    if optimiz0r == "adam":
        optimizer = Adam(learning_rate=learn_rate)

    model.compile(loss="mean_squared_error", optimizer=optimizer)

    return model

### 3. parse_imagery_tfrecord

##### This function is used when it's time to make predictions (maps). It is used to parse the series of .tfrecord files that contain the image data from our area of interest. Data will be streamed from those .tfrecord files, fed into the model, and in the end we will have a number of predictions that matches the number of 30m pixels contained in our area of interest. 

In [None]:
# Function to map over the NAIP data (in TFRecords) to get it in the format that our model takes
def parse_imagery_tfrecord(serialized_example):
    feature = {
        "B": tf.io.FixedLenFeature([900], tf.float32),
        "G": tf.io.FixedLenFeature([900], tf.float32),
        "N": tf.io.FixedLenFeature([900], tf.float32),
        "R": tf.io.FixedLenFeature([900], tf.float32),
    }
    example = tf.io.parse_single_example(serialized_example, feature)

    # Convert the input features to the format expected by the model
    B = tf.reshape(example["B"], [30, 30])
    G = tf.reshape(example["G"], [30, 30])
    R = tf.reshape(example["R"], [30, 30])
    N = tf.reshape(example["N"], [30, 30])
    image = tf.stack([B, G, R, N], axis=-1)

    return tf.expand_dims(image, axis=-1)

# Read in the training data. Create training and validation datasets

In [None]:
# Connection object
full_dataset_con = tf.data.TFRecordDataset(tfrecord_path)

# Map the parsing function over the connection object
full_dataset = full_dataset_con.map(
    parse_training_tfrecord, num_parallel_calls=tf.data.AUTOTUNE
)

# Split the training data in train and validation sets
validation_dataset = full_dataset.take(NUM_VALIDATION_RECORDS)
train_dataset = full_dataset.skip(NUM_VALIDATION_RECORDS)


# function to augment the images
def augment(image):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    return image


# apply data augmentation to the images in the dataset
train_dataset = train_dataset.map(lambda x, y: (augment(x), y))

# Do some other mysterious things to the data
train_dataset = train_dataset.shuffle(buffer_size=10000)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.repeat(NUM_EPOCHS)

validation_dataset = validation_dataset.batch(BATCH_SIZE)

# Training pt. 1: Run a gridsearchCV on the model hyperparameters

##### CAUTION: This code fits here in the process, but is ***not*** written for the most current data format. I ran gridsearchCV on training data stored in a different format (not in .tfrecords, and therefore parsed differently). All of the code in the cell below ought to apply / be valid, but will need to be changed slightly to accomodate the .tfrecord flow.     

##### SEE NOTES next to the lists that go into the grid to see what I came up w/ for model performance scores while adjusting the parameters (partway by hand). 

##### Why did I do the gridesearchCV partway by hand? I did multiple small, successive searches (as opposed to setting up one big ol' grid search and letting it rip) because of a memory leak in the gridsearchCV process. I could watch my system memory usage creep up and up over the hours, and by hour 7 or so the system memory would be exhaused and the process would crash. So, I had to break the job down in to smaller searches.

##### Has tensorflow fixed the memory leak bug? Will this bug not be an issue on a windows system (as opposed to Linux)? Hard to tell, but worth a shot trying it out! Just keep an eye on the process. 

#####  The general wisdom I gleaned from various sources was to test the optimizer and learning rate early on, then start fiddling with the stuff pertaining to the model architecture ( kernel size and number of filters on the Conv3D, number of nodes on the dense layers, etc....


In [None]:
def create_model_2(
    init_mode="uniform",
    optimiz0r="adam",
    batch_size=10,
    kernel_dim=2,
    filter_no=16,
    Dense1_no=128,
    Dense2_no=64,
    learn_rate=0.01,
    momentum=0.2,
):
    kernel_size = (kernel_dim, kernel_dim, 4)

    model = tf.keras.models.Sequential(
        [
            Conv3D(
                filters=filter_no,
                kernel_size=kernel_size,
                input_shape=(30, 30, 4, 1),
                padding="same",
                activation="relu",
                use_bias=True,
            ),
            Flatten(),
            Dense(Dense1_no, activation="relu"),
            Dropout(0.4),
            Dense(Dense2_no, activation="relu"),
            Dropout(0.4),
            Dense(units=1, activation="linear"),
        ]
    )

    if optimiz0r == "adam":
        optimizer = Adam(learning_rate=learn_rate)

    model.compile(loss="mean_squared_error", optimizer=optimizer)

    return model


# Random seed (?)
np.random.seed(seed=42)

# create the sklearn model for the network
model_init = KerasRegressor(
    model=create_model_2,
    epochs=20,
    verbose=2,
    kernel_dim=2,
    filter_no=16,
    Dense1_no=32,
    Dense2_no=16,
    learn_rate=0.01,
    momentum=0.2,
)

# we choose the initializers that came at the top in our previous cross-validation!!
# init_list = ['uniform', 'lecun_uniform', 'normal', 'zero', 'glorot_normal', 'glorot_uniform', 'he_normal', 'he_uniform']

# grid search for initializer, batch size and number of epochs
# param_grid = dict(epochs=epochs, batch_size=batches, init=init_mode)
batch_list = [1024]  # Batch size of 1000 had the best R2 of 0.57 when epochs was 20
epoch_list = [20]  # Best R2 of 0.57 when epochs was 20 (tested [20,30,40,50,100,200]).
kernel_dim_list = [
    5
]  # Best R2 of 0.58 when kernel_dim was 5 (tested [2,3,4,5,6,7,8,9,10]).
filter_list = [
    32
]  # Best R2 of 0.5755 at 64 filters, but 32 filters not too far behind (0.5721) and much faster (tested [16,32,64,128,256]).
Dense1_list = [256]
Dense2_list = [32]
learn_list = [0.0001, 0.001, 0.002, 0.005, 0.01, 0.02]
momentum_list = [0.2]
param_grid = dict(
    epochs=epoch_list,
    batch_size=batch_list,
    kernel_dim=kernel_dim_list,
    filter_no=filter_list,
    Dense1_no=Dense1_list,
    Dense2_no=Dense2_list,
    learn_rate=learn_list,
    momentum=momentum_list,
)
grid = GridSearchCV(
    estimator=model_init, param_grid=param_grid, cv=2, error_score="raise"
)

# Print the keys of the estimator object "grid", to see what all is involved
# print(grid.get_params().keys())

grid_result = grid.fit(X=X_train, y=y_train)

# Training pt. 2: Fit the model with the selected parameters

In [None]:
model = create_model_function()
model.compile(
    loss="mean_squared_error", optimizer="adam", metrics="RootMeanSquaredError"
)
history = model.fit(
    train_dataset, verbose=1, validation_data=validation_dataset, epochs=NUM_EPOCHS
)
# history = model.fit(train_dataset,verbose=1,validation_data = validation_dataset, steps_per_epoch = 29,epochs = NUM_EPOCHS)

# With batch size 1024, we got an r squared of 0.42
# With batch size 2048, we got an r squared of 0.49
# With batch size 1500, we got an r squared of
# Then, expanding validation dataset to 5000.....maybe a little better?

#### Save the model

In [None]:
# model.save('/md1/data/NAIP/trained_models/CNN_canCov_v1.1', save_format='tf')
model.save(
    "/FILL/THIS/IN/WITH/THE/PATH/TO/YOUR/trained_models/CNN_canCov_v1.1_batch64",
    save_format="tf",
)

# Validation

### Use the validation data to test the model predictions

In [None]:
# Make predicted_Labels
predictions = model.predict(validation_dataset)
# print(f'The first ten predicted values {predictions[0:10]}')

# Make val_labels
val_labels = []
for _, label in validation_dataset:
    val_labels.append(label.numpy())
val_labels = np.concatenate(val_labels)
# print(f'The first ten labels in the validation dataset {val_labels[0:10]}')

# Compute R squared
r_squared = r2_score(val_labels, predictions)

# Reformating predictions (x_forfit) and val_labels (y_forfit) for what LinearRegression wants
x_forfit = np.array(predictions).reshape((-1, 1))
y_forfit = val_labels

# Do the regression
lm = LinearRegression().fit(X=x_forfit, y=y_forfit)
r_sq = lm.score(x_forfit, y_forfit)
m = lm.coef_[0]
b = lm.intercept_

# Print regression info
print(f"R squared value: {r_sq: .3f}")
print(f"intercept: {b[0]: .2f}")
print(f"slope: {m[0]: .2f}")

# Compute RMSE and print it
RMSE = math.sqrt(mean_squared_error(val_labels, predictions))
print(f"RMSE for the predicted canopy height: {RMSE: .2f}")

# Heatmap scatterplot
fig = plt.figure(figsize=(10, 8))
plt.hexbin(x=predictions, y=val_labels, gridsize=20)
plt.plot(x_forfit, m * x_forfit + b, color="red")
plt.title(f"Canopy Cover - Rsquared: {r_sq: .3f}, RMSE {RMSE: .2f}", fontsize=20)
plt.xlim([5, 70])
plt.ylim([5, 70])
plt.plot([5, 70], [5, 70], "--", color="black")
plt.colorbar()
plt.ylabel("Actual Cover (%)", fontsize=20)
plt.xlabel("Predicted cover (%)", fontsize=20)
plt.show()

### Optional: make plots to do a sanity check on the data. Clear session afterwords

In [None]:
# labels = []
# image_list = []

# for image, label in full_dataset:
#     labels.append(label.numpy())
#     image_list.append(image.numpy())

# labels = np.concatenate(labels)
# X = np.array(image_list)

In [None]:
# #i_list = [1414,20000,24020,0,500,1500,]
# i_list = [24000,130,24020,29870,120,1321]

# fig, axs = plt.subplots(nrows = 2, ncols = 3, figsize = (36,24))
# axs_flat = axs.flatten()

# for i in range(0,len(i_list)):
#     ax = axs_flat[i]
#     samp = i_list[i]
#     ax.imshow(np.flip(X[samp,:,:,0:3], 2))
#     ax.set_title(f'canopy cover: {labels[samp]} %', fontsize = 24)
# plt.show()

# Make a map

In [None]:
batch_size = 10000
# mixer         = json.load(open('/scratch/CNN/image_prediction/Malheur_small_bbox_2011-mixer.json'))
file_patterns = sorted(
    glob.glob(
        "/FILL/THIS/IN/WITH/THE/PATH/TO/YOUR/imagery_for_prediction/Malheur_small_bbox_2011_nativeProj-0000*.tfrecord"
    )
)
dataset = tf.data.TFRecordDataset(file_patterns)
dataset = dataset.map(parse_imagery_tfrecord).batch(batch_size)
predictions = model.predict(dataset)
# dataset      = tf.data.TFRecordDataset(mixer.tfrecord_file_paths())
# dataset      = dataset.map(map_func=parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# dataset      = dataset.batch(batch_size)
# predictions = model.predict(dataset)
np.savetxt(
    "/FILL/THIS/IN/WITH/THE/PATH/TO/YOUR/predicted_vectors/Malheur_small_bbox_2011.csv",
    predictions,
    delimiter=",",
)

In [None]:
np.shape(predictions_smallest)

#### Write the predictions out as a GeoTIFF

In [None]:
# Get the metadata
import json

with open(
    "/FILL/THIS/IN/WITH/THE/PATH/TO/YOUR/imagery_for_prediction/Malheur_small_bbox_2011_nativeProj-mixer.json",
    "r",
) as f:
    meta = json.load(f)

ncol = meta["patchesPerRow"]
nrow = meta["totalPatches"] / ncol
affine = meta["projection"]["affine"]["doubleMatrix"]
affine[0] = affine[0] * 30
affine[4] = affine[4] * 30
crs = meta["projection"]["crs"]
print(meta)

pred_raster = predictions.reshape(int(nrow), int(ncol))

profile = dict(
    dtype=rasterio.float32,
    count=1,
    compress="lzw",
    height=nrow,
    width=ncol,
    driver="GTiff",
    crs=crs,
    transform=affine,
)

with rasterio.open(
    "/FILL/THIS/IN/WITH/THE/PATH/TO/YOUR/CNN/predicted_vectors/test.tif", "w", **profile
) as fh:
    fh.write(pred_raster, 1)

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(pred_raster)
plt.show()