In [62]:
import timeit

import numpy as np
import torch

In [None]:
inp = torch.randn(1024).to('cuda')
weights = [torch.randn(1024, 1024).to('cuda') for _ in range(50)]
n_runs = 10
n_repeat = 50


### Single-layer version ###

def single_layer_pass():
    torch.cuda.empty_cache()
    results = inp @ weights[0]
    torch.cuda.synchronize()
    return results

# Warmup
for _ in range(3):
    results = inp @ weights[0]

single_time = np.mean(timeit.repeat(single_layer_pass, number=n_runs, repeat=n_repeat)) / n_runs


### Multi-layer version ###

def all_layers_pass():
    torch.cuda.empty_cache()
    results = []
    for w in enumerate(weights):
        results.append(inp @ w)
    torch.cuda.synchronize()
    return results

# Warmup
for _ in range(3):
    for w in weights:
        results = inp @ w

multi_time = np.mean(timeit.repeat(all_layers_pass, number=n_runs, repeat=n_repeat)) / n_runs / len(weights)


### Multi-stream version ###

def all_layers_stream_pass():
    torch.cuda.empty_cache()
    streams = [torch.cuda.Stream() for _ in range(len(weights))]
    results = []
    for w, stream in zip(weights, streams):
        with torch.cuda.stream(stream):
            results.append(inp @ w)
    torch.cuda.synchronize()
    return results

# Warmup
for _ in range(3):
    for w in weights:
        results = inp @ w

multi_stream_time = np.mean(timeit.repeat(all_layers_stream_pass, number=n_runs, repeat=n_repeat)) / n_runs / len(weights)
        

print(f"Single layer time: {single_time}")
print(f"All layers time: {multi_time}")
print(f"All layers stream time: {multi_stream_time}")
print(f"Speedup: {single_time / multi_time}")
print(f"All layers stream speedup: {single_time / multi_stream_time}")


Single layer time: 0.00011067800200044076
All layers time: 2.1284480039994376e-05
All layers stream time: 5.6550740200009385e-05
Speedup: 5.199939194778188
All layers stream speedup: 1.9571450631590919
