In [None]:
import netCDF4 as nc
import xarray as xr
import numpy as np
import datetime
import pandas as pd
from dateutil.relativedelta import relativedelta
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import LinearRegression

#download
!curl https://zenodo.org/record/6532501/files/CESM_EA_SPI.nc?download=1 --output CESM_EA_SPI.nc

# Run Main file

In [None]:
%load_ext autoreload
%autoreload 2

%run train_siren.py --fast_dev_run 1

## RAW DATA SET

In [None]:
file_name = 'CESM_EA_SPI.nc'
ds = xr.open_dataset(file_name)
spi = ds['spi']
ds.close()
spi

# Experiment month

In [None]:
dim_month = pd.DataFrame(spi["time"].to_series().values, columns=['time'])
month = dim_month['time'].apply(lambda x: datetime.strptime(str(x), '%Y-%m-%dT%H:%M:%S').month).values
fract = (month-1)/12*np.pi*2
month_sin = np.sin(fract)
month_cos = np.cos(fract)

In [None]:
month_cos

In [None]:
month = 12
tmp=(month-1)/12*np.pi*2

In [None]:
np.sin(tmp)

In [None]:
np.cos(tmp)

# Experiment Padding Window

In [None]:
img = train_predictands[0,:,:]
img.shape

In [None]:
window=1
img_pad = np.pad(img, pad_width=window, mode='symmetric')
img_pad.shape

In [None]:
X.reshape(np.shape(X)[0]*np.shape(X)[1]*np.shape(X)[2], np.shape(X)[3])

In [None]:
window=5

data_x = []
data_y = []

for time_idx in range(num_samples):
    img_y = y[time_idx,:,:]
    img_pad_y = np.pad(img_y, pad_width=window, mode='symmetric')
    
    img_x = X[time_idx,:,:,:]
    img_pad_x[:,:,0] = np.pad(img_pad_y[:,:,0], pad_width=window, mode='symmetric')
    img_pad_x[:,:,1] = np.pad(img_pad_y[:,:,1], pad_width=window, mode='symmetric')
    img_pad_x[:,:,2] = np.pad(img_pad_y[:,:,2], pad_width=window, mode='symmetric')
    
    for lat in range(window, 13):
        for lon in range(window, 20):
            # y
            sample_y = img_pad_y[lat:lat+window,lon:lon+window]
            sample_y = sample_y.reshape(np.shape(sample_y)[0]*np.shape(sample_y)[1])
            data_y.append(sample_y)
            
            # x
            sample = img_pad_x[lat:lat+window,lon:lon+window,:]
            sample = sample.reshape(np.shape(sample)[0]*np.shape(sample)[1]*np.shape(sample)[2])
            data_x.append(sample)
    break

In [None]:
spi[dict(time=0, lat=slice(None, 2), lon=slice(None, 2))]

In [None]:
spi.sel(lat=slice(0,1))

In [None]:
#plot spatial spi distribution for random time
time = np.random.randint(spi.shape[0])
spi2d = spi.isel(time=time)
spi2d.plot()

In [None]:
#plot spi time series at random location
lat = np.random.randint(spi.shape[1])
lon = np.random.randint(spi.shape[2])
k = np.random.randint(spi.shape[0]-1000)
plt.plot(spi[k:k+1000,lat,lon])
plt.ylabel('SPI')

In [None]:
#plot overall spi distribution
plt.hist(np.array(spi).flatten(), bins=100)
plt.title('SPI distribution')
plt.show()

## Prepare data

In [None]:
def helge_assemble_predictors_predictands(start_date, end_date, lead_time, dataset, num_input_time_steps, window=5):
    '''
    Args
    ----
    start_date (str): The start date for extraction. Important, put the trailing 0 at the beginning of year for dates before 1000 (e.g., '0400')
    end_date (str): The end date for extraction
    lead_time (int): The number of months between the predictor/predictand
    dataset (str): Either 'CESM' or 'ECMWF'
    num_input_time_steps (int): The number of time steps to use for each predictor samples
    '''    
    file_name = {'CESM': 'CESM_EA_SPI.nc', 'ECMWF': 'ECMWF_EA_SPI.nc'}[dataset]
    ds = xr.open_dataset(file_name)
    spi = ds['spi'].sel(time=slice(start_date,end_date))
    num_samples=spi.shape[0] 
    #Stack and remove nans
    spi = np.stack([spi.values[n-num_input_time_steps:n] for n in range(num_input_time_steps, num_samples+1)])
    num_samples = spi.shape[0]
    spi[np.isnan(spi)] = 0
    #make sure we have floats in there
    X = spi.astype(np.float32)
    # select Y
    if dataset == 'ECMWF':
        start_date_plus_lead = pd.to_datetime(start_date) + pd.DateOffset(months=lead_time+num_input_time_steps-1)
        end_date_plus_lead = pd.to_datetime(end_date) + pd.DateOffset(months=lead_time)
    elif dataset == 'CESM':
        t_start=datetime.datetime(int(start_date.split('-')[0]),int(start_date.split('-')[1]),int(start_date.split('-')[2]))
        t_end=datetime.datetime(int(end_date.split('-')[0]),int(end_date.split('-')[1]),int(end_date.split('-')[2]))
        start_date_plus_lead = t_start + relativedelta(months=lead_time+num_input_time_steps-1)
        end_date_plus_lead = t_end + relativedelta(months=lead_time)
        if len(str(start_date_plus_lead.year))<4:
            start_date_plus_lead = '0'+start_date_plus_lead.strftime('%Y-%m-%d')
        elif len(str(start_date_plus_lead.year))==4:
            start_date_plus_lead = start_date_plus_lead.strftime('%Y-%m-%d')
        if len(str(end_date_plus_lead.year))<4:
            end_date_plus_lead = '0'+end_date_plus_lead.strftime('%Y-%m-%d')
        elif len(str(end_date_plus_lead.year))==4:
            end_date_plus_lead = end_date_plus_lead.strftime('%Y-%m-%d')
    subsetted_ds = ds['spi'].sel(time=slice(start_date_plus_lead, end_date_plus_lead))
    y = subsetted_ds.values.astype(np.float32)
    y[np.isnan(y)] = 0
    # add month feature
    month = pd.DataFrame(subsetted_ds["time"].to_series().values, columns=['time'])
    month = month['time'].apply(lambda x: datetime.datetime.strptime(str(x), '%Y-%m-%dT%H:%M:%S').month).values
    fract = (month-1)/12*np.pi*2
    month_sin = np.sin(fract)
    month_cos = np.cos(fract)       
    ds.close()
    X = np.moveaxis(X, 1,3)
    orig_shape_X = X.shape
    orig_shape_y = y.shape
    #y = y.reshape(np.shape(y)[0]*np.shape(y)[1]*np.shape(y)[2])
    #X = X.reshape(np.shape(X)[0]*np.shape(X)[1]*np.shape(X)[2], np.shape(X)[3])

    data_set = []
    halfwindow = int(window/2)

    for time_idx in range(num_samples):
        img_y = y[time_idx,:,:]
        img_pad_y = np.pad(img_y, pad_width=window, mode='symmetric')
        
        img_x = X[time_idx,:,:,:]
        img_pad_x = np.ndarray((img_pad_y.shape[0], img_pad_y.shape[1], 3))
        img_pad_x[:,:,0] = np.pad(img_x[:,:,0], pad_width=window, mode='symmetric')
        img_pad_x[:,:,1] = np.pad(img_x[:,:,1], pad_width=window, mode='symmetric')
        img_pad_x[:,:,2] = np.pad(img_x[:,:,2], pad_width=window, mode='symmetric')
        for lat in range(window, 13+window):
            for lon in range(window, 20+window):

                # x
                sample = img_pad_x[lat-halfwindow:lat+halfwindow+1,lon-halfwindow:lon+halfwindow+1,:]
                sample = sample.reshape(np.shape(sample)[0]*np.shape(sample)[1]*np.shape(sample)[2])
                sample = np.append(sample, month_sin[time_idx])
                sample = np.append(sample, month_cos[time_idx])
                sample = np.append(sample, lat)
                sample = np.append(sample, lon)
                
                # y
                sample_y = img_pad_y[lat,lon]
                #sample_y = img_pad_y[lat:lat+window,lon:lon+window]
                #sample_y = sample_y.reshape(np.shape(sample_y)[0]*np.shape(sample_y)[1])
                #data_y.append(sample_y)                
                sample = np.append(sample, sample_y)
                data_set.append(sample)
        #break
    return np.array(data_set)



In [None]:
num_input_time_steps = 3 
lead_time = 3
window = 3
ver=1

climate_model = 'CESM'

all_start_date = '0400-01-01'
all_end_date = '2021-12-31'

#train_start_date = '0400-01-01'
#train_end_date = '1800-12-31'

#test_start_date = '1801-01-01'
#test_end_date = '1978-12-31'

In [None]:
data_set[0]

# Store new dataset as netcdf files

In [None]:
def store_data(data_set, fn: str)->None:
    import xarray as xr
    sample_ticks = np.arange(data_set.shape[0])
    input_ticks = np.arange(data_set.shape[1])

    xr_drought = xr.DataArray(data_set, 
                 coords=[sample_ticks, input_ticks], 
                 dims=["sample_dim", "input_dim"],
                 name="samples",
                 attrs={"begin":all_start_date,
                        "end":all_end_date,
                        "climate_model":climate_model, 
                        "num_input_time_steps":num_input_time_steps,
                        "lead_time":lead_time,
                        "unit":"Standard Precipitation Index (SPI)",
                      })

    xr_dataset = xr.merge( [xr_drought], compat='override' )
    print(xr_dataset)

    xr_dataset.to_netcdf(fn)
    
# set seed
seed = 1
train_fract = 0.8
import pytorch_lightning as pl
pl.seed_everything(seed)

# split in train and test data
import numpy as np
rng = np.random.default_rng(seed)
msk = rng.random(data_set.shape[0]) < 0.8
data_set_train = data_set[msk,:]
data_set_test = data_set[~msk,:]

ver = 1
fn = "helge_TRAIN_dataset_window{}_sincostime_ver{}.nc".format(window,ver)
store_data(data_set_train, fn)

fn = "helge_TEST_dataset_window{}_sincostime_ver{}.nc".format(window,ver)
store_data(data_set_test, fn)

# Store mean std dev of training Data

In [None]:
fn_train = "helge_TRAIN_dataset_window{}_sincostime_ver{}.nc".format(window,ver)

import xarray as xr
xr_dataset = xr.open_dataset(fn_train)

In [None]:
xr_dataset

In [None]:
xr_dataset.dims['input_dim']

In [None]:
import xarray as xr
coords = np.arange(xr_dataset.dims['input_dim'])

xr_mean = xr.DataArray(xr_dataset.mean(dim='sample_dim').samples, 
             coords=[coords], 
             dims=["features"],
             name="mean",
             attrs={
                 "features":'[3*window^2+month_sin+month_cos+lat+lon+y]',
                  }) 

xr_std = xr.DataArray(xr_dataset.std(dim='sample_dim').samples, 
             coords=[coords], 
             dims=["features"],
             name="stddev",
             attrs={
                 "features":'[3*window^2+month_sin+month_cos+lat+lon+y]',
                  }) 

xr_dataset_mean_std = xr.merge( [xr_mean, xr_std], compat='override', combine_attrs='override', )
print(xr_dataset_mean_std)


fn = "helge_MEAN_STDDEV_dataset_window{}_sincostime_ver{}.nc".format(window,ver)
xr_dataset_mean_std.to_netcdf(fn)

## Train model

In [None]:
#regr = RandomForestRegressor(max_depth=4, n_jobs=-1, max_samples=0.1)
regr = LinearRegression()
regr.fit(train_predictors, train_predictands)


## Predict

In [None]:
pred = regr.predict(test_predictors)

## Evaluate model

In [None]:
mse = mean_squared_error(pred, test_predictands)
print('MSE:', mse)

In [None]:
#reshape prediction to initial shape
pred = pred.reshape(orig_shape_ytest)
test_predictands = test_predictands.reshape(orig_shape_ytest)

In [None]:
time = np.random.randint(pred.shape[0])
plt.figure(figsize=(30,6))
plt.subplot(1,3,1)
plt.title('Prediction')
plt.imshow(pred[time,:,:])
plt.colorbar()
plt.subplot(1,3,2)
plt.title('Truth')
plt.imshow(test_predictands[time,:,:])
plt.colorbar()
plt.subplot(1,3,3)
plt.title('Error')
plt.imshow(test_predictands[time,:,:]-pred[time,:,:])
plt.colorbar()


In [None]:
#plot spi time series at random location
lat = np.random.randint(pred.shape[1])
lon = np.random.randint(pred.shape[2])
k = np.random.randint(pred.shape[0]-100)
plt.plot(pred[k:k+100,lat,lon], label='Pred')
plt.plot(test_predictands[k:k+100,lat,lon], label='Truth')
plt.legend()
