In [62]:
import timeit

import numpy as np
import torch

In [None]:
inp = torch.randn(1024).to('cuda')
weights = [torch.randn(1024, 1024).to('cuda') for _ in range(50)]
n_runs = 10
n_repeat = 50


### Single-layer version ###

def single_layer_pass():
    torch.cuda.empty_cache()
    results = inp @ weights[0]
    torch.cuda.synchronize()
    return results

# Warmup
for _ in range(3):
    results = inp @ weights[0]

single_time = np.mean(timeit.repeat(single_layer_pass, number=n_runs, repeat=n_repeat)) / n_runs


### Multi-layer version ###

def all_layers_pass():
    torch.cuda.empty_cache()
    results = []
    for w in enumerate(weights):
        results.append(inp @ w)
    torch.cuda.synchronize()
    return results

# Warmup
for _ in range(3):
    for w in weights:
        results = inp @ w

multi_time = np.mean(timeit.repeat(all_layers_pass, number=n_runs, repeat=n_repeat)) / n_runs / len(weights)


### Multi-stream version ###

def all_layers_stream_pass():
    torch.cuda.empty_cache()
    streams = [torch.cuda.Stream() for _ in range(len(weights))]
    results = []
    for w, stream in zip(weights, streams):
        with torch.cuda.stream(stream):
            results.append(inp @ w)
    torch.cuda.synchronize()
    return results

# Warmup
for _ in range(3):
    for w in weights:
        results = inp @ w

multi_stream_time = np.mean(timeit.repeat(all_layers_stream_pass, number=n_runs, repeat=n_repeat)) / n_runs / len(weights)
        

print(f"Single layer time: {single_time}")
print(f"All layers time: {multi_time}")
print(f"All layers stream time: {multi_stream_time}")
print(f"Speedup: {single_time / multi_time}")
print(f"All layers stream speedup: {single_time / multi_stream_time}")


Single layer time: 0.00011067800200044076
All layers time: 2.1284480039994376e-05
All layers stream time: 5.6550740200009385e-05
Speedup: 5.199939194778188
All layers stream speedup: 1.9571450631590919


In [291]:
import jax
import jax.numpy as jnp
from jax import jit

key = jax.random.PRNGKey(0)
inp = jax.random.normal(key, (1024,))
keys = jax.random.split(key, 50)
# Convert list of weights to a single array
weights = jnp.stack([jax.random.normal(k, (1024, 1024)) for k in keys])
n_runs = 10
n_repeat = 50


### Single-layer version ###

@jit
def single_layer_pass(inp, weight):
    return jnp.dot(inp, weight)

# Warmup
for _ in range(3):
    jax.block_until_ready(single_layer_pass(inp, weights[0]))

single_time = np.mean(timeit.repeat(lambda: jax.block_until_ready(single_layer_pass(inp, weights[0])), 
                                  number=n_runs, repeat=n_repeat)) / n_runs


### Multi-layer version ###

@jit
def all_layers_pass(inp, weights):
    return jax.vmap(lambda w: jnp.dot(inp, w))(weights)

# Warmup
for _ in range(3):
    jax.block_until_ready(all_layers_pass(inp, weights))

multi_time = np.mean(timeit.repeat(lambda: jax.block_until_ready(all_layers_pass(inp, weights)),
                                 number=n_runs, repeat=n_repeat)) / n_runs / len(weights)


### Multi-layer version with only jitted dot ###

@jit
def jitted_dot(inp, weight):
    return jnp.dot(inp, weight)

def all_layers_pass_jitted_dot(inp, weights):
    result = jax.vmap(lambda w: jitted_dot(inp, w))(weights)
    return jax.block_until_ready(result)

# Warmup
for _ in range(3):
    all_layers_pass_jitted_dot(inp, weights)

multi_time_jitted_dot = np.mean(timeit.repeat(lambda: all_layers_pass_jitted_dot(inp, weights),
                                            number=n_runs, repeat=n_repeat)) / n_runs / len(weights)


print(f"Single layer time: {single_time}")
print(f"All layers time: {multi_time}")
print(f"All layers time (jitted dot): {multi_time_jitted_dot}")
print(f"Speedup: {single_time / multi_time}")
print(f"Speedup (jitted dot): {single_time / multi_time_jitted_dot}")


Single layer time: 0.0008562467960009599
All layers time: 2.982936385999892e-05
All layers time (jitted dot): 3.021226001997093e-05
Speedup: 28.704829242073917
Speedup (jitted dot): 28.341037560082
