In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_parallel_scan as tps

In [None]:
import torch

# Loop-based implementation for comparison
def parallel_scan_torch(a, b, v0):
    """
    Perform the parallel scan for vt = at ⊙ vt-1 + bt using a for loop (PyTorch).

    Parameters:
        a (torch.Tensor): Tensor of coefficients `a`.
        b (torch.Tensor): Tensor of coefficients `b`.
        v0 (float): Initial value of `v`.

    Returns:
        torch.Tensor: Tensor of computed `v` values.
    """
    n = len(a)
    v = torch.zeros_like(a)
    v[0] = a[0] * v0 + b[0]
    for t in range(1, n):
        v[t] = a[t] * v[t - 1] + b[t]
    return v

Result with loop (PyTorch):
tensor([1.0000, 2.6000, 4.8200, 7.8560])


In [None]:
def parallel_scan_no_loop_torch(a, b, v0):
    """
    Perform the parallel scan for vt = at ⊙ vt-1 + bt without using a for loop (PyTorch).

    Parameters:
        a (torch.Tensor): Tensor of coefficients `a`.
        b (torch.Tensor): Tensor of coefficients `b`.
        v0 (float): Initial value of `v`.

    Returns:
        torch.Tensor: Tensor of computed `v` values.
    """
    # Compute the cumulative product of `a`
    cumulative_a = torch.cumprod(a, dim=-1)
    
    # Compute the scaling factors for `b`
    # scaling_factors = torch.cat((torch.tensor([1.0], device=a.device), cumulative_a[:-1]))
    
    # Compute the contributions from `b`
    # weighted_b = b * scaling_factors
    b_contributions = torch.cumsum(b - cumulative_a, dim=-1)
    
    # Compute the final result
    v = a * b
    return v

In [58]:
# Example usage
a = torch.tensor([0.5, 0.6, 0.7, 0.8, 0.9, 1.2, 0.3])
b = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])

a = torch.rand((1, 5))
b = torch.rand((1, 5))
v0 = 0

# Results
result_with_loop = parallel_scan_torch(a, b, v0)
result_no_loop = parallel_scan_no_loop_torch(a, b, v0)

print("Result with loop (PyTorch):")
print(result_with_loop)

print("Result with no loop (PyTorch):")
print(result_no_loop)

Result with loop (PyTorch):
tensor([[0.7824, 0.7945, 0.7638, 0.6933, 0.0084]])
Result with no loop (PyTorch):
tensor([[2.4126e-02, 2.9535e-01, 4.1106e-02, 5.1348e-01, 1.7030e-04]])
