In [43]:
from deepspeed.profiling.flops_profiler import get_model_profile
from utils import calc_bspline_flops
from models import LinearReLU
from models import LinearBSpline
import linspline
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
import time
import torch

def profile_model(arch, layers, ctrl=3, range_=1, input_size=8, batch_size=10):

    if(arch == "ReLU"):
        model = LinearReLU(layers)
    elif(arch == "BSpline"):
        model = LinearBSpline(layers, ctrl, range_)
    elif(arch == "LSpline"):
        model = linspline.LSplineFromBSpline(LinearBSpline(layers, ctrl, range_).get_layers())

    base_flops, macs, params = get_model_profile(model=model, # model
            input_shape=(batch_size, input_size), # input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
            args=None, # list of positional arguments to the model.
            kwargs=None, # dictionary of keyword arguments to the model.
            print_profile=False, #! prints the model graph with the measured profile attached to each module
            detailed=True, # print the detailed profile
            module_depth=-1, # depth into the nested modules, with -1 being the inner most modules
            top_modules=1, # the number of top modules to print aggregated profile
            warm_up=10, # the number of warm-ups before measuring the time of each module
            as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
            output_file=None, # path to the output file. If None, the profiler prints to stdout.
            ignore_modules=None) # the list of modules to ignore in the profiling
    
    flops = float(base_flops.replace("K", "").strip()) * 1000 / batch_size

    if(arch == "BSpline"):
        flops += calc_bspline_flops(model)

    housing = fetch_california_housing()
    input = torch.tensor(housing.data, dtype=torch.float32)

    start_time = time.perf_counter()
    _ = model(input) # model output is irrelevant
    end_time = time.perf_counter()
    fwd_lat = (end_time - start_time) * 1000 * 1000 / len(input) # per sample latency: seconds -> microsec per input
    fwd_lat = round(fwd_lat, 4)

    return flops, params, fwd_lat

In [51]:
from IPython.display import clear_output
from tabulate import tabulate

print("FLOPS may be inaccurate for LSpline")
print("LSpline params should prob just be = bspline params")

archs_layers = [
    # ("ReLU", [8]),
    # ("ReLU", [24, 8]),
    # ("ReLU", [24, 48, 24, 8]),
    # ("ReLU", [24, 48, 96, 24, 8]),

    # ("BSpline", [8]),
    # ("BSpline", [24, 8]),
    # ("BSpline", [24, 48, 24, 8]),
    # ("BSpline", [24, 48, 96, 24, 8]),

    # ("LSpline", [8]),
    # ("LSpline", [24, 8]),
    # ("LSpline", [24, 48, 24, 8]),
    # ("LSpline", [24, 48, 96, 24, 8]),
]
store = []

for (arch, layers) in archs_layers:
    flops, params, fwd_lat = profile_model(arch, layers)
    store.append([f"{arch} {layers}", params, flops, fwd_lat])
clear_output()

headers = ["Model", "Params", "FLOPs/input", "μs/input"]
print(tabulate(store, headers=headers, tablefmt="grid"))

+-----------------+----------+---------------+------------+
| Model           |   Params |   FLOPs/input |   μs/input |
| ReLU [24, 8]    |      425 |           816 |     0.0604 |
+-----------------+----------+---------------+------------+
| BSpline [24, 8] |      521 |          1936 |     0.2434 |
+-----------------+----------+---------------+------------+
| LSpline [24, 8] |      425 |           784 |     0.5252 |
+-----------------+----------+---------------+------------+
