In [None]:
import attr
from uwnet.timestepper import Batch, predict_multiple_steps

def _convert_dataset_to_dict(dataset):
    return {key: dataset[key] for key in dataset.data_vars
            if 'time' in dataset[key].dims}


class XarrayBatch(Batch):
    """An Xarray-aware version of batch"""
    
    def __init__(self, dataset, **kwargs):
        data = _convert_dataset_to_dict(dataset)
        super(XarrayBatch, self).__init__(data, **kwargs)

    def get_model_inputs(self, t, state):
        inputs = super(XarrayBatch, self).get_model_inputs(t, state)
        for key in inputs:
            try:
                inputs[key] = inputs[key].drop('time')
            except ValueError:
                pass
        return xr.Dataset(inputs)
    
    
def get_time_step(ds):
    return float(ds.time.diff('time')[0]*86400)


def single_column_simulation(model, dataset, interval=(0,10)):
    """Run a single column model simulation with a model for the source terms
    
    Parameters
    ----------
    model
        pytorch model for producing the apparent sources
    dataset : xr.Dataset
        input dataset in the same format as the training data
    interval : tuple
        (start_time, end_time) interval
    """
    start, end = interval
    time_step = get_time_step(dataset)
    pred_generator = predict_multiple_steps(model.call_with_xr, batch, initial_time=start, prediction_length=end, time_step=time_step)
    datasets = []
    for k, state in pred_generator:
        datasets.append(xr.Dataset(state).assign_coords(time=dataset.time[k]))
    output_time_series = xr.concat(datasets, dim='time')
    return output_time_series


In [None]:
import xarray as xr
from uwnet.model import ApparentSource

model = ApparentSource({},{})
ds = xr.open_dataset("../data/processed/training.nc")
batch = XarrayBatch(ds, prognostics=['QT', 'SLI'])

state = batch.get_prognostics_at_time(0)
ins = batch.get_model_inputs(0, state)
# ins = ins.expand_dims('time')
outs = model.call_with_xr(ins)

In [None]:
output = single_column_simulation(model, ds, interval=(0, 10))

In [None]:
output