# Infer model on array

---

## Imports

In [9]:
import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest

from functions import _get_output_array_size

## Testing the array size function

In [10]:
%%writefile test_get_array_size.py
import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest

from functions import _get_output_array_size

Overwriting test_get_array_size.py


In [11]:
%%writefile -a test_get_array_size.py

@pytest.fixture
def bgen_fixture() -> xbatcher.BatchGenerator:
    data = xr.DataArray(
        data=np.random.rand(100, 100, 10),
        dims=("x", "y", "t"),
        coords={
            "x": np.arange(100),
            "y": np.arange(100),
            "t": np.arange(10),
        }
    )
    
    bgen = xbatcher.BatchGenerator(
        data,
        input_dims=dict(x=10, y=10),
        input_overlap=dict(x=5, y=5),
    )
    return bgen

@pytest.mark.parametrize(
    "case_description, output_tensor_dim, new_dim, resample_dim, expected_output",
    [
        (
            "Resampling only: Downsample x, Upsample y",
            {'x': 5, 'y': 20},  
            [],
            ['x', 'y'],
            {'x': 50, 'y': 200} 
        ),
        (
            "New dimensions only: Add a 'channel' dimension",
            {'channel': 3},
            ['channel'],
            [],
            {'channel': 3}
        ),
        (
            "Mixed: Resample x and add new channel dimension",
            {'x': 30, 'channel': 12}, 
            ['channel'],
            ['x'],
            {'x': 300, 'channel': 12} 
        ),
        (
            "Identity resampling (ratio=1)",
            {'x': 10, 'y': 10},
            [],
            ['x', 'y'],
            {'x': 100, 'y': 100} 
        ),
        (
            "Dimension not in batcher is treated as new",
            {'t': 5},
            ['t'],
            [],
            {'t': 5}
        )
        
    ]
)
def test_get_output_array_size_scenarios(
    bgen_fixture,  # The fixture is passed as an argument
    case_description,
    output_tensor_dim,
    new_dim,
    resample_dim,
    expected_output
):
    """
    Tests various valid scenarios for calculating the output array size.
    The `case_description` parameter is not used in the code but helps make
    test results more readable.
    """
    # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture
    result = _get_output_array_size(
        bgen=bgen_fixture,
        output_tensor_dim=output_tensor_dim,
        new_dim=new_dim,
        resample_dim=resample_dim
    )
    
    assert result == expected_output, f"Failed on case: {case_description}"

Appending to test_get_array_size.py


In [12]:
%%writefile -a test_get_array_size.py

def test_get_output_array_size_raises_assertion_error_on_non_integer_size():
    """
    Tests that the function raises an AssertionError when the resampling
    calculation results in a non-integer output dimension size.
    """
    # DataArray size for 'x' is 101.
    data_for_error = xr.DataArray(
        data=np.random.rand(101, 100, 10),
        dims=("x", "y", "t")
    )
    
    bgen = xbatcher.BatchGenerator(data_for_error, input_dims={'x': 10})
    
    # The resampling logic will be: 101 * (5 / 10) = 50.5, which is not an integer.
    output_tensor_dim = {'x': 5}
    
    with pytest.raises(AssertionError):
        _get_output_array_size(
            bgen=bgen,
            output_tensor_dim=output_tensor_dim,
            new_dim=[],
            resample_dim=['x']
        )

Appending to test_get_array_size.py


In [13]:
!pytest -v

platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0 -- /srv/conda/envs/notebook/bin/python3.12
cachedir: .pytest_cache
rootdir: /home/jovyan/xbatcher-deep-learning/notebooks
plugins: anyio-4.9.0, hydra-core-1.3.2, jaxtyping-0.3.2
collected 6 items                                                              [0m[1m

test_get_array_size.py::test_get_output_array_size_scenarios[Resampling only: Downsample x, Upsample y-output_tensor_dim0-new_dim0-resample_dim0-expected_output0] [32mPASSED[0m[32m [ 16%][0m
test_get_array_size.py::test_get_output_array_size_scenarios[New dimensions only: Add a 'channel' dimension-output_tensor_dim1-new_dim1-resample_dim1-expected_output1] [32mPASSED[0m[32m [ 33%][0m
test_get_array_size.py::test_get_output_array_size_scenarios[Mixed: Resample x and add new channel dimension-output_tensor_dim2-new_dim2-resample_dim2-expected_output2] [32mPASSED[0m[32m [ 50%][0m
test_get_array_size.py::test_get_output_array_size_scenarios[Identity resamp