# Batching

MagPy supports tensor-valued functions and time arguments, batch-evaluating accordingly.

For scalar functions and tensor time, the function is evaluated elementwise:

In [1]:
from torch import tensor, stack, sin, cos
from magpy import X, Y, FunctionProduct as FP

H = sin*X()

H(tensor([1,2,3]))

tensor([[[0.0000+0.j, 0.8415+0.j],
         [0.8415+0.j, 0.0000+0.j]],

        [[0.0000+0.j, 0.9093+0.j],
         [0.9093+0.j, 0.0000+0.j]],

        [[0.0000+0.j, 0.1411+0.j],
         [0.1411+0.j, 0.0000+0.j]]], dtype=torch.complex128)

For tensor functions and scalar time, the time argument is repeated:

In [2]:
H = (lambda t: stack((sin(t), t))) * X()

H(tensor(1))

tensor([[[0.0000+0.j, 0.8415+0.j],
         [0.8415+0.j, 0.0000+0.j]],

        [[0.0000+0.j, 1.0000+0.j],
         [1.0000+0.j, 0.0000+0.j]]], dtype=torch.complex128)

For tensor functions and tensor time, both are repeated. Note how the first three elements correspond to `sin` and the last three to `t`.

In [3]:
H = (lambda t: stack((sin(t), t))) * X()

H(tensor([1,2,3]))

tensor([[[0.0000+0.j, 0.8415+0.j],
         [0.8415+0.j, 0.0000+0.j]],

        [[0.0000+0.j, 0.9093+0.j],
         [0.9093+0.j, 0.0000+0.j]],

        [[0.0000+0.j, 0.1411+0.j],
         [0.1411+0.j, 0.0000+0.j]],

        [[0.0000+0.j, 1.0000+0.j],
         [1.0000+0.j, 0.0000+0.j]],

        [[0.0000+0.j, 2.0000+0.j],
         [2.0000+0.j, 0.0000+0.j]],

        [[0.0000+0.j, 3.0000+0.j],
         [3.0000+0.j, 0.0000+0.j]]], dtype=torch.complex128)

When the coefficientf functions are a mix such that the dimension of the largest function is a multiple of the dimension of any other functions, MagPy will repeat accordingly. For example,

In [4]:
H = (lambda t: 2*t)*X() + (lambda t: stack((sin(t), t)))*Y()

H(tensor(1))

tensor([[[0.+0.0000j, 2.-0.8415j],
         [2.+0.8415j, 0.+0.0000j]],

        [[0.+0.0000j, 2.-1.0000j],
         [2.+1.0000j, 0.+0.0000j]]], dtype=torch.complex128)

Function products may also be batched:

In [13]:
f = FP() * tensor([1,2]) * sin * cos

f(tensor(1))

tensor([0.4546, 0.9093])