In [4]:
import numpy as np
import torch
import time
import torch.nn as nn

torch.set_num_threads(1)

from threadpoolctl import threadpool_limits
np.show_config()

openblas64__info:
    libraries = ['openblas64_', 'openblas64_']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None), ('BLAS_SYMBOL_SUFFIX', '64_'), ('HAVE_BLAS_ILP64', None)]
    runtime_library_dirs = ['/usr/local/lib']
blas_ilp64_opt_info:
    libraries = ['openblas64_', 'openblas64_']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None), ('BLAS_SYMBOL_SUFFIX', '64_'), ('HAVE_BLAS_ILP64', None)]
    runtime_library_dirs = ['/usr/local/lib']
openblas64__lapack_info:
    libraries = ['openblas64_', 'openblas64_']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None), ('BLAS_SYMBOL_SUFFIX', '64_'), ('HAVE_BLAS_ILP64', None), ('HAVE_LAPACKE', None)]
    runtime_library_dirs = ['/usr/local/lib']
lapack_ilp64_opt_info:
    libraries = ['openblas64_', 'openblas64_']
    library_dirs = ['/usr/local/lib']
    language = c
    define_macros = [('HAVE_CBLAS', None

In [2]:
def perf_matmul(ashape, bshape):
    A = torch.rand(ashape)
    B = torch.rand(bshape)
    start = time.time()
    C = torch.matmul(A, B)
    print(C.size())
    print(time.time() - start)
    
def np_perf(ashape, bshape):
    A = np.random.randn(*ashape)
    B = np.random.randn(*bshape)
    start = time.time()
    with threadpool_limits(limits=1, user_api='blas'):
        C = A @ B
        print(C.shape)
    print(time.time() - start)

In [3]:
perf_matmul((1008, 64, 512), (512, 1536))

np_perf((1008, 64, 512), (512, 1536))

torch.Size([1008, 64, 1536])
2.458440065383911
(1008, 64, 1536)
5.642819404602051


In [4]:
perf_matmul((128, 1008, 1008), (128, 1008, 256))

np_perf((128, 1008, 1008), (128, 1008, 256))

torch.Size([128, 1008, 256])
1.9739768505096436
(128, 1008, 256)
3.0647637844085693


In [3]:
perf_matmul((64512, 512), (512, 512))

np_perf((64512, 512), (512, 512))

torch.Size([64512, 512])
0.7357625961303711
(64512, 512)
1.724242925643921


In [6]:
perf_matmul((1008, 64, 512), (512, 2048))

np_perf((1008, 64, 512), (512, 2048))

torch.Size([1008, 64, 2048])
2.8759877681732178
(1008, 64, 2048)
7.270247936248779


In [7]:
perf_matmul((1008, 64, 2048), (2048, 512))
np_perf((1008, 64, 2048), (2048, 512))

torch.Size([1008, 64, 512])
2.827690601348877
(1008, 64, 512)
7.1652514934539795


In [8]:
w = torch.nn.Linear(512, 2048)

x = torch.rand((64512, 512))

In [9]:
%%time

y = w(x)
y.size()

CPU times: user 2.96 s, sys: 660 ms, total: 3.62 s
Wall time: 3.62 s


torch.Size([64512, 2048])

In [10]:
input_names = ["x"]
output_names = ["y"]
dummy_input = x
torch.onnx.export(w,
                  dummy_input,
                  "./gemm.onnx",
                  verbose=False,
                  input_names=input_names,
                  output_names=output_names)