In [None]:
%matplotlib inline
import numpy as np

In [None]:
import xarray as xr

In [None]:
fields_3d = xr.open_mfdataset("../data/raw/2/NG_5120x2560x34_4km_10s_QOBS_EQX/coarse/3d/*.nc")
fields_2d = xr.open_dataset("../data/raw/2/NG_5120x2560x34_4km_10s_QOBS_EQX/coarse/2d/all.nc")

# for some reason the time values are all scrambled up.
fields_2d = fields_2d.isel(time=fields_2d.time.values.argsort())
data = xr.merge((fields_2d, fields_3d), join='inner').isel(y=slice(24, 40))

In [None]:
plt.figure(figsize=(4,10))
data.Prec.mean('y').plot()

In [None]:
inputs = data[['QV', 'TABS']]
output = data.Prec

In [None]:
mu = inputs.mean(['x','y', 'time'])
sig2 = ((inputs - mu)**2).mean()
sig  = sig2.compute().apply(np.sqrt)

def prepare_inputs(inputs):
    
    sample_dims = [dim for dim in ['x', 'y', 'time']
                  if dim in inputs.dims]

    inputs = (inputs-mu)/sig
    X = inputs.to_array().stack(samples=sample_dims, features=['variable', 'z']).compute()
    
    return X

X = prepare_inputs(inputs)
y  = output.stack(samples=['x', 'y', 'time']).compute()

In [None]:
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop


inds = np.random.choice(X.shape[0], 100000)

model = Sequential()
model.add(Dense(200, activation='relu', input_shape=(68,)))
model.add(Dropout(0.2))
model.add(Dense(200, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(1, activation='relu'))

model.compile(loss='mse',
              optimizer=RMSprop(),
              metrics=['accuracy'])

model.fit(X[inds], y[inds], epochs=5)

In [None]:
pred = model.predict(prepare_inputs(inputs.isel(x=0,y=8)))

In [None]:
pred = xr.DataArray(pred[:,0], coords=data.Prec.isel(x=0, y=8).coords)

In [None]:
comp = xr.Dataset({'truth': data.Prec, 'pred': pred})

In [None]:
import holoviews as hv
hv.extension('bokeh')

In [None]:
%%opts Curve[width=600]
ds = hv.Dataset(comp.isel(x=0, y=8).to_array().to_dataset(name= "Precip"))
ds.to.curve("time").overlay()