In [5]:
import torch
import torch.nn as nn
import time

def einsum(a,b): return torch.einsum('ij, jk -> ik', a, b)


X = torch.exp(1j * 2 * 3.1426 * torch.randn(15000, 15000, device='cpu'))
torch.manual_seed(1)

t0 = time.time()
netofmodel = nn.Linear(in_features=15000, out_features=1, bias=False, device='cpu', dtype=torch.cfloat)
Y = netofmodel(X)
t1 = time.time()

print("\nX * X^T [ (15000 x 15000) * (15000 x 15000)^T matrice di valori complessi] \n ")
print("Time using Neural Network Fully Connected: ", t1 - t0, " [sec]")
print("First element:", X[0][0])
print("Last element:", X[14999][14999])

t0 = time.time()
Y = einsum(X,X)
t1 = time.time()

print("\nTime using torch.matmul default function: ", t1 - t0, " [sec]")

print("First element:", X[0][0])
print("Last element:", X[14999][14999])

t0 = time.time()
Y = torch.matmul(X,X)
t1 = time.time()

print("\nTime using einsum optimized** function: ", t1 - t0, " [sec]")
print("First element:", X[0][0])
print("Last element:", X[14999][14999])

print("\n** Einsum reference: https://arxiv.org/pdf/2204.06045.pdf \n")


X * X^T [ (15000 x 15000) * (15000 x 15000)^T matrice di valori complessi] 
 
Time using Neural Network Fully Connected:  0.7437911033630371  [sec]
First element: tensor(-0.7220+0.6919j)
Last element: tensor(0.8627+0.5058j)

Time using torch.matmul default function:  40.42633318901062  [sec]
First element: tensor(-0.7220+0.6919j)
Last element: tensor(0.8627+0.5058j)

Time using einsum optimized** function:  46.65758204460144  [sec]
First element: tensor(-0.7220+0.6919j)
Last element: tensor(0.8627+0.5058j)

** Einsum reference: https://arxiv.org/pdf/2204.06045.pdf 

