<a href="https://colab.research.google.com/github/igorvere/tests/blob/main/torch_cumsum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import time
import jax


def cumsum_torch(x, y0):
    n_times, n_batch = x.shape
    assert len(y0.shape) == 2

    all_y = []
    y = y0
    for i in range(n_times):
        y = x[i] + y
        all_y.append(y)

    all_y = torch.cat(all_y)

    return all_y
    
@torch.jit.script
def cumsum_torch_jit(x, y0):
    n_times, n_batch = x.shape
    assert len(y0.shape) == 2

    all_y = []
    y = y0
    for i in range(n_times):
        y = x[i] + y
        all_y.append(y)

    all_y = torch.cat(all_y)

    return all_y




seq_len = 10 ** 5
batch_size = 2 ** 10
nruns = 5

def get_time():
    torch.cuda.synchronize()
    return time.time()

#x = torch.zeros(seq_len, batch_size).cuda().requires_grad_(True)
x = torch.zeros(seq_len, batch_size).cuda()
y0 = torch.zeros(batch_size).cuda()
        
    
for _ in range(nruns):
    with torch.no_grad():
        t = get_time()
        y2 = cumsum_torch_jit(x, y0.unsqueeze(0))
        print(f'jit elapsed: {get_time()-t} sec')
        

for _ in range(nruns):
    with torch.no_grad():
        t = get_time()
        y2 = cumsum_torch(x, y0.unsqueeze(0))
        print(f'py elapsed: {get_time()-t} sec')


@jax.jit
def cumsum_jax(x, y0):
  def body(yi, xi):
    yi += xi
    return yi, yi
  _, y = jax.lax.scan(body, y0, x, unroll=256)
  return y


x = jax.numpy.zeros((seq_len, batch_size))
y0 = jax.numpy.zeros((batch_size, ))

print(jax.devices())

for _ in range(nruns):
  t = get_time()
  y = cumsum_jax(x, y0)
  y.block_until_ready()
  print(f'jax elapsed: {get_time()-t} sec')


jit elapsed: 2.705786943435669 sec
jit elapsed: 0.7739927768707275 sec
jit elapsed: 0.5666718482971191 sec
jit elapsed: 0.5901196002960205 sec
jit elapsed: 0.6054491996765137 sec
py elapsed: 1.7551472187042236 sec
py elapsed: 1.7276153564453125 sec
py elapsed: 1.7300012111663818 sec
py elapsed: 1.7638764381408691 sec
py elapsed: 1.676008939743042 sec
[GpuDevice(id=0, process_index=0)]
jax elapsed: 4.911156892776489 sec
jax elapsed: 0.14863204956054688 sec
jax elapsed: 0.13698363304138184 sec
jax elapsed: 0.12814617156982422 sec
jax elapsed: 0.12300992012023926 sec
