# Infer model on array

---

## Imports

In [1]:
import torch
import xbatcher
import xarray as xr
import numpy as np

## Toy data

In [2]:
data = xr.DataArray(
    data=np.random.rand(100, 100, 10),
    dims=("x", "y", "t")
)
data

## Simple model

In [3]:
class MeanAlongDim(torch.nn.Module):
    def __init__(self, ax):
        super(MeanAlongDim, self).__init__()
        self.ax = ax

    def forward(self, x):
        return torch.mean(x, self.ax)

## Batch generator, dataset

In [4]:
from xbatcher.loaders.torch import MapDataset

bgen = xbatcher.BatchGenerator(
    data,
    input_dims=dict(x=10, y=10),
    input_overlap=dict(x=5, y=5),
)

ds = MapDataset(bgen)

inp = next(iter(ds))

# Check the input/output size of the first example
print("Input shape:", inp.shape)

mad = MeanAlongDim(-1)
print("Output shape:", mad(inp).shape)

Input shape: torch.Size([10, 10, 10])
Output shape: torch.Size([10, 10])


In [5]:
assert torch.allclose(mad(inp), torch.mean(inp, -1))

## Inference function

In [6]:
%run ./functions.ipynb

In [17]:
out_size_dict = _get_output_array_size(
    bgen = ds.X_generator,
    output_tensor_dim = dict(y=10, x=10, t=5),
    new_dim = ["t"],
    resample_dim = ["y", "x", "t"]
)
print(out_size_dict)

{'y': 100, 'x': 100, 't': 5}


In [11]:
from functools import partial

In [21]:
get_array_size_partial = partial(_get_output_array_size, bgen=ds.X_generator)

In [18]:
test_cases = [
    {
        "name": "Same dims and same input sizes 1",
        "function_inputs": {
            "output_tensor_dim": dict(y=10, x=10),
            "new_dim": [],
            "resample_dim": ["x", "y"]
        },
        "expected_output": dict(y=100, x=100)
    },
    {
        "name": "New dim added",
        "function_inputs": {
            "output_tensor_dim": dict(y=10, x=10, t=5),
            "new_dim": ["t"],
            "resample_dim": ["y", "x"]
        },
        "expected_output": dict(y=100, x=100, t=5)
    }
]         

In [22]:
for i, test_case in enumerate(test_cases):
    true_output = get_array_size_partial(**test_case["function_inputs"])
    success = true_output == test_case["expected_output"]
    message = "passed" if success else "failed"
    print(f"Test case {i} {message}")

Test case 0 passed
Test case 1 passed
