# 1. Setup

Import the libraries, set up the GLOBAL variables, and run the google authentications 

In [None]:
# Import libraries
import os
import rasterio
import glob
import json
import tensorflow as tf
import folium
import ee
from pprint import pprint
from google.oauth2 import service_account
from google.cloud import storage

os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

# Specify the key for accessing the google cloud project
KEY_PATH = "/home/climber/.keys/ee-maptheforests-access-key.json"

# Our bucket
BUCKET_NAME = "naip_structure"

# Set the directories where we will want to grab TFrecords from on gcloud
GCLOUD_NAIP_DIR = "naip_Malheur_TFrecords/"

# Where to put them on Treenet
TREENET_NAIP_DIR = "/scratch/NAIP_data/malheur/2011/"

# Name of the output raster name
TIF_NAME = "Malheur_bbox_predicted_canCov_v1.tif"

# Assets folder in GEE
GEE_USER_FOLDER = "users/forestMapper"

# Input bands to the model
BANDS = ["B", "G", "R", "N"]

# The shape of patches expected by the model
KERNEL_SHAPE = [30, 30]

# Speicfy where the model file is on the different platforms
MODEL_ON_GCLOUD = "gs://" + BUCKET_NAME + "/CNN_models/CNN_canCov_v1"
MODEL_ON_TREENET = "/scratch/CNN/models/canCov_model_2011_v1/"

# Authenticate and intialize earth engine
ee.Authenticate()
ee.Initialize()

# The bounding box of the area we'd like to map over
MAP_REGION = ee.Geometry.Polygon(
    [
        [
            [-119.37239004502233, 44.48079613290612],
            [-118.57725454697545, 44.48079613290612],
            [-118.57725454697545, 44.81785572318615],
            [-119.37239004502233, 44.81785572318615],
        ]
    ],
    None,
    False,
)

# print(f'Available hardware: {tf.config.list_physical_devices()}')
# print(f'Tensorflow version {tf.__version__}')
# print(f'folium version {folium.__version__}')

# 2.  Check out the imagery for MAP_REGION and year of interest

In [None]:
# The NAIP image collection
naip = ee.ImageCollection("USDA/NAIP/DOQQ")

# The image input data is a cloud-masked median composite.
image = naip.filterDate("2011-01-01", "2011-12-31").median().clip(MAP_REGION)

# Use folium to visualize the imagery.
mapid = image.getMapId({"bands": ["R", "G", "B"], "min": 0, "max": 255})
map = folium.Map(location=[44.62157, -118.98257])
folium.TileLayer(
    tiles=mapid["tile_fetcher"].url_format,
    attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
    overlay=True,
    name="median composite",
).add_to(map)

map

# 3. Running the prediction

In [None]:
# clear the default graph and free up GPU memory
tf.compat.v1.reset_default_graph()

In [None]:
# with tf.device('/cpu:0'):
# m = tf.keras.models.load_model(MODEL_ON_TREENET)
m = tf.keras.models.load_model("/md1/data/NAIP/trained_models/CNN_canCov_v1.1/")
m.summary()

In [None]:
flist = os.listdir(TREENET_NAIP_DIR)
imageFilesList = sorted(
    [f"{TREENET_NAIP_DIR}{s}" for s in flist if ".tfrecord.gz" in s]
)
jsonFile = [s for s in flist if "mixer.json" in s][0]

with open(f"{TREENET_NAIP_DIR}{jsonFile}", "r") as f:
    mixer = json.load(f)
pprint(mixer)

# pprint(mixer)
patches = mixer["totalPatches"]


# Function to map over the NAIP data (in TFRecords) to get it in the format that our model takes
def parse_imagery_tfrecord(serialized_example):
    feature = {
        "B": tf.io.FixedLenFeature([900], tf.float32),
        "G": tf.io.FixedLenFeature([900], tf.float32),
        "N": tf.io.FixedLenFeature([900], tf.float32),
        "R": tf.io.FixedLenFeature([900], tf.float32),
    }
    example = tf.io.parse_single_example(serialized_example, feature)

    # Convert the input features to the format expected by the model
    B = tf.reshape(example["B"] / 255, [30, 30])
    G = tf.reshape(example["G"] / 255, [30, 30])
    R = tf.reshape(example["R"] / 255, [30, 30])
    N = tf.reshape(example["N"] / 255, [30, 30])
    image = tf.stack([B, G, R, N], axis=-1)

    return tf.expand_dims(image, axis=-1)


# Create a dataset from the TFRecord file(s) in Cloud Storage.
imageDataset = tf.data.TFRecordDataset(imageFilesList, compression_type="GZIP")
imageDataset = imageDataset.map(parse_imagery_tfrecord).batch(500)

print("Running predictions...")
predictions = m.predict(imageDataset, verbose=1)

In [None]:
print(len(predictions))

In [None]:
# OLD: used to be a function called doPrediction()
# predictions = doPrediction()

# 4. Write the predictions as a .tif

In [None]:
def writeRaster(predictions, GDRIVE_FOLDER, TIF_NAME):
    # Find the mixer file
    flist = !gsutil ls 'gs://'{BUCKET}'/'{TF_FOLDER}
    jsonFile = [s for s in flist if "mixer.json" in s][0]

    # Load the contents of the mixer file to a JSON object.
    jsonText = !gsutil cat {jsonFile}

    # Get a single string w/ newlines from the IPython.utils.text.SList
    mixer = json.loads(jsonText.nlstr)

    ncol = mixer["patchesPerRow"]
    nrow = mixer["totalPatches"] / ncol
    affine = mixer["projection"]["affine"]["doubleMatrix"]
    affine[0] = affine[0] * 30
    affine[4] = affine[4] * 30
    crs = mixer["projection"]["crs"]

    pred_raster = predictions.reshape(int(nrow), int(ncol))

    profile = dict(
        dtype=rasterio.float32,
        count=1,
        compress="lzw",
        height=nrow,
        width=ncol,
        driver="GTiff",
        crs=crs,
        transform=affine,
    )

    # with rasterio.open('gs://naip_structure/predicted_canCov/test.tif', 'w',**profile) as fh:
    with rasterio.open(
        f"/gdrive/My Drive/{GDRIVE_FOLDER}/{TIF_NAME}", "w", **profile
    ) as fh:
        fh.write(pred_raster, 1)

In [None]:
writeRaster(predictions, GDRIVE_FOLDER, TIF_NAME)