In [1]:
import torch

In [2]:
# Loop-based implementation for comparison
def parallel_scan_torch(a, b, v0):
    """
    Perform the parallel scan for vt = at ⊙ vt-1 + bt using a for loop (PyTorch).

    Parameters:
        a (torch.Tensor): Tensor of coefficients `a`.
        b (torch.Tensor): Tensor of coefficients `b`.
        v0 (float): Initial value of `v`.

    Returns:
        torch.Tensor: Tensor of computed `v` values.
    """
    n = len(a)
    v = torch.zeros_like(a)
    v[0] = a[0] * v0 + b[0]
    for t in range(1, n):
        v[t] = a[t] * v[t - 1] + b[t]
    return v

In [4]:
# Example usage
a = torch.tensor([0.5, 0.6, 0.7, 0.8])
b = torch.tensor([1.0, 2.0, 3.0, 4.0])

v0 = 0

# Results
result_with_loop = parallel_scan_torch(a, b, v0)

print("Result with loop (PyTorch):")
print(result_with_loop)


Result with loop (PyTorch):
tensor([1.0000, 2.6000, 4.8200, 7.8560])


In [None]:
def parallel_scan_no_loop_torch(a: torch.Tensor, b: torch.Tensor, v0: torch.Tensor):
    """
    Perform the parallel scan for vt = at ⊙ vt-1 + bt without using a for loop (PyTorch).

    Parameters:
        a (torch.Tensor): Tensor of coefficients `a`.
        b (torch.Tensor): Tensor of coefficients `b`.
        v0 (float): Initial value of `v`.

    Returns:
        torch.Tensor: Tensor of computed `v` values.
    """

    seq_length = a.shape[1]
    v = torch.zeros_like(a)
    v[0] = a[0] * v0 + b[0]
    print(f"seq_length: {seq_length}")
    for i in range(seq_length):
        v[:,i,:] = a[:,i,:] * v[:, i - 1, :] + b[:,i,:]

    return v

In [10]:
a = torch.rand((1, 4, 3))
b = torch.rand((1, 4, 3))
v0 = 0

result = parallel_scan_no_loop_torch(a, b, v0)

4
