# Process model prediction

In [53]:
import glob
import numpy as np
import xarray as xr

import tensorflow as tf

## Open dataset for prediction

In [3]:
path_file = "../data/validate/2019_04"#_ERA5.nc"
ds = xr.open_dataset(path_file)

In [40]:
ds

We are going to manually recreate the functionality of the ```slice_generator``` class here,
just to pull out an input for the model

In [9]:
# Choose arbitrary slice
start = 427
end   = 429

In [22]:
# Extract slice and process
array = ds['t2m'].isel(time=slice(start, end)).values
print(array.shape)
array = np.moveaxis(array, 0, 1)
print(array.shape)
array = array.reshape(-1, 2, 1, 21, 21)
print(array.shape)

(2, 21, 21)
(21, 2, 21)
(1, 2, 1, 21, 21)


## Import model and predict

In [54]:
model_dir = "../models/"
models_list = sorted(glob.glob(model_dir+"*.h5"))
print(models_list)

['../models/full_stack_1f_1c_21x_21y.h5', '../models/t_full_stack_1f_1c_21x_21y.h5']


In [55]:
# choose a model
file_index = 1
models_list[file_index]
model_name = 'full_stack_1f_1c_21x_21y' # todo: use regex to parse filename

model = tf.keras.models.load_model(models_list[file_index])
model.summary()

Model: "Full_stack"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
model_input (InputLayer)        [(None, 2, 1, 21, 21 0                                            
__________________________________________________________________________________________________
tf_op_layer_unstack_1 (TensorFl [(None, 1, 21, 21),  0           model_input[0][0]                
__________________________________________________________________________________________________
gaussian_noise_1 (GaussianNoise (None, 1, 21, 21)    0           tf_op_layer_unstack_1[0][0]      
__________________________________________________________________________________________________
convB1 (Conv2D)                 (None, 8, 9, 9)      208         tf_op_layer_unstack_1[0][1]      
_________________________________________________________________________________________

In [56]:
pred = model.predict(array, verbose=1)
pred



array([[[[[90.887764, 90.887764, 90.887764, 91.39598 , 91.51119 ,
           90.887764, 90.887764, 90.887764, 93.77147 , 90.887764,
           91.389565, 90.887764, 90.887764, 90.887764, 92.01836 ,
           90.887764, 90.887764, 90.887764, 91.58931 , 90.887764,
           90.887764],
          [90.887764, 90.887764, 91.75855 , 90.887764, 90.887764,
           90.887764, 90.887764, 90.887764, 90.887764, 94.156   ,
           90.887764, 92.41626 , 90.887764, 90.887764, 90.887764,
           90.887764, 90.887764, 92.00379 , 90.887764, 90.887764,
           90.887764],
          [99.02188 , 95.13313 , 90.887764, 90.887764, 90.887764,
           93.748886, 90.887764, 91.87696 , 90.887764, 90.887764,
           93.07386 , 90.887764, 90.887764, 90.887764, 90.887764,
           90.887764, 93.98111 , 90.931885, 90.887764, 90.887764,
           91.122955],
          [98.57535 , 95.99892 , 90.887764, 90.887764, 90.887764,
           90.887764, 90.887764, 94.597786, 90.887764, 90.887764,
       

## Convert prediction back to ```netcdf``` file

In [58]:
out = pred #np.zeros(shape=(1,1,1,21,21))
print(out.shape)
out = out.reshape(1,21,21)
print(out.shape)


(1, 1, 1, 21, 21)
(1, 21, 21)


In [61]:
data = xr.DataArray(data = out,
                     dims=('time', 'latitude', 'longitude',),
                     coords={'latitude': ds['latitude'].values,
                            'longitude': ds['longitude'].values,
                            'time': ds['t2m'].isel(time=slice(end-1, end))['time'].values,
                            })
data