# DL4DS - UNSW test 

This notebook utlises DL4DS to process NCI CORDEX-CMIP5 Data collection.

In [None]:
from numba import cuda 
import os 

import numpy as np
import xarray as xr
import ecubevis as ecv
import dl4ds as dds
import scipy as sp
import netCDF4 as nc
import climetlab as cml

import xarray as xr
import cartopy.crs as ccrs  # CRS stands for "coordinate reference system"
import matplotlib.pyplot as plt
from datetime import datetime
#import pyresample
import yaml

import tensorflow as tf 
from tensorflow import keras
from tensorflow.keras import models

In [None]:
device = cuda.get_current_device()
device.reset()

os.chdir("YOUR_OWN_WORKDING_DIRECTORY")
if tf.test.gpu_device_name(): 
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
else:
    print("Please install GPU version of TF")

In [None]:
filename="/g/data/rr3/publications/CORDEX/output/AUS-44/UNSW/"
filename +="CSIRO-BOM-ACCESS1-0/rcp45/r1i1p1/UNSW-WRF360K/v1/day/tasmax/files/d20210629/"
file_dir = filename + "*"
print(file_dir)
data = xr.open_mfdataset(file_dir, parallel=False)    
lat = data.lat
lon = data.lon
time = data.time
tasmax = data.tasmax
print (tasmax.shape ,  tasmax.dims)

In [None]:
print(lon)
print(time)

In [None]:
min_lon = 139.0
max_lon = 195.0
cropped_tasmax = tasmax.sel( rlon=slice(min_lon,max_lon))
cropped_tasmax.shape

In [None]:
test_data  = cropped_tasmax.sel(time=slice('2089-07-02', '2100-12-31'))
val_data   = cropped_tasmax.sel(time=slice('2078-01-01', '2089-07-01'))
train_data = cropped_tasmax.sel(time=slice('2006-01-19', '2077-12-31'))

print(test_data.shape, val_data.shape, train_data.shape)

In [None]:
scaler_train = dds.StandardScaler(axis=None)
scaler_train.fit(train_data)  
y_train = scaler_train.transform(train_data)
y_test = scaler_train.transform(test_data)
y_val = scaler_train.transform(val_data)

In [None]:
y_train = y_train.expand_dims(dim='channel', axis=-1)
y_test = y_test.expand_dims(dim='channel', axis=-1)
y_val = y_val.expand_dims(dim='channel', axis=-1)

In [None]:
_ = dds.create_pair_hr_lr(
     array=  y_test.values[0],  # ytmaxscr_train.values[0], 
     array_lr = None,
     upsampling='spc',
     scale=8, 
     patch_size=None, 
     static_vars=None, 
     predictors=None, 
     season=None,
     debug=True, 
     interpolation='inter_area')

In [None]:
ARCH_PARAMS = dict(n_filters=8,
                   n_blocks=8,
                   normalization=None,
                   dropout_rate=0.0,
                   dropout_variant='spatial',
                   attention=False,
                   activation='relu',
                   localcon_layer=True)

tasmax_trainer = dds.SupervisedTrainer(
    backbone='resnet',
    upsampling='spc', 
    data_train=  y_train, 
    data_val=    y_val,
    data_test=   y_test,
    data_train_lr=None, 
    data_val_lr=None,  
    data_test_lr=None, 
    scale=8,
    time_window=None, 
    static_vars=None,
    predictors_train= None,     
    predictors_val= None,  
    predictors_test= None,  
    interpolation='inter_area',
    patch_size=None , 
    batch_size=60, 
    loss='mae',
    epochs=100, 
    steps_per_epoch=None, 
    validation_steps=None, 
    test_steps=None, 
    learning_rate=(1e-3, 1e-4), lr_decay_after=1e4,
    early_stopping=False, patience=6, min_delta=0, 
    save=False, 
    save_path=None,
    show_plot=True, verbose=True, 
    device= 'GPU', 
    **ARCH_PARAMS)

tasmax_trainer.run()

In [None]:
tasmax_pred = dds.Predictor(
    tasmax_trainer, 
    y_test, 
    scale=8, 
    array_in_hr=True,
    static_vars=None, 
    predictors= None, #[ytmaxscr_train], 
    time_window=None,
    interpolation='inter_area', 
    batch_size=8,
    scaler=scaler_train,
    save_path=None,
    save_fname=None,
    return_lr=True,
    device='GPU')

tasmax_unscaled_y_pred, tasmax_coarsened_array = tasmax_pred.run()


In [None]:
tasmax_scaled_y_pred = scaler_train.transform(tasmax_unscaled_y_pred)

In [None]:
ecv.plot(tasmax_coarsened_array,plot_size_px=400 )

In [None]:
ecv.plot(y_test.values,plot_size_px=400 ) +   ecv.plot(tasmax_scaled_y_pred,plot_size_px=400)  

In [None]:
data.close()