In [1]:
import os
import torch.nn as nn
import torch
import time
import numpy as np
import matplotlib.pyplot as plt


def gen_linear(inc, oc):
    return nn.Linear(inc, oc)


def gen_svd_linear(inc, oc, k):
    ALinear = nn.Linear(inc, k)
    BLinear = nn.Linear(k, oc)
    return nn.Sequential(ALinear, BLinear)




In [11]:
results_mha={}
results_mlp={}
for num_threads in [1,4,16]:
    torch.set_num_threads(num_threads)

    fig = plt.figure(figsize=(8, 6))
    inc = 4096
    oc = 4096
    N = 100
    batchsize = 1
    raw_linear = gen_linear(inc, oc)
    latency_ratios = []
    # ratios=np.linspace(0.75, 1, 25)
    ratios=[0.75, 0.8, 0.85, 0.9, 0.95]
    for param_ratio in ratios:
        k = int(inc * oc * param_ratio / (inc + oc))
        svd_linear = gen_svd_linear(inc, oc, k)
        # test speed
        x = torch.randn(batchsize, inc)

        # warm up
        for i in range(10):
            raw_linear(x)
        # test
        st = time.time()
        for i in range(N):
            raw_linear(x)
        et = time.time()
        raw_time = (et - st) / N
        st = time.time()
        for i in range(N):
            svd_linear(x)
        et = time.time()
        svd_time = (et - st) / N
        print(f"param_ratio={param_ratio} raw_time={raw_time} svd_time={svd_time}")
        latency_ratio = svd_time/raw_time
        latency_ratios.append(latency_ratio)

    # plt.plot(ratios, latency_ratios, label="MHA")
    results_mha[num_threads]=latency_ratios

    inc = 4096
    oc = 11008
    raw_linear = gen_linear(inc, oc)
    latency_ratios = []
    for param_ratio in ratios:
        k = int(inc * oc * param_ratio / (inc + oc))
        svd_linear = gen_svd_linear(inc, oc, k)
        # test speed
        x = torch.randn(batchsize, inc)

        # warm up
        for i in range(100):
            raw_linear(x)
        # test
        st = time.time()
        for i in range(N):
            raw_linear(x)
        et = time.time()
        raw_time = (et - st) / N
        st = time.time()
        for i in range(N):
            svd_linear(x)
        et = time.time()
        svd_time = (et - st) / N
        print(f"param_ratio={param_ratio} raw_time={raw_time} svd_time={svd_time}")
        latency_ratio = svd_time/raw_time
        latency_ratios.append(latency_ratio)
    results_mlp[num_threads]=latency_ratios
    # plt.plot(ratios, latency_ratios, label="MLP")
    # plt.ylabel("latency ratio")
    # plt.xlabel("param ratio")
    # plt.legend()




param_ratio=0.75 raw_time=0.005185337066650391 svd_time=0.004072341918945312
param_ratio=0.8 raw_time=0.005265669822692871 svd_time=0.004499285221099853
param_ratio=0.85 raw_time=0.005416505336761475 svd_time=0.004612176418304443
param_ratio=0.9 raw_time=0.005344386100769043 svd_time=0.004829003810882569
param_ratio=0.95 raw_time=0.005350379943847656 svd_time=0.005088210105895996
param_ratio=0.75 raw_time=0.014723942279815674 svd_time=0.011187601089477538
param_ratio=0.8 raw_time=0.01471515655517578 svd_time=0.01181518316268921
param_ratio=0.85 raw_time=0.014762394428253174 svd_time=0.01264045000076294
param_ratio=0.9 raw_time=0.015127789974212647 svd_time=0.01362471580505371
param_ratio=0.95 raw_time=0.014756803512573241 svd_time=0.014102323055267334
param_ratio=0.75 raw_time=0.0016720247268676757 svd_time=0.0012467050552368165
param_ratio=0.8 raw_time=0.0015723776817321777 svd_time=0.0013126659393310547
param_ratio=0.85 raw_time=0.0017038536071777343 svd_time=0.00142578125
param_rati

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

In [12]:

print(",".join([str(r) for r in ratios]))
# print("MHA")
for num_threads in [1,4,16]:
    print(f"{num_threads},", end="")
    for r in results_mha[num_threads]:
        print(f"{r:.2f},", end="")
    print("")
# print("MLP")
for num_threads in [1,4,16]:
    print(f"{num_threads},", end="")
    for r in results_mlp[num_threads]:
        print(f"{r:.2f},", end="")
    print("")

0.75,0.8,0.85,0.9,0.95
1,0.79,0.85,0.85,0.90,0.95,
4,0.75,0.83,0.84,0.92,1.00,
16,0.70,0.77,0.85,0.88,0.98,
1,0.76,0.80,0.86,0.90,0.96,
4,0.74,0.82,0.87,0.91,0.96,
16,0.73,0.78,0.67,0.93,0.96,


In [13]:
# print with chart format
import pandas as pd
df = pd.DataFrame(results_mha)
df.index = ratios
df.to_csv("mha.csv")
df = pd.DataFrame(results_mlp)
df.index = ratios
df.to_csv("mlp.csv")