In [42]:
%load_ext autoreload
%autoreload 2
# Add parent directory into system path
import sys, os
sys.path.insert(1, os.path.abspath(os.path.normpath('..')))

In [79]:
def batch_loader(*args, batch_size=None, num_batches:int=10):
    """
    Fast batch loader without collate_fn
    
    Parameters
    ----------
    - :attr:`*args`: `torch.tensor` or `numpy.array` \n
        The first dimension (`shape[0]`) of all input must be the same
    - :attr:`batch_size`: int | None \n
        The number of batch size. If it equals to `None`, the `batch_size` will be calculated from `num_batches`
    - :attr:`num_batches`: int \n
        The total number of batches (default=10). If `batch_size` is not `None`, `num_batches` will be ignored.

    Returns
    -------
    Generator of tuples of a batch

    Example
    -------
    ```python
    >> bl = batch_loader(np.ones((100,3)), np.ones(100), batch_size=30)
    >> for x, y in bl:
    >>    print(x.shape, y.shape)
    (30, 3) (30,)
    (30, 3) (30,)
    (30, 3) (30,)
    (10, 3) (10,)
    ```
    """
    assert len(args) > 0, 'Missing input'
    assert all([hasattr(x, 'shape') for x in args]), 'arguments must be torch.tensor or numpy.array'
    total_length = [x.shape[0] for x in args]
    assert total_length.count(total_length[0]) == len(total_length), f'The first dimension of every tensor or array must be the same: {total_length}'
    
    if batch_size is None:
        batch_size = total_length[0] // num_batches

    return (
        tuple(x[start:start+batch_size] for x in args) for start in range(0, total_length[0], batch_size)
    )

def run_batch(callback, *args, reducer=None, **kwarg):
    """
    Run `callback` with `batch_loader`
    see `utils/dataset_generator.py`:`batch_loader` for more information

    Parameters
    ----------
    - :attr:`callback`: function to run for each batch
    - :attr:`*args`: `torch.tensor` or `numpy.array` \n
        The first dimension (`shape[0]`) of all input must be the same
    - :attr:`reducer`: function to combine the result of calculation for each batch
    - :attr:`batch_size`: int | None \n
        The number of batch size. If it equals to `None`, the `batch_size` will be calculated from `num_batches`
    - :attr:`num_batches`: int \n
        The total number of batches (default=10). If `batch_size` is not `None`, `num_batches` will be ignored.
    
    Returns:
    Any

    Example
    -------
    ```
    >> class A:
    >>    def f1(self, x, y):
    >>        return np.mean(np.mean(x) + y)
    >> a = A()
    >> run_batch(a.f1, np.ones((100, 3)), np.ones(100), reducer=np.mean, batch_size=30)
    2.0

    >> class B:
    >>    def f1(self, x, y):
    >>        return torch.mean(torch.mean(x) + y)
    >> b = B()
    >> run_batch(b.f1, torch.ones((100, 3), device='cuda'), torch.ones(100, device='cuda'), reducer=torch.mean, batch_size=30)
    tensor(2., device='cuda:0')
    ```
    """
    is_self = 'self' in callback.__code__.co_varnames
    callback_args_count = callback.__code__.co_argcount - (1 if is_self else 0) 
    if callback_args_count != len(args):
        print(f'[warning] The number of arguments of callback have to match input arguments: {callback_args_count} != {len(args)}')
    
    result = [callback(*x) for x in batch_loader(*args, **kwarg)]
    if isinstance(args[0], torch.Tensor):
        assert hasattr(args[0], 'device'), 'torch.Tensor should have device attribute'
        result = torch.tensor(result, device=args[0].device)
    
    return result if reducer is None else reducer(result)
    

In [81]:
import torch
from models.MLP import Davies2021
net = Davies2021(N_layers=8, width=28, activation=torch.nn.Softplus(30), last_activation=torch.nn.Softplus(30)).to('cuda')
#net = Davies2021(N_layers=8, width=28, activation=nn.SiLU(), last_activation=nn.Identity()).to(device)

run_batch(net.test, torch.ones((100, 3), device='cuda'), torch.ones((100,), device='cuda'), reducer=torch.mean, batch_size=30)



tensor(1.1628, device='cuda:0')

In [74]:
net.test.__code__.co_filename

'C:\\Users\\Jirawat\\miniconda3\\envs\\sdf\\lib\\site-packages\\torch\\autograd\\grad_mode.py'