In [1]:
import time
import torch
torch.set_num_threads(1)
from clebsch_gordan import get_real_clebsch_gordan, ClebschGordan
from sparse_accumulation_plain_torch import sparse_accumulation_loops, sparse_accumulation_index_add
import sparse_accumulation, sparse_accumulation_active_dim_first,  sparse_accumulation_active_dim_middle
import numpy as np

In [2]:
L_MAX = 8
clebsch = ClebschGordan(L_MAX).precomputed_
indices = get_real_clebsch_gordan(clebsch[L_MAX, L_MAX, L_MAX], L_MAX, L_MAX, L_MAX)

In [3]:
m1_aligned, m2_aligned = [], []
multipliers, mu_aligned = [], []
for mu in range(0, 2 * L_MAX + 1):
    for el in indices[mu]:
        m1, m2, multiplier = el
        m1_aligned.append(m1)
        m2_aligned.append(m2)
        multipliers.append(multiplier)
        mu_aligned.append(mu)
m1_aligned = torch.LongTensor(m1_aligned)
m2_aligned = torch.LongTensor(m2_aligned)
mu_aligned = torch.LongTensor(mu_aligned)
multipliers = torch.FloatTensor(multipliers)

In [4]:
def benchmark_forward(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials):
    if active_dim == 0:
        X1 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES)
        X2 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES)
    
    if active_dim == 1:
        X1 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES)
        X2 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES)
        
    if active_dim == 2:
        X1 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1)
        X2 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1)
    
   
        
    if (active_dim != 0) and (active_dim != 2) and (active_dim != 1):
        raise ValueError("active dim should be one of 0, 1, 2")
    times = []
    for _ in range(n_trials):
        begin = time.time()
        output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers)
        times.append(time.time() - begin)
    return times


def benchmark_backward(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials):
    if active_dim == 0:
        X1 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES)
        X2 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES)
    
    if active_dim == 1:
        X1 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES)
        X2 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES)
        
    if active_dim == 2:
        X1 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1)
        X2 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1)
        
    if (active_dim != 0) and (active_dim != 2) and (active_dim != 1):
        raise ValueError("active dim should be one of 0, 1, 2")
        
    X1.requires_grad = True
    X2.requires_grad = True
    times = []
    for _ in range(n_trials):
        begin = time.time()
        output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers)
        output.backward(gradient=torch.ones_like(output))
        times.append(time.time() - begin)
    return np.array(times)

In [5]:
def get_func_fixed_dim(func, active_dim):
    def func_fixed_dim(*args):
        return func(*args, active_dim = active_dim)
    return func_fixed_dim

In [6]:
BATCH_SIZE = 1000
N_FEATURES = 100
times = benchmark_forward(BATCH_SIZE, N_FEATURES, 1, 
                          get_func_fixed_dim(sparse_accumulation_index_add, 1), 10)
print("torch index_add_; active dim 1; forward: ", np.mean(times[1:]))

torch index_add_; active dim 1; forward:  0.2133361498514811


In [7]:
BATCH_SIZE = 1000
N_FEATURES = 100

times = benchmark_forward(BATCH_SIZE, N_FEATURES, 0, 
                          get_func_fixed_dim(sparse_accumulation_loops, 0), 10)
print("python loops; active dim 0; forward: ", np.mean(times[1:]))
times = benchmark_forward(BATCH_SIZE, N_FEATURES, 0, 
                          get_func_fixed_dim(sparse_accumulation_index_add, 0), 10)
print("torch index_add_; active dim 0; forward: ", np.mean(times[1:]))
times = benchmark_forward(BATCH_SIZE, N_FEATURES, 0,
                          sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply, 10)
print("cpp; active dim 0; forward: ", np.mean(times[1:]))

print()
times = benchmark_forward(BATCH_SIZE, N_FEATURES, 1, 
                          get_func_fixed_dim(sparse_accumulation_loops, 1), 10)
print("python loops; active dim 1; forward: ", np.mean(times[1:]))
times = benchmark_forward(BATCH_SIZE, N_FEATURES, 1, 
                          get_func_fixed_dim(sparse_accumulation_index_add, 1), 10)
print("torch index_add_; active dim 1; forward: ", np.mean(times[1:]))
times = benchmark_forward(BATCH_SIZE, N_FEATURES, 1,
                          sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply, 10)
print("cpp; active dim 1; forward: ", np.mean(times[1:]))

print()
times = benchmark_forward(BATCH_SIZE, N_FEATURES, 2, get_func_fixed_dim(sparse_accumulation_loops, 2), 10)
print("python loops; active dim 2; forward: ", np.mean(times[1:]))
times = benchmark_forward(BATCH_SIZE, N_FEATURES, 2, get_func_fixed_dim(sparse_accumulation_index_add, 2), 10)
print("torch index_add_; active dim 2; forward: ", np.mean(times[1:]))
times = benchmark_forward(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation.SparseAccumulation.apply, 10)
print("cpp; active dim 2; forward: ", np.mean(times[1:]))



python loops; active dim 0; forward:  0.04021528032090929
torch index_add_; active dim 0; forward:  0.205434693230523
cpp; active dim 0; forward:  0.02603549427456326

python loops; active dim 1; forward:  0.06679489877488878
torch index_add_; active dim 1; forward:  0.2108128865559896
cpp; active dim 1; forward:  0.022336641947428387

python loops; active dim 2; forward:  0.40611039267645943
torch index_add_; active dim 2; forward:  0.7698346243964301
cpp; active dim 2; forward:  0.06612857182820638


In [8]:
times = benchmark_backward(BATCH_SIZE, N_FEATURES, 0, 
                           get_func_fixed_dim(sparse_accumulation_loops, 0), 10)
print("python loops; active dim 0; backward: ", np.mean(times[1:]))
times = benchmark_backward(BATCH_SIZE, N_FEATURES, 0, 
                           get_func_fixed_dim(sparse_accumulation_index_add, 0), 10)
print("torch index_add_; active dim 0; backward: ", np.mean(times[1:]))
times = benchmark_backward(BATCH_SIZE, N_FEATURES, 0,
                           sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply, 10)
print("cpp; active dim 0; backward: ", np.mean(times[1:]))

print()

times = benchmark_backward(BATCH_SIZE, N_FEATURES, 1, 
                           get_func_fixed_dim(sparse_accumulation_loops, 1), 10)
print("python loops; active dim 1; backward: ", np.mean(times[1:]))
times = benchmark_backward(BATCH_SIZE, N_FEATURES, 1, 
                           get_func_fixed_dim(sparse_accumulation_index_add, 1), 10)
print("torch index_add_; active dim 1; backward: ", np.mean(times[1:]))
times = benchmark_backward(BATCH_SIZE, N_FEATURES, 1,
                           sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply, 10)
print("cpp; active dim 1; backward: ", np.mean(times[1:]))


print()
times = benchmark_backward(BATCH_SIZE, N_FEATURES, 2, 
                           get_func_fixed_dim(sparse_accumulation_index_add, 2), 10)
print("python loops; active dim 2; backward: ", np.mean(times[1:]))
times = benchmark_backward(BATCH_SIZE, N_FEATURES, 2, 
                           get_func_fixed_dim(sparse_accumulation_index_add, 2), 10)
print("torch index_add_; active dim 2; backward: ", np.mean(times[1:]))
times = benchmark_backward(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation.SparseAccumulation.apply, 10)
print("cpp; active dim 2; backward: ", np.mean(times[1:]))





python loops; active dim 0; backward:  2.34902556737264
torch index_add_; active dim 0; backward:  0.3824967013465034
cpp; active dim 0; backward:  0.07314239607916938

python loops; active dim 1; backward:  2.920709212621053
torch index_add_; active dim 1; backward:  0.405303160349528
cpp; active dim 1; backward:  0.06241893768310547

python loops; active dim 2; backward:  1.6288407113817003
torch index_add_; active dim 2; backward:  1.4928550985124376
cpp; active dim 2; backward:  0.1249254544576009
