In [6]:
import time
import torch
from clebsch_gordan import get_real_clebsch_gordan, ClebschGordan
from sparse_accumulation_plain_torch import sparse_accumulation_loops, sparse_accumulation_index_add
import sparse_accumulation
import numpy as np

In [7]:
L_MAX = 8
clebsch = ClebschGordan(L_MAX).precomputed_
indices = get_real_clebsch_gordan(clebsch[L_MAX, L_MAX, L_MAX], L_MAX, L_MAX, L_MAX)

In [8]:
m1_aligned, m2_aligned = [], []
multipliers, mu_aligned = [], []
for mu in range(0, 2 * L_MAX + 1):
    for el in indices[mu]:
        m1, m2, multiplier = el
        m1_aligned.append(m1)
        m2_aligned.append(m2)
        multipliers.append(multiplier)
        mu_aligned.append(mu)
m1_aligned = torch.LongTensor(m1_aligned)
m2_aligned = torch.LongTensor(m2_aligned)
mu_aligned = torch.LongTensor(mu_aligned)
multipliers = torch.FloatTensor(multipliers)

## forward pass

In [9]:
BATCH_SIZE = 1000
N_FEATURES = 100
X1 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1)
X2 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1)

print("python loops implementation forward:")
times = []
for _ in range(10):
    begin = time.time()
    output = sparse_accumulation_loops(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers)
    times.append(time.time() - begin)
print(np.mean(times[1:]))

python loops implementation forward:
0.31460971302456325


In [10]:
print("pytorch index_add_ implementation forward:")
times = []
for _ in range(10):
    begin = time.time()
    output = sparse_accumulation_index_add(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers)
    times.append(time.time() - begin)
print(np.mean(times[1:]))

pytorch index_add_ implementation forward:
0.338148381974962


In [11]:
print("cpp implementation forward:")
times = []
for _ in range(10):
    begin = time.time()
    output = sparse_accumulation.SparseAccumulation.apply(X1, X2, mu_aligned,
                                                          2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers)
    times.append(time.time() - begin)
    
print(np.mean(times[1:]))

cpp implementation forward:
0.08500968085394965


## backward pass

In [12]:
X1.requires_grad = True
X2.requires_grad = True

print("python loops implementation backward:")
times = []
for _ in range(10):
    begin = time.time()
    output = sparse_accumulation_loops(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers)
    output.backward(gradient=torch.ones_like(output))
    times.append(time.time() - begin)
print(np.mean(times[1:]))

python loops implementation backward:
4.437259833017985


In [13]:
print("pytorch index_add_ implementation backward:")
times = []
for _ in range(10):
    begin = time.time()
    output = sparse_accumulation_index_add(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers)
    output.backward(gradient=torch.ones_like(output))
    times.append(time.time() - begin)
print(np.mean(times[1:]))

pytorch index_add_ implementation backward:
0.8081594838036431


In [14]:
print("cpp implementation backward:")
times = []
for _ in range(10):
    begin = time.time()
    output = sparse_accumulation.SparseAccumulation.apply(X1, X2, mu_aligned,
                                                            2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers)
    output.backward(gradient=torch.ones_like(output))
    times.append(time.time() - begin)
    
print(np.mean(times[1:]))

cpp implementation backward:
0.20572347111172146
