# Infer model on array

---

## Imports

In [1]:
import torch
import xbatcher
import xarray as xr
import numpy as np

## Toy data

In [2]:
data = xr.DataArray(
    data=np.random.rand(100, 100, 10),
    dims=("x", "y", "t")
)
data

## Simple model

In [3]:
class MeanAlongDim(torch.nn.Module):
    def __init__(self, ax):
        super(MeanAlongDim, self).__init__()
        self.ax = ax

    def forward(self, x):
        return torch.mean(x, self.ax)

## Batch generator, dataset

In [4]:
from xbatcher.loaders.torch import MapDataset

bgen = xbatcher.BatchGenerator(
    data,
    input_dims=dict(x=10, y=10),
    input_overlap=dict(x=5, y=5),
)

ds = MapDataset(bgen)

inp = next(iter(ds))

# Check the input/output size of the first example
print("Input shape:", inp.shape)

mad = MeanAlongDim(-1)
print("Output shape:", mad(inp).shape)

Input shape: torch.Size([10, 10, 10])
Output shape: torch.Size([10, 10])


In [5]:
assert torch.allclose(mad(inp), torch.mean(inp, -1))

## Inference function

In [6]:
%run ./functions.ipynb

In [7]:
out = predict_on_array(
    dataset=ds, 
    model=mad, 
    output_tensor_dim=dict(y=20, x=10, x_new=5),
    new_dim=["x_new"],
    resample_dim=["x", "y"]
)
print(out.shape)

(200, 100, 5)
