In [None]:
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib import cm
import cartopy.crs as ccrs
from keras.layers import Dense, SimpleRNN, LSTM, GRU, Reshape 
from keras import initializers
import matplotlib.pyplot as plt 
from keras.models import Sequential 
from keras.optimizers import RMSprop
import tensorflow.keras as keras
from keras.callbacks import ModelCheckpoint
from keras.callbacks import EarlyStopping
from keras import layers
from keras import models
from tensorflow.keras import models
import os
import scipy

## Authors
* Martin Wegmann (martin.wegmann@giub.unibe.ch)
* Fernando Jaume Santero (fernando.jaume@unige.ch)

### Read in Pseudo-Location data

In [None]:
df2 = pd.read_csv("/Volumes/SPARK/ISTI/EKF400_v1_assim_ISTI_less1831_smallest.txt", delimiter = " ")

df2

In [None]:
nlats=len(lats)
nlons=len(lons)

### Read in Gridded Data

In [None]:
pathTo20CR = '/Volumes/SPARK/20crv3/'
pathToEKF = '/Volumes/SPARK/ekf400v2/ensmean/' 
save_folder="/Volumes/SPARK/RNN_savestates/"

In [None]:
ds1 = xr.open_dataset(pathTo20CR  + 'air.2m.mon.mean_18512015_anoms_remap.nc')
ds1_var=ds1.air
ds2 = xr.open_dataset(pathToEKF + 'EKF400_ensmean_v2.0_t2m_anoms.nc')
ds2_var=ds2.air_temperature

In [None]:
lon_dim=ds1_var.shape[2]
lat_dim=ds1_var.shape[1]
print(lon_dim)
print(lat_dim)

In [None]:
latitudes=ds1_var.lat.values
longitudes=ds1_var.lon.values

### Convert latitude and longitude data of the stations to fit the grid

In [None]:
resolution_lon=360/lon_dim
print(resolution_lon)
resolution_lat=180/lat_dim
print(resolution_lat)
lats=df2.Lat.values
lons=df2.Lon.values
#lons=lons+lon_dim/2
#lats=lats-lat_dim/2
lons=(lons+180)/resolution_lon
lats=(lats-90)/(resolution_lat*-1)

In [None]:
# number of timesteps we have in the training data
timesteps_in_data=len(ds1_var.time.values)
timesteps_in_data

In [None]:
# number of timesteps we want to reconstruct
timesteps_in_testdata=len(ds2_var.time.values)
timesteps_in_testdata

### Set training sample size and amount of channels we want to train with

In [None]:
sample_size=timesteps_in_data-1 
sample_size=int(sample_size)
amount_locations=len(lats)
amount_channels=3

### Set output location and file name 

In [None]:
model="20cr"
member="det"
amount_locations=amount_locations
sample_size=sample_size
resolution="lowres"
output="anoms"



RNN1_path=os.path.join(save_folder, "best_model_50p" + output + "_" + model + "_" + resolution +"_" +str(amount_locations) + "_" + str(sample_size) + "_RNN1_"+member+".h5")
RNN1lstm_path=os.path.join(save_folder, "best_model_50p" + output + "_" + model + "_" + resolution +"_" +str(amount_locations) + "_" + str(sample_size) + "_RNN1lstm_"+member+".h5")

RNN1_path_nc=os.path.join(save_folder, "best_model_ekf400_50p_" + output + "_" + model + "_" + resolution +"_" +str(amount_locations) + "_" + str(sample_size) + "_RNN1_"+member+".nc")
RNN1lstm_path_nc=os.path.join(save_folder, "best_model_ekf400_50p_" + output + "_" + model + "_" + resolution +"_" +str(amount_locations) + "_" + str(sample_size) + "_RNN1lstm_"+member+".nc")




### Lets have a look at the gridded data

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = plt.axes(projection=ccrs.Mollweide(central_longitude=0, globe=None)) 
tplot=ds1_var.isel(time=0).plot.contourf(ax=ax,
levels = 17, transform=ccrs.PlateCarree(), cmap=cm.seismic, cbar_kwargs={'orientation':'vertical',
'fraction':0.012, 'pad':0.015, 'aspect':35})

tplot.colorbar.set_label('Temperature at 2 meters', size=16) 
tplot.ylabel_style = {'size':16}
ax.set_global()
ax.coastlines();

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = plt.axes(projection=ccrs.Mollweide(central_longitude=0, globe=None)) 
tplot=ds2_var.isel(time=0).plot.contourf(ax=ax,
levels = 17, transform=ccrs.PlateCarree(), cmap=cm.seismic, cbar_kwargs={'orientation':'vertical',
'fraction':0.012, 'pad':0.015, 'aspect':35})

tplot.colorbar.set_label('Temperature at 2 meters', size=16) 
tplot.ylabel_style = {'size':16}
ax.set_global()
ax.coastlines();

### Define our Checkpoints

In [None]:
ess = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=25)

mc_RNN1 = ModelCheckpoint(RNN1_path, monitor='val_loss', mode='min', save_best_only=True, verbose=1)
mc_RNN1lstm = ModelCheckpoint(RNN1lstm_path, monitor='val_loss', mode='min', save_best_only=True, verbose=1)



### Define our RNN models

In [None]:
N1= Sequential()
N1.add(SimpleRNN(50, input_shape=(amount_channels, amount_locations), activation='tanh', unroll=True))
#N1.add(Dense(256*512, activation='linear', bias_initializer=initializers. 􏰀→Constant(value=273.15)))
N1.add(Dense(lat_dim*lon_dim, activation='linear')) 
N1.add(Reshape((lat_dim,lon_dim)))
N1.summary()

In [None]:
N1_lstm= Sequential()
N1_lstm.add(LSTM(50, input_shape=(amount_channels, amount_locations), activation='tanh', unroll=True))
#N1.add(Dense(256*512, activation='linear', bias_initializer=initializers. 􏰀→Constant(value=273.15)))
N1_lstm.add(Dense(lat_dim*lon_dim, activation='linear')) 
N1_lstm.add(Reshape((lat_dim,lon_dim)))
N1_lstm.summary()

### Compile our RNN models

In [None]:
optimizer = RMSprop(lr=0.0001, decay=1e-5) 
N1.compile(keras.optimizers.Adam(1e-4), loss='mse',metrics=["mae"])
N1_lstm.compile(keras.optimizers.Adam(1e-4), loss='mse',metrics=["mae"])


### Create our latitude and longitude channels

In [None]:
lon_points = np.random.randint(ds1_var.shape[2], size=(sample_size,1,amount_locations))
# _test is the full EKF400v2 time frame we want to reconstruct
lon_points_test = np.random.randint(ds1_var.shape[2], size=(timesteps_in_testdata,1,amount_locations))

In [None]:
lat_points = np.random.randint(ds1_var.shape[1], size=(sample_size,1,amount_locations))
lat_points_test = np.random.randint(ds1_var.shape[1], size=(timesteps_in_testdata,1,amount_locations))

In [None]:
lon_points=lons
lat_points=lats

In [None]:
lat_points=np.array(lat_points,dtype=int)
lon_points=np.array(lon_points,dtype=int)

In [None]:
geo_locations_rand = np.random.rand(sample_size,1,nlons) # 49 for 7*7 locations
geo_locations_zero = np.zeros_like(geo_locations_rand, dtype=np.float32)

geo_locations_rand_test = np.random.rand(timesteps_in_testdata,1,nlons) # 49 for 7*7 locations
geo_locations_zero_test = np.zeros_like(geo_locations_rand_test, dtype=np.float32)

In [None]:
for i in range(0,timesteps_in_testdata):
    geo_locations_zero_test[i,0,:]=lon_points
lon_points_test=geo_locations_zero_test

In [None]:
for i in range(0,sample_size):
    geo_locations_zero[i,0,:]=lon_points
lon_points=geo_locations_zero

In [None]:
geo_locations_rand = np.random.rand(sample_size,1,nlats) # 49 for 7*7 locations
geo_locations_zero = np.zeros_like(geo_locations_rand, dtype=np.float32)

geo_locations_rand_test = np.random.rand(timesteps_in_testdata,1,nlats) # 49 for 7*7 locations
geo_locations_zero_test = np.zeros_like(geo_locations_rand_test, dtype=np.float32)

In [None]:
for i in range(0,timesteps_in_testdata):
    geo_locations_zero_test[i,0,:]=lat_points
lat_points_test=geo_locations_zero_test

In [None]:
for i in range(0,sample_size):
    geo_locations_zero[i,0,:]=lat_points
lat_points=geo_locations_zero

### create the time domain we want to sample

In [None]:
timesteps = np.random.randint(ds1_var.shape[0]-1, size=(sample_size,1,1))


In [None]:
all_timesteps_test=list(range(0,timesteps_in_testdata))
timesteps_in_testdata

### sample our data according to the time domain

In [None]:
y1_array = ds1_var[timesteps.flatten(),:,:] # selecting the time steps in the grid

In [None]:
y2_array_nonrandom = ds2_var[all_timesteps_test,:,:] # selecting the time steps in the grid



In [None]:
y1_matrix = y1_array.values.reshape((len(timesteps),ds1_var.shape[2]*ds1_var.shape[1])) # reshape to matrix, timesteps, lon*lat
y2_matrix = y2_array_nonrandom.values.reshape((len(all_timesteps_test),ds2_var.shape[2]*ds2_var.shape[1])) # reshape to matrix, timesteps, lon*lat

In [None]:
X1 = np.zeros_like(lon_points, dtype=np.float32) # create zero matrix with structure like lon_points

In [None]:
X2 = np.zeros_like(lon_points_test, dtype=np.float32) # create zero matrix with structure like lon_points

In [None]:
lat_points=np.array(lat_points,dtype=int)
lon_points=np.array(lon_points,dtype=int)

lat_points_test=np.array(lat_points_test,dtype=int)
lon_points_test=np.array(lon_points_test,dtype=int)

### sample our data according to the space domain

In [None]:
for i in range(len(timesteps)): 
        X1[i,0,:]=y1_matrix[i,lat_points[i,0,:]*lon_points[i,0,:]]

In [None]:
for i in range(len(all_timesteps_test)): 
        X2[i,0,:]=y2_matrix[i,lat_points_test[i,0,:]*lon_points_test[i,0,:]]

### normalize our data with the maximum values

In [None]:
lon_points_normmax=lon_points/lon_dim

lat_points_normmax=lat_points/lat_dim


In [None]:
lon_points_test_normmax=lon_points_test/lon_dim

lat_points_test_normmax=lat_points_test/lat_dim


In [None]:
X1_normmax=X1/X1.max()


In [None]:
X2_normmax=X2/X1.max()


In [None]:
y1_array_normmax=y1_array/y1_array.max()

y2_array_nonrandom_normax=y2_array_nonrandom/y1_array.max()



### concatenate our X or input data

In [None]:
input_nn = np.concatenate((lat_points_normmax,lon_points_normmax,X1_normmax),axis=1)
# input_nn_test has a nonrandom time domain, so the same time structure as EKF400v2
input_nn_test = np.concatenate((lat_points_test_normmax,lon_points_test_normmax,X2_normmax),axis=1)





### define our Y or output data

In [None]:
y_values=y1_array_normmax.values

### train our models

In [None]:
N1.fit(input_nn, y_values, batch_size=128, epochs=1000, verbose=1,validation_split=0.2, callbacks=[ess, mc_RNN1])

In [None]:
N1_lstm.fit(input_nn, y_values, batch_size=128, epochs=1000, verbose=1,validation_split=0.2, callbacks=[ess, mc_RNN1lstm])

In [None]:
Est_RNN1= N1.predict(input_nn)
Est_RNN1_test= N1.predict(input_nn_test)
Est_RNN1_test_lstm=N1_lstm.predict(input_nn_test)

In [None]:
unnormalize=y1_array.max().values

### lets have a look at our prediction

In [None]:
fig = plt.figure(figsize=(13, 7))
plt.imshow(y1_array_normmax[0,:,:]*y1_array.max()+modulator, vmin=-max, vmax=max, cmap='seismic',origin='upper',interpolation="none") 
plt.plot(input_nn[half_data,1,:]*lon_dim,input_nn[half_data,0,:]*lat_dim,'|k', markersize=7)
plt.colorbar()

In [None]:
fig = plt.figure(figsize=(13, 7))
plt.imshow(Est_RNN1[0,:,:]*unnormalize+modulator, vmin=-max, vmax=max, cmap='seismic',origin='upper',interpolation="none") 
plt.plot(input_nn[half_data,1,:]*lon_dim,input_nn[half_data,0,:]*lat_dim,'|k', markersize=7)
plt.colorbar()

### write our reconstructed fields out as netcdf

In [None]:
# ---------------------
import numpy as np
import datetime
from netCDF4 import Dataset,num2date,date2num
# -----------------------
nyears = Est_RNN1_test.shape[0];
output=RNN1_path_nc
unout = 'days since 1900-01-01 00:00:00'
# -----------------------
ny, nx = (lat_dim, lon_dim)
lon = longitudes
lat = latitudes

dataout = Est_RNN1_test[:,:,:]*unnormalize; # create some random data
datesout = [datetime.datetime(1900+iyear,1,1) for iyear in range(nyears)]; # create datevalues
# =========================
ncout = Dataset(output,'w','NETCDF4'); # using netCDF3 for output format 
ncout.createDimension('lon',nx);
ncout.createDimension('lat',ny);
ncout.createDimension('time',nyears);
lonvar = ncout.createVariable('lon','float32',('lon'));lonvar[:] = lon;
latvar = ncout.createVariable('lat','float32',('lat'));latvar[:] = lat;
timevar = ncout.createVariable('time','float64',('time'));timevar.setncattr('units',unout);timevar[:]=date2num(datesout,unout);
myvar = ncout.createVariable("t2m",'float32',('time','lat','lon'));myvar.setncattr('units',"K");myvar[:] = dataout;
ncout.close();

In [None]:
# ---------------------
import numpy as np
import datetime
from netCDF4 import Dataset,num2date,date2num
# -----------------------
nyears = Est_RNN1_test_lstm.shape[0];
output=RNN1lstm_path_nc
unout = 'days since 1900-01-01 00:00:00'
# -----------------------
ny, nx = (lat_dim, lon_dim)
lon = longitudes
lat = latitudes

dataout = Est_RNN1_test_lstm[:,:,:]*unnormalize; # create some random data
datesout = [datetime.datetime(1900+iyear,1,1) for iyear in range(nyears)]; # create datevalues
# =========================
ncout = Dataset(output,'w','NETCDF4'); # using netCDF3 for output format 
ncout.createDimension('lon',nx);
ncout.createDimension('lat',ny);
ncout.createDimension('time',nyears);
lonvar = ncout.createVariable('lon','float32',('lon'));lonvar[:] = lon;
latvar = ncout.createVariable('lat','float32',('lat'));latvar[:] = lat;
timevar = ncout.createVariable('time','float64',('time'));timevar.setncattr('units',unout);timevar[:]=date2num(datesout,unout);
myvar = ncout.createVariable("t2m",'float32',('time','lat','lon'));myvar.setncattr('units',"K");myvar[:] = dataout;
ncout.close();