In [None]:
import torch

import numpy as np

import time

In [None]:
def to_sparse(x):
    """ converts dense tensor x to sparse format """
    x_typename = torch.typename(x).split('.')[-1]
    sparse_tensortype = getattr(torch.sparse, x_typename)

    indices = torch.nonzero(x)
    if len(indices.shape) == 0:  # if all elements are zeros
        return sparse_tensortype(*x.shape)
    indices = indices.t()
    values = x[tuple(indices[i] for i in range(indices.shape[0]))]
    return sparse_tensortype(indices, values, x.size())    

In [None]:
x = torch.rand((1000, 1))

sparse_time = torch.zeros(101)
dense_time = torch.zeros(101)

for i, beta in enumerate(np.linspace(0.998, 1, 101)):
    A = torch.rand((1000**2))

    indices = np.random.choice(np.arange(1000**2), int(1000**2*beta), replace=False)
    A[indices] = 0.
        
    dense_A = A.reshape(1000, 1000)
    sparse_A = to_sparse(dense_A)
    
    start = time.time()
    for k in range(1000):
        _ = torch.mm(dense_A, x)
    end = time.time()
    dense_time[i] = (end-start)/1000
        
    start = time.time()
    for k in range(1000):
        _ = torch.sparse.mm(sparse_A, x)
    end = time.time()
    sparse_time[i] = (end-start)/1000

In [None]:
import matplotlib.pyplot as plt

In [None]:
_ = plt.plot(np.linspace(99.8, 100, 101), dense_time)
_ = plt.plot(np.linspace(99.8, 100, 101), sparse_time)
_ = plt.legend(["dense operation", "sparse operation"])

In [None]:
2**9