# Infer model on array

---

## Imports

In [59]:
import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest

## Testing the array size function

In [62]:
%run ./functions.ipynb

In [63]:
%%writefile test_get_array_size.py
import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest

@pytest.fixture
def bgen_fixture() -> xbatcher.BatchGenerator:
    data = xr.DataArray(
        data=np.random.rand(100, 100, 10),
        dims=("x", "y", "t"),
        coords={
            "x": np.arange(100),
            "y": np.arange(100),
            "t": np.arange(10),
        }
    )
    
    bgen = xbatcher.BatchGenerator(
        data,
        input_dims=dict(x=10, y=10),
        input_overlap=dict(x=5, y=5),
    )
    return bgen

@pytest.mark.parametrize(
    "case_description, output_tensor_dim, new_dim, resample_dim, expected_output",
    [
        (
            "Resampling only: Downsample x, Upsample y",
            # window=10 -> tensor=5 (0.5x); window=10 -> tensor=20 (2x)
            {'x': 5, 'y': 20},  
            [],
            ['x', 'y'],
            # ds_size=100 * 0.5 = 50; ds_size=100 * 2 = 200
            {'x': 50, 'y': 200} 
        ),
        (
            "New dimensions only: Add a 'channel' dimension",
            {'channel': 3},
            ['channel'],
            [],
            {'channel': 3}
        ),
        (
            "Mixed: Resample x and add new channel dimension",
            # window=10 -> tensor=30 (3x)
            {'x': 30, 'channel': 12}, 
            ['channel'],
            ['x'],
            # ds_size=100 * 3 = 300
            {'x': 300, 'channel': 12} 
        ),
        (
            "Identity resampling (ratio=1)",
            {'x': 10, 'y': 10},
            [],
            ['x', 'y'],
            # ds_size * 1 = ds_size
            {'x': 100, 'y': 100} 
        ),
        (
            "Dimension not in batcher is treated as new",
            # 't' is in the dataset but not in `input_dims`, so it's not a resample dim.
            # The logic should treat it as a new dimension.
            {'t': 5},
            ['t'],
            [],
            {'t': 5}
        )
    ]
)
def test_get_output_array_size_scenarios(
    bgen_fixture,  # The fixture is passed as an argument
    case_description,
    output_tensor_dim,
    new_dim,
    resample_dim,
    expected_output
):
    """
    Tests various valid scenarios for calculating the output array size.
    The `case_description` parameter is not used in the code but helps make
    test results more readable.
    """
    # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture
    result = _get_output_array_size(
        bgen=bgen_fixture,
        output_tensor_dim=output_tensor_dim,
        new_dim=new_dim,
        resample_dim=resample_dim
    )
    
    assert result == expected_output, f"Failed on case: {case_description}"

def test_get_output_array_size_raises_assertion_error_on_non_integer_size():
    """
    Tests that the function raises an AssertionError when the resampling
    calculation results in a non-integer output dimension size.
    """
    # Create a dataset where the total size is not a clean multiple
    # for the resampling ratio we will test.
    # DataArray size for 'x' is 101.
    data_for_error = xr.DataArray(
        data=np.random.rand(101, 100, 10),
        dims=("x", "y", "t")
    )
    
    # The batch window size for 'x' is 10.
    bgen = xbatcher.BatchGenerator(data_for_error, input_dims={'x': 10})
    
    # The resampling logic will be: 101 * (5 / 10) = 50.5, which is not an integer.
    output_tensor_dim = {'x': 5}
    
    # Use pytest.raises to assert that an AssertionError is thrown.
    with pytest.raises(AssertionError):
        _get_output_array_size(
            bgen=bgen,
            output_tensor_dim=output_tensor_dim,
            new_dim=[],
            resample_dim=['x']
        )

Overwriting test_get_array_size.py


In [64]:
!pytest test_get_array_size.py

platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/jovyan/xbatcher-deep-learning/notebooks
plugins: anyio-4.9.0, hydra-core-1.3.2, jaxtyping-0.3.2
collected 6 items                                                              [0m[1m

test_get_array_size.py [31mF[0m[31mF[0m[31mF[0m[31mF[0m[31mF[0m[31mF[0m[31m                                            [100%][0m

[31m[1m_ test_get_output_array_size_scenarios[Resampling only: Downsample x, Upsample y-output_tensor_dim0-new_dim0-resample_dim0-expected_output0] _[0m

bgen_fixture = <xbatcher.generators.BatchGenerator object at 0x7f1a1a713560>
case_description = 'Resampling only: Downsample x, Upsample y'
output_tensor_dim = {'x': 5, 'y': 20}, new_dim = [], resample_dim = ['x', 'y']
expected_output = {'x': 50, 'y': 200}

    [0m[37m@pytest[39;49;00m.mark.parametrize([90m[39;49;00m
        [33m"[39;49;00m[33mcase_description, output_tensor_dim, new_dim, resample_dim, expected_output[39;49;0

## Toy data

In [44]:
data = xr.DataArray(
    data=np.random.rand(100, 100, 10),
    dims=("x", "y", "t")
)
data

## Simple model

In [45]:
class MeanAlongDim(torch.nn.Module):
    def __init__(self, ax):
        super(MeanAlongDim, self).__init__()
        self.ax = ax

    def forward(self, x):
        return torch.mean(x, self.ax)

## Batch generator, dataset

In [46]:
from xbatcher.loaders.torch import MapDataset

bgen = xbatcher.BatchGenerator(
    data,
    input_dims=dict(x=10, y=10),
    input_overlap=dict(x=5, y=5),
)

ds = MapDataset(bgen)

inp = next(iter(ds))

# Check the input/output size of the first example
print("Input shape:", inp.shape)

mad = MeanAlongDim(-1)
print("Output shape:", mad(inp).shape)

Input shape: torch.Size([10, 10, 10])
Output shape: torch.Size([10, 10])


In [47]:
assert torch.allclose(mad(inp), torch.mean(inp, -1))

## Inference function

In [48]:
%run ./functions.ipynb

## Pytest

In [54]:
@pytest.fixture
def bgen_fixture() -> xbatcher.BatchGenerator:
    data = xr.DataArray(
        data=np.random.rand(100, 100, 10),
        dims=("x", "y", "t"),
        coords={
            "x": np.arange(100),
            "y": np.arange(100),
            "t": np.arange(10),
        }
    )
    
    # 2. Create the BatchGenerator with input dimensions and overlap
    bgen = xbatcher.BatchGenerator(
        data,
        input_dims=dict(x=10, y=10),
        input_overlap=dict(x=5, y=5),
    )
    return bgen

In [56]:
@pytest.mark.parametrize(
    "case_description, output_tensor_dim, new_dim, resample_dim, expected_output",
    [
        (
            "Resampling only: Downsample x, Upsample y",
            # window=10 -> tensor=5 (0.5x); window=10 -> tensor=20 (2x)
            {'x': 5, 'y': 20},  
            [],
            ['x', 'y'],
            # ds_size=100 * 0.5 = 50; ds_size=100 * 2 = 200
            {'x': 50, 'y': 200} 
        ),
        (
            "New dimensions only: Add a 'channel' dimension",
            {'channel': 3},
            ['channel'],
            [],
            {'channel': 3}
        ),
        (
            "Mixed: Resample x and add new channel dimension",
            # window=10 -> tensor=30 (3x)
            {'x': 30, 'channel': 12}, 
            ['channel'],
            ['x'],
            # ds_size=100 * 3 = 300
            {'x': 300, 'channel': 12} 
        ),
        (
            "Identity resampling (ratio=1)",
            {'x': 10, 'y': 10},
            [],
            ['x', 'y'],
            # ds_size * 1 = ds_size
            {'x': 100, 'y': 100} 
        ),
        (
            "Dimension not in batcher is treated as new",
            # 't' is in the dataset but not in `input_dims`, so it's not a resample dim.
            # The logic should treat it as a new dimension.
            {'t': 5},
            ['t'],
            [],
            {'t': 5}
        )
    ]
)
def test_get_output_array_size_scenarios(
    bgen_fixture,  # The fixture is passed as an argument
    case_description,
    output_tensor_dim,
    new_dim,
    resample_dim,
    expected_output
):
    """
    Tests various valid scenarios for calculating the output array size.
    The `case_description` parameter is not used in the code but helps make
    test results more readable.
    """
    # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture
    result = _get_output_array_size(
        bgen=bgen_fixture,
        output_tensor_dim=output_tensor_dim,
        new_dim=new_dim,
        resample_dim=resample_dim
    )
    
    assert result == expected_output, f"Failed on case: {case_description}"

In [58]:
def test_get_output_array_size_raises_assertion_error_on_non_integer_size():
    """
    Tests that the function raises an AssertionError when the resampling
    calculation results in a non-integer output dimension size.
    """
    # Create a dataset where the total size is not a clean multiple
    # for the resampling ratio we will test.
    # DataArray size for 'x' is 101.
    data_for_error = xr.DataArray(
        data=np.random.rand(101, 100, 10),
        dims=("x", "y", "t")
    )
    
    # The batch window size for 'x' is 10.
    bgen = xbatcher.BatchGenerator(data_for_error, input_dims={'x': 10})
    
    # The resampling logic will be: 101 * (5 / 10) = 50.5, which is not an integer.
    output_tensor_dim = {'x': 5}
    
    # Use pytest.raises to assert that an AssertionError is thrown.
    with pytest.raises(AssertionError):
        _get_output_array_size(
            bgen=bgen,
            output_tensor_dim=output_tensor_dim,
            new_dim=[],
            resample_dim=['x']
        )
pytest.main(['-v'])

platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0 -- /srv/conda/envs/notebook/bin/python
cachedir: .pytest_cache
rootdir: /home/jovyan/xbatcher-deep-learning/notebooks
plugins: anyio-4.9.0, hydra-core-1.3.2, jaxtyping-0.3.2
[1mcollecting ... [0mcollected 0 items



<ExitCode.NO_TESTS_COLLECTED: 5>