# Process model prediction

In [1]:
import glob
import numpy as np
import pandas as pd 
import xarray as xr

import tensorflow as tf

## Open dataset for prediction

In [2]:
path_file = "../../data/validate/2019_04_ERA5.nc"
ds = xr.open_dataset(path_file)

We are going to manually recreate the functionality of the ```slice_generator``` class here,
just to pull out an input for the model.

In [15]:
# Choose arbitrary slice
start = 427
end   = 428
vars_ = 't2m'

In [19]:
# Extract slice and reshape
array = ds[vars_].isel(time=slice(start, end)).values
print(array.shape)
array = array.reshape(-1, 1, 21, 21)
print(array.shape)
# array = np.moveaxis(array, 0, 1)
# print(array.shape)
# array = array.reshape(-1, 2, 1, 21, 21)
# print(array.shape)

(1, 21, 21)
(1, 1, 21, 21)


## Import model

In [7]:
model_dir = "./"
models_list = sorted(glob.glob(model_dir+"*.h5"))
print(models_list)

['./less_deconv_1F.h5']


In [10]:
# choose a model
file_index = 0
models_list[file_index]

model = tf.keras.models.load_model(models_list[file_index])
model.summary()

Model: "Full_stack"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
model_input (InputLayer)        [(None, 1, 21, 21)]  0                                            
__________________________________________________________________________________________________
gaussian_noise_2 (GaussianNoise (None, 1, 21, 21)    0           model_input[0][0]                
__________________________________________________________________________________________________
convA1 (Conv2D)                 (None, 8, 17, 17)    208         gaussian_noise_2[0][0]           
__________________________________________________________________________________________________
convA1_bn (BatchNormalization)  (None, 8, 17, 17)    32          convA1[0][0]                     
_________________________________________________________________________________________

In [11]:
model.history

## Make prediction here:

In [20]:
pred = model.predict(array, verbose=1)
pred



array([[[[[2.01120768e-02, 1.96349137e-02, 5.12968749e-02,
           0.00000000e+00, 3.42606939e-02, 0.00000000e+00,
           7.30598941e-02, 4.66653965e-02, 4.69413660e-02,
           0.00000000e+00, 5.13737649e-02, 0.00000000e+00,
           2.70042364e-02, 1.63345058e-02, 6.96390271e-02,
           5.54146394e-02, 1.06679499e-01, 0.00000000e+00,
           0.00000000e+00, 0.00000000e+00, 2.66902987e-02],
          [0.00000000e+00, 5.99247031e-03, 0.00000000e+00,
           0.00000000e+00, 3.38783190e-02, 0.00000000e+00,
           2.10062396e-02, 0.00000000e+00, 0.00000000e+00,
           0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
           9.39025544e-03, 2.15871949e-02, 4.49323021e-02,
           0.00000000e+00, 7.85134137e-02, 0.00000000e+00,
           0.00000000e+00, 1.49672516e-02, 2.38494240e-02],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
           1.96719635e-02, 0.00000000e+00, 9.74641591e-02,
           3.18774581e-02, 0.00000000e+00, 0.00000000e

## Convert prediction back to ```netcdf``` file

Reshape model output

In [21]:
print(pred.shape)
out = pred.reshape(1,21,21)
print(out.shape)

(1, 1, 1, 21, 21)
(1, 21, 21)


Get timestamp of prediction

In [22]:
pred_time = ds['t2m'].isel(time=slice(end-1, end))['time'].values[0]
ts = pd.to_datetime(str(pred_time)) 
d = ts.strftime('%Y_%m_%d_%H')
d

'2019_04_18_19'

 Create new xarray DataSet with same dimensions and coordinates as original, and save it as a ```netcdf``` file in ```../data/pred/```

In [24]:
data = xr.DataArray(data = out,
                     dims=('time', 'latitude', 'longitude',),
                     coords={'latitude': ds['latitude'].values,
                            'longitude': ds['longitude'].values,
                            'time': ds['t2m'].isel(time=slice(end-1, end))['time'].values,
                            })
data = xr.Dataset(data_vars={'t2m':data})
data.to_netcdf("../../data/pred/"+d+"pred.nc", format="netcdf4")
data