# Generate CONUS LFMC Image
#### Description
Creates a GeoTiff image of LFMC predictions that can be used to produce LFMC maps.

#### Input Images
1. An image of auxiliary data - latitude, longitude, elevation, slope, aspect and climate zone
2. Images of MODIS data for at least 1 year prior to the mapping date
3. Images of PRISM data for at least 1 year prior to the mapping date  
Note: Band names for PRISM data are confusing. When GEE converts an image collection to an image, it includes the date in the band name. Timestamps on PRISM data are midday, so when they are converted (rounded) to a date, this becomes the following day. So in the extracted GeoTiffs, bands for 2016-10-01 data will named for 2016-10-02!

#### Other Inputs
1. Model directory - this should contain "run" directories - one for each model in the ensemble.
2. Data used to train the model - The notebook extracts normalisation bounds and one-hot encodings needed to prepare the input data. The normalisation bounds are saved to csv files, so if these files already exist, bounds can be loaded from these instead.
3. Legend file for Koppen climate zones. This should be a CSV as created by the "Extract Auxiliary Data.ipynb" notebook. Used to convert the climate zone numbers in the auxiliary input into climate zone codes


In [None]:
import glob
import numpy as np
import os
import pandas as pd
import time

from sklearn.preprocessing import OneHotEncoder
from datetime import date, datetime, timedelta

In [None]:
import initialise
import common
from model_list import ModelList
from model_parameters import ModelParams
from data_prep_utils import reshape_data, normalise, load_bounds, create_onehot_enc

### Parameters
The next two cells set up the parameters for a specific map. Uncomment and run the cell for the required map.

In [None]:
LFMC_SCENARIO = 'Nowcasting'
MODEL_DIR = os.path.join(common.MODELS_DIR, 'final_models', 'test0')
TS_OFFSET = 1

In [None]:
# LFMC_SCENARIO = 'Projection'
# MODEL_DIR = os.path.join(common.MODELS_DIR, 'final_models', 'test1')
# TS_OFFSET = 91

Set the mapping date. In the paper, maps for the dates '2018-04-01' and '2018-10-01' are provided

In [None]:
MAP_DATE = '2018-04-01'
# MAP_DATE = '2018-10-01'

### Other settings

In [None]:
DERIVED_MODEL = common.ANALYSIS_MODEL
IMAGE_DIR = 'GEE_EPSG-4326_2000'
TS_DAYS = 365
INTERP_METHOD = 'linear'
INTERP_MAXGAP = None
INTERP_DIRECTION = 'both'
NODATA = -999
FLOAT_PRE = 5
FN_DATE_POS = (-14, -4)
FN_DATE_FORMAT = '%Y-%m-%d'

### Directories and Files

In [None]:
aux_train_file = os.path.join(common.DATASETS_DIR, 'samples_730days.csv')
KOPPEN_LEGEND = os.path.join(common.SOURCE_DIR, 'Climate_zones.csv')
OUTPUT_FILE = os.path.join(common.MAPS_DIR, 'LFMC_maps', f'{LFMC_SCENARIO}_{DERIVED_MODEL}_{MAP_DATE}.tif')
OUTPUT_FILE

In [None]:
with open(os.path.join(MODEL_DIR, 'model_params.json'), 'r') as f:
    model_params = ModelParams(f)

### Load the models
Use the "base" model and ensemble the predictions

In [None]:
models = ModelList.load_model_set(MODEL_DIR)
if isinstance(models, ModelList):
    for i, model in enumerate(models):
        print(f"Loading model {i}")
        model.load_model(DERIVED_MODEL)
else:
    print(f"Loading single model")
    models = [models]
    models[0].load_model(DERIVED_MODEL)

## Get Transformation parameters
The transformation parameters need to be obtained from the training data, so the mapping data is transformed the same way
- Get the normalisation ranges from the MODIS and PRISM data
- Get the one-hot encodings from the auxiliary data

### Load normalisation bounds

In [None]:
modis_bounds = load_bounds('modis', MODEL_DIR) #np.genfromtxt(os.path.join(MODEL_DIR, 'modis_bounds.csv'), delimiter=',')
prism_bounds = load_bounds('prism', MODEL_DIR) #np.genfromtxt(os.path.join(MODEL_DIR, 'prism_bounds.csv'), delimiter=',')

### Auxiliary One-hot Encoder
Create the one-hot encoder and encode the climate zones

In [None]:
if model_params['auxOneHotCols']:
    onehot_enc = create_onehot_enc(model_params['auxOneHotCols'], model_dir=MODEL_DIR)
    czones = pd.read_csv(KOPPEN_LEGEND, index_col=0)
    czones_enc = pd.DataFrame(
        onehot_enc.transform(czones[['Code']].to_numpy()),
        index=czones.index,
        columns=onehot_enc.get_feature_names_out(['Czone']))
else:
    onehot_enc = None
    czones_enc = None

### If auxiliary variables are specified by number, get the names

In [None]:
if isinstance(model_params['auxColumns'], int):
    samples = pd.read_csv(aux_train_file, index_col=0)
    model_params['auxColumns'] = list(samples.columns[-model_params['auxColumns']:])
    del samples
print('Auxiliaries:', model_params['auxColumns'])

## Prepare the mapping data

In [None]:
def days_between(date1, date2):
    return (datetime.strptime(date2, FN_DATE_FORMAT) - datetime.strptime(date1, FN_DATE_FORMAT)).days

In [None]:
def file_after(fn, date_):
    return fn[FN_DATE_POS[0]:FN_DATE_POS[1]] > date_

In [None]:
def select_files(file_list, start_date, end_date):
    for i, fn in enumerate(file_list):
        if file_after(fn, start_date):
            break
    i = i - 1 if i > 0 else 0
    keep_last = True
    for j, fn in enumerate(file_list[i:], i):
        if file_after(fn, end_date):
            keep_last = False
            break
    j = j + 1 if keep_last else j
    return file_list[i:j]

#### Function to prepare auxiliary data
- Reads the auxiliary data into a dataframe with column names set to band descriptions. Adds normalised longitude and latitude (unnormalised values retained for referencing/alignment. Replaces elevation, slope, and aspect with normalized values.

In [None]:
def get_aux_data(aux_image, bands, offsets, sizes):
    aux_data = aux_image.ReadAsArray(xoff=offsets[0], yoff=offsets[1], xsize=sizes[0], ysize=sizes[1]).round(FLOAT_PRE)
    aux_data = aux_data.transpose(1, 2, 0).reshape(aux_data.shape[1] * aux_data.shape[2], aux_data.shape[0])
    aux_df = pd.DataFrame(aux_data, columns=bands)
    aux_df = aux_df.replace(aux_image.GetRasterBand(1).GetNoDataValue(), np.NAN).dropna()
    if model_params['auxOneHotCols']:
        aux_df = aux_df.merge(czones_enc, left_on='climate_zone', right_index=True)
    aux_norm = aux_df.drop(['elevation', 'slope', 'aspect', 'climate_zone'], axis=1)
    # doy = datetime.strptime(MAP_DATE, FN_DATE_FORMAT).timetuple().tm_yday
    # doy = normalise(doy, method='range', range=(1, 366), out_range=(-np.pi, np.pi))
    # aux_norm["Day_sin"] = round(np.sin(doy), FLOAT_PRE)
    # aux_norm["Day_cos"] = round(np.cos(doy), FLOAT_PRE)
    longitude = normalise(aux_df.longitude, method='range', data_range=(-180, 180), scaled_range=(-np.pi, np.pi))
    aux_norm["Long_sin"] = longitude.transform(np.sin).round(FLOAT_PRE)
    aux_norm["Long_cos"] = longitude.transform(np.cos).round(FLOAT_PRE)
    aux_norm["Lat_norm"] = normalise(aux_df.latitude, method='range', data_range=(-90, 90)).round(FLOAT_PRE)
    # aux_norm["Elevation"] = normalise(aux_df.elevation.round(0), method='range', range=(0, 6000)).round(FLOAT_PRE)
    # aux_norm["Slope"] = normalise(aux_df.slope.round(0), method='range', range=(0, 90)).round(FLOAT_PRE)
    # aspect = normalise(aux_df.aspect.round(0), method='range', range=(0, 360), out_range=(-np.pi, np.pi))
    # aux_norm["Aspect_sin"] = aspect.transform(np.sin).round(FLOAT_PRE)
    # aux_norm["Aspect_cos"] = aspect.transform(np.cos).round(FLOAT_PRE)
    return aux_norm

#### Function to compute number of pixels between two locations
- Input locations can be single values or list-like, but should have the same dimensions
- Pixel_size can be a single value, or a value for each pair of location elements
- Return value has the same shape as the input locations

In [None]:
def num_pixels(start_loc, end_loc, pixel_size, convert=np.round):
    return_type = type(start_loc)
    pixels = (np.array(end_loc) - np.array(start_loc)) / np.array(pixel_size)
    if convert:
        pixels = convert(pixels)
    return return_type(pixels)

#### Function to prepare time-series data
- Reads data from all images (assumes they are in date order and first image start on ts_start)
- Extracts data for the relevant days and pixels
  - index parameter indicates required pixels
- Interpolates along day axis to fill missing values
- Normalises the data using the bounds

In [None]:
def get_ts_data(images, offsets, sizes, channels, ts_start, start_date, index, bounds, scaled_range):
    ts_data = []
    for image in images:
        ts_data.append(image.ReadAsArray(xoff=offsets[0], yoff=offsets[1], xsize=sizes[0], ysize=sizes[1]))
    start = channels * days_between(ts_start, start_date)
    end = start + channels * TS_DAYS
    new_shape = (ts_data[0].shape[1] * ts_data[0].shape[2], end - start)
    ts_data = np.concatenate(ts_data, axis=0)[start:end].transpose((1, 2, 0)).reshape(new_shape)[index]
    ts_data = reshape_data(ts_data, channels)
    df = []
    for b in range(channels):
        df.append(pd.DataFrame(ts_data[:, :, b]).interpolate(axis=1, method=INTERP_METHOD, limit=INTERP_MAXGAP, limit_direction=INTERP_DIRECTION))
    ts_data = np.stack(df, axis=-1)
    return normalise(ts_data, method='range', data_range=bounds, scaled_range=scaled_range)

#### Function to predict LFMC
- Models is a ModelList, with trained model loaded
- Returns a series indexed by index parameter

In [None]:
def predict(models, derived_model, x_aux, x_modis, x_prism, index):
    preds = []
    X = {'modis': x_modis, 'prism': x_prism, 'aux': x_aux}
    start_time = time.time()
    for num, model in enumerate(models):
        preds.append(model.predict(X, derived_model))
    pred_time = round(time.time() - start_time, 2)
    print(f'Prediction time:', pred_time)
    preds = pd.DataFrame(preds, columns=aux_df.index)
    return [preds.mean(axis=0), preds.std(axis=0)]

### Create VRTs
Mosaic multiple images for the same date into a VRT

In [None]:
def mosaic_images(prefix):
    file_list = sorted(glob.glob(prefix + "*.tif"))
    days = sorted(list({fn[len(prefix) : len(prefix) + (FN_DATE_POS[1] - FN_DATE_POS[0])] for fn in file_list}))
    for day in days:
        print(prefix, day)
        gdal.BuildVRT(prefix + day + ".vrt", glob.glob(prefix + day + "*.tif"))

In [None]:
os.environ['GDAL_DATA'] = common.GDAL_DATA
os.environ['PROJ_LIB'] = common.PROJ_LIB
from osgeo import gdal
mosaic_images(os.path.join(common.GEE_MAPS_DIR, IMAGE_DIR, f'MODIS_'))

### Open Image Files
- Allows for multiple MODIS and PRISM images
- Assumes MODIS and PRISM image file names include first date in image in file name
- Assumes images are contiguous with no overlaps

In [None]:
aux_file = os.path.join(common.GEE_MAPS_DIR, IMAGE_DIR, f'conus_aux.tif')
aux_image = gdal.Open(aux_file, gdal.GA_ReadOnly)
bands = []
for b in range(1, aux_image.RasterCount+1):
    bands.append(aux_image.GetRasterBand(b).GetDescription())
    
end_date = datetime.strptime(MAP_DATE, FN_DATE_FORMAT) - timedelta(TS_OFFSET)
start_date = str((end_date - timedelta(TS_DAYS-1)).date())
end_date = str(end_date.date())

modis_files = sorted(glob.glob(os.path.join(common.GEE_MAPS_DIR, IMAGE_DIR, f'MODIS_*.vrt')))
modis_files = select_files(modis_files, start_date, end_date)
modis_images = [gdal.Open(fn, gdal.GA_ReadOnly) for fn in modis_files]

origin = (modis_images[0].GetGeoTransform()[0], modis_images[0].GetGeoTransform()[3])
pixel_size = (modis_images[0].GetGeoTransform()[1], modis_images[0].GetGeoTransform()[5])
raster_size = (modis_images[0].RasterXSize, modis_images[0].RasterYSize)
origin_aux = num_pixels((aux_image.GetGeoTransform()[0], aux_image.GetGeoTransform()[3]),
                        origin, pixel_size)
# Process by block - assumes all images and bands have the same block size
batch_size = aux_image.GetRasterBand(1).GetBlockSize()

#### Convert PRISM data to mapping resolution

In [None]:
if 'prism' in model_params['dataSources']:
    prism_files = sorted(glob.glob(os.path.join(common.GEE_MAPS_DIR, IMAGE_DIR, f'PRISM_*.tif')))
    prism_files = select_files(prism_files, start_date, end_date)
    prism_images = [gdal.Open(fn, gdal.GA_ReadOnly) for fn in prism_files]

    # MODIS data projection and resolution
    proj = modis_images[0].GetProjection()
    geotrans = modis_images[0].GetGeoTransform()
    x_size = modis_images[0].RasterXSize
    y_size = modis_images[0].RasterYSize

    for num in range(len(prism_images)):
        # In-memory raster for the reprojected data
        print(prism_files[num])
        dst = gdal.GetDriverByName('MEM').Create("", x_size, y_size, prism_images[num].RasterCount, gdal.GDT_Float32)
        dst.SetGeoTransform(geotrans)
        dst.SetProjection(proj)
        gdal.ReprojectImage(prism_images[num], dst, proj, proj, gdal.GRA_NearestNeighbour)
        prism_images[num] = dst
        
    origin_prism = num_pixels((prism_images[0].GetGeoTransform()[0], prism_images[0].GetGeoTransform()[3]),
                              origin, pixel_size)
    print(origin, origin_aux, origin_prism)

### Create output Geotiff

In [None]:
driver = gdal.GetDriverByName('GTiff')
out_map_raster = driver.Create(OUTPUT_FILE, modis_images[0].RasterXSize, modis_images[0].RasterYSize, 2, gdal.GDT_Float32)
out_map_raster.SetGeoTransform(modis_images[0].GetGeoTransform())
out_map_raster.SetProjection(modis_images[0].GetProjectionRef())
lfmc_band = out_map_raster.GetRasterBand(1)
lfmc_band.SetNoDataValue(NODATA)
std_band = out_map_raster.GetRasterBand(2)
std_band.SetNoDataValue(NODATA)

## Generate LFMC estimates
Loop through the images by raster block, prepare the data, make LFMC predictions and save to output raster.
- Nodata pixels are removed before making predictions
- Indexes used to link between dataframes and arrays
- Block processing skipped if all pixels in aux block are nodata

In [None]:
modis_start = modis_files[0][FN_DATE_POS[0]:FN_DATE_POS[1]]
if 'prism' in model_params['dataSources']:
    prism_start = prism_files[0][FN_DATE_POS[0]:FN_DATE_POS[1]]
augment = model_params['auxAugment']
for y_offset in range(0, raster_size[1], batch_size[1]):
    for x_offset in range(0, raster_size[0], batch_size[0]):
        start_time = time.time()
        x_size = min(batch_size[0], raster_size[0] - x_offset)
        y_size = min(batch_size[1], raster_size[1] - y_offset)
        lfmc_index = pd.Index(range(y_size * x_size))
        aux_df = get_aux_data(aux_image, bands, (int(x_offset + origin_aux[0]), int(y_offset + origin_aux[1])), (x_size, y_size))
        print(f'Processing block at ({x_offset}, {y_offset}), size ({x_size}, {y_size}), {len(aux_df)} predictions')
        if len(aux_df) > 0:
            if model_params['auxOneHotCols']:
                x_aux = aux_df[model_params['auxColumns'] + list(czones_enc.columns)].to_numpy()
            else:
                x_aux = aux_df[model_params['auxColumns']].to_numpy()
            aux_time = time.time()
            print('Aux processing:', round(aux_time - start_time, 2))
            x_modis = None
            x_prism = None

            for source in model_params['dataSources']:
                if source == 'modis':
                    modis_params = model_params['inputs']['modis']
                    x_modis = get_ts_data(
                        modis_images,
                        (x_offset, y_offset),
                        (x_size, y_size),
                        modis_params['channels'],
                        modis_start,
                        start_date,
                        aux_df.index,
                        modis_bounds,
                        modis_params['normalise'].get('scaled_range', (0, 1)),
                    )
                    if (augment is True) or (isinstance(augment, list) and 'modis' in augment):
                        x_aux = np.concatenate([x_aux, x_modis[:, -1, :]], axis=1)
                    elif isinstance(augment, dict) and 'modis' in augment.keys():
                        offset = augment[input_name] or 1
                        x_aux = np.concatenate([x_aux, x_modis[:, -offset, :]], axis=1)
                    modis_time = time.time()
                    print('Modis processing:', round(modis_time - aux_time, 2))

                if source == 'prism':
                    prism_params = model_params['inputs']['prism']
                    x_prism = get_ts_data(
                        prism_images,
                        (int(x_offset + origin_prism[0]), int(y_offset + origin_prism[1])),
                        (x_size, y_size),
                        prism_params['channels'],
                        prism_start,
                        start_date,
                        aux_df.index,
                        prism_bounds,
                        prism_params['normalise'].get('scaled_range', (0, 1)),
                    )
                    if (augment is True) or (isinstance(augment, list) and 'prism' in augment):
                        x_aux = np.concatenate([x_aux, x_prism[:, -1, :]], axis=1)
                    elif isinstance(augment, dict) and 'prism' in augment.keys():
                        offset = augment[input_name] or 1
                        x_aux = np.concatenate([x_aux, x_prism[:, -offset, :]], axis=1)
                    prism_time = time.time()
                    print('Prism processing:', round(prism_time - modis_time, 2))

            lfmc, std_dev = predict(models, DERIVED_MODEL, x_aux, x_modis, x_prism, aux_df.index)
            lfmc = lfmc.reindex(lfmc_index).to_numpy().reshape(1, y_size, x_size)
            lfmc[np.isnan(lfmc)] = NODATA
            lfmc_band.WriteArray(lfmc[0], xoff=x_offset, yoff=y_offset)
            lfmc_band.FlushCache()
            std_dev = std_dev.reindex(lfmc_index).to_numpy().reshape(1, y_size, x_size)
            std_dev[np.isnan(std_dev)] = NODATA
            std_band.WriteArray(std_dev[0], xoff=x_offset, yoff=y_offset)
            std_band.FlushCache()
        else:
            print('No data in block - skipping')
            empty = np.full([y_size, x_size], NODATA)
            lfmc_band.WriteArray(empty, xoff=x_offset, yoff=y_offset)
            lfmc_band.FlushCache()
            std_band.WriteArray(empty, xoff=x_offset, yoff=y_offset)
            std_band.FlushCache()

In [None]:
del out_map_raster