# Process model prediction

In [73]:
import glob
import numpy as np
import pandas as pd 
import xarray as xr

import tensorflow as tf

## Open dataset for prediction

In [3]:
path_file = "../data/validate/2019_04"#_ERA5.nc"
ds = xr.open_dataset(path_file)

We are going to manually recreate the functionality of the ```slice_generator``` class here,
just to pull out an input for the model.

In [9]:
# Choose arbitrary slice
start = 427
end   = 429

In [22]:
# Extract slice and reshape
array = ds['t2m'].isel(time=slice(start, end)).values
print(array.shape)
array = np.moveaxis(array, 0, 1)
print(array.shape)
array = array.reshape(-1, 2, 1, 21, 21)
print(array.shape)

(2, 21, 21)
(21, 2, 21)
(1, 2, 1, 21, 21)


## Import model and predict

In [54]:
model_dir = "../models/"
models_list = sorted(glob.glob(model_dir+"*.h5"))
print(models_list)

['../models/full_stack_1f_1c_21x_21y.h5', '../models/t_full_stack_1f_1c_21x_21y.h5']


In [79]:
# choose a model
file_index = 1
models_list[file_index]

model = tf.keras.models.load_model(models_list[file_index])
# model.summary()

Make prediction here:

In [68]:
pred = model.predict(array, verbose=1)
pred.shape



(1, 1, 1, 21, 21)

## Convert prediction back to ```netcdf``` file

Reshape model output

In [80]:
print(pred.shape)
out = pred.reshape(1,21,21)
print(out.shape)

(1, 1, 1, 21, 21)
(1, 21, 21)


Get timestamp of prediction

In [77]:
pred_time = ds['t2m'].isel(time=slice(end-1, end))['time'].values[0]
ts = pd.to_datetime(str(pred_time)) 
d = ts.strftime('%Y_%m_%d_%H')
d

'2019_04_18_20'

 Create new xarray DataSet with same dimensions and coordinates as original, and save it as a ```netcdf``` file  ```../data/pred/```

In [78]:
data = xr.DataArray(data = out,
                     dims=('time', 'latitude', 'longitude',),
                     coords={'latitude': ds['latitude'].values,
                            'longitude': ds['longitude'].values,
                            'time': ds['t2m'].isel(time=slice(end-1, end))['time'].values,
                            })
data = xr.Dataset(data_vars={'t2m':data})
data.to_netcdf("../data/pred/"+d+"pred.nc")
data