## Performance Comparison of `torch.einsum` and `opt_einsum`*

### *Why Compare?*

When working with large tensors and complex contraction patterns, the performance of tensor contraction operations can significantly impact the overall efficiency of your code. `torch.einsum` and `opt_einsum` are two popular libraries for tensor contractions, but they have different optimization strategies and performance characteristics. By comparing their performance, you can choose the best library for your specific use case and optimize your code for better performance.

### *Comparison Example*
#### Contraction Pattern
   $$\sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{m=1}^{M} A_{hmn} B_{hn} = C_{hm}$$
   - e.g, `hmn,hn->hm`

#### Tensor Size
   - $H \in \{3, 10, 30, 100, 300, 1000\}$
   - $M \in \{4, 10, 30, 100, 300, 1000\}$
   - $N \in \{4, 10, 30, 100, 300, 1000\}$
#### Iteration
   -  1000




In [None]:
import timeit
import torch
from opt_einsum import contract

# Define the tensor sizes
H_values = [3, 10, 30]#, 100, 300, 1000]
N_values = [4, 10, 30]#, 100, 300, 1000]
iterations = 1000
constaction_pattern = "hmn,hn->hm"

# Initialize the timing results
torch_einsum_times = []
opt_einsum_times = []

for H, N in zip(H_values, N_values):
    A = torch.randn(H, N, N)
    B = torch.randn(H, N)

    torch_einsum_time = timeit.timeit(
        lambda: torch.einsum(constaction_pattern, A, B),
        number=iterations
    )

    opt_einsum_time = timeit.timeit(
        lambda: contract(constaction_pattern, A, B),
        number=iterations
    )

    torch_einsum_times.append(torch_einsum_time)
    opt_einsum_times.append(opt_einsum_time)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

sns.set_style("whitegrid")
colors = sns.color_palette("husl", 2)

torch_einsum_std = np.std(torch_einsum_times)
opt_einsum_std = np.std(opt_einsum_times)

# Plot with error bars
plt.errorbar(
    x=H_values, 
    y=torch_einsum_times, 
    yerr=torch_einsum_std, 
    label='Torch Einsum', 
    capsize=5, 
    fmt='-o',
    markersize=5,
    color=colors[0]
)
plt.errorbar(
    x=H_values, 
    y=opt_einsum_times, 
    yerr=opt_einsum_std, 
    label='Opt Einsum', 
    capsize=5, 
    fmt='-o',
    markersize=5,
    color=colors[1]
)

plt.xlabel('Tensor Size (H)', fontsize=12)
plt.ylabel('Time (sec)', fontsize=12)
plt.title('Performance Comparison', fontsize=14)

plt.legend(fontsize=10)
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()
