In [18]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
from tabulate import tabulate
from tqdm.auto import tqdm

from models.bitlinear import BitLinear

In [19]:
def test_forward_speed(layer_class, input_sizes, num_runs=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = {}

    for size in tqdm(input_sizes):
        in_features = size[-1]
        out_features = size[-1]
        layer = layer_class(in_features, out_features).to(device)
        input_tensor = torch.randn(size).to(device)

        # Warm-up run
        _ = layer(input_tensor)

        # Timing runs
        times = []
        for _ in range(num_runs):
            start_time = time.perf_counter()
            _ = layer(input_tensor)
            end_time = time.perf_counter()
            times.append(end_time - start_time)

        avg_time = sum(times) / num_runs
        results[size] = avg_time

    return results

# Example usage
input_sizes = [
    (1, 64),
    (1, 640),
    (1, 6400),
    (8, 64),
    (8, 640),
    (8, 6400),
    (128, 64),
    (128, 640),
    (128, 6400),
]

layer_classes = [nn.Linear, BitLinear]  # Add more layer classes as needed
layer_results = {}

for layer_class in layer_classes:
    layer_results[layer_class.__name__] = test_forward_speed(layer_class, input_sizes)

print("Speed comparison matrix:")
headers = ["Input Size"] + [cls.__name__ for cls in layer_classes]
table_data = []

for size in input_sizes:
    row = [str(size)]
    for layer_class in layer_classes:
        row.append(f"{layer_results[layer_class.__name__][size]:.6f}")
    table_data.append(row)

print(tabulate(table_data, headers, tablefmt="grid"))

table_data = []
for layer1 in layer_classes:
    row = [layer1.__name__]
    for layer2 in layer_classes:
        if layer1 == layer2:
            row.append("1.00")
        else:
            time1_sum = sum(layer_results[layer1.__name__].values())
            time2_sum = sum(layer_results[layer2.__name__].values())
            ratio = time1_sum / time2_sum
            row.append(f"{ratio:.2f}")
    table_data.append(row)

print(tabulate(table_data, headers, tablefmt="grid"))

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Speed comparison matrix:
+--------------+----------+-------------+
| Input Size   |   Linear |   BitLinear |
| (1, 64)      |  3.4e-05 |    0.000438 |
+--------------+----------+-------------+
| (1, 640)     |  3.7e-05 |    0.000496 |
+--------------+----------+-------------+
| (1, 6400)    |  3.3e-05 |    0.000529 |
+--------------+----------+-------------+
| (8, 64)      |  3.6e-05 |    0.000451 |
+--------------+----------+-------------+
| (8, 640)     |  3.7e-05 |    0.000463 |
+--------------+----------+-------------+
| (8, 6400)    |  3.8e-05 |    0.000463 |
+--------------+----------+-------------+
| (128, 64)    |  3.3e-05 |    0.00045  |
+--------------+----------+-------------+
| (128, 640)   |  3.8e-05 |    0.000478 |
+--------------+----------+-------------+
| (128, 6400)  |  3.5e-05 |    0.000531 |
+--------------+----------+-------------+
+--------------+----------+-------------+
| Input Size   |   Linear |   BitLinear |
| Linear       |     1    |        0.07 |
+--------