In [14]:
from deepspeed.profiling.flops_profiler import get_model_profile
import utils, importlib
from models import LinearReLU, LinearBSpline
import linspline
from sklearn.datasets import fetch_california_housing
import time, torch
from torch.utils.data import TensorDataset, DataLoader

importlib.reload(utils)

def profile_model(arch, layers, ctrl=3, range_=1, input_size=8, batch_size=10):

    if(arch == "ReLU"):
        model = LinearReLU(layers)
    elif(arch == "BSpline"):
        model = LinearBSpline(layers, ctrl, range_)
    elif(arch == "LSpline"):
        model = linspline.LSplineFromBSpline(LinearBSpline(layers, ctrl, range_).get_layers())

    #^ FLOPs and Params

    base_flops, macs, params = get_model_profile(model=model, # model
            input_shape=(batch_size, input_size), # input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
            args=None, # list of positional arguments to the model.
            kwargs=None, # dictionary of keyword arguments to the model.
            print_profile=False, #! prints the model graph with the measured profile attached to each module
            detailed=True, # print the detailed profile
            module_depth=-1, # depth into the nested modules, with -1 being the inner most modules
            top_modules=1, # the number of top modules to print aggregated profile
            warm_up=10, # the number of warm-ups before measuring the time of each module
            as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
            output_file=None, # path to the output file. If None, the profiler prints to stdout.
            ignore_modules=None) # the list of modules to ignore in the profiling
    
    if(base_flops.__contains__("K")):
        flops = float(base_flops.replace("K", "").strip()) * 1000 / batch_size
    elif(base_flops.__contains__("M")):
        flops = float(base_flops.replace("M", "").strip()) * 1000000 / batch_size

    if(arch == "BSpline"):
        flops += utils.calc_bspline_flops(model) #! this might be incorrect- need to check w the implementation
    if(arch == "LSpline"):
        flops += utils.calc_lspline_flops(model)
    
    housing = fetch_california_housing()

    
    #^ Forward Latency

    per = 1000 # number of sims
    
    inpts = [ torch.rand(batch_size, input_size) * 3 - 1.5 for _ in range(per) ]
    start_time = time.perf_counter()
    for i in range(per):
        _ = model(inpts[i]) # model output is irrelevant
    end_time = time.perf_counter()
    fwd_lat_sim = round((end_time - start_time) * 1000 * 1000 / per / batch_size, 2) # per sample latency: seconds -> microsec per input

    
    X, y = torch.tensor(housing.data, dtype=torch.float32), torch.tensor(housing.target, dtype=torch.float32).reshape(-1, 1)
    # using a dataloader to randomize batching
    train_dataset = TensorDataset(X, y)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    inpts = []
    for i in range(per):
        inpts.append(next(iter(train_loader))[0]) # pre-emptively do this so it doesn't affect timing
    
    start_time = time.perf_counter()
    for i in range(per):
        _ = model(inpts[i]) # model output is irrelevant
    end_time = time.perf_counter()
    
    fwd_lat_real = round((end_time - start_time) * 1000 * 1000 / per / batch_size, 2) # per sample latency: seconds -> microsec per input
    
    return flops, params, fwd_lat_real

In [48]:
from IPython.display import clear_output
from tabulate import tabulate

import importlib

importlib.reload(linspline)

print("FLOPS may be inaccurate for LSpline")
print("LSpline params should be = bspline params")

archs_layers = [

    #^ Testing number of control points:

    # ("BSpline", [8]),
    # ("BSpline", [24, 8]),
    # ("BSpline", [24, 48, 24, 8]),
    # ("BSpline", [24, 48, 96, 48, 24, 8]),

    # ("ReLU", [8]),
    # ("ReLU", [24, 8]),
    # ("ReLU", [24, 48, 24, 8]),
    # ("ReLU", [8, 48, 192, 48, 8]),
    # ("ReLU", [24, 48, 96, 24, 8]),

    # ("BSpline", [8]),
    # ("BSpline", [24, 8]),
    # ("BSpline", [24, 48, 24, 8]),
    # ("BSpline", [24, 48, 96, 24, 8]),

    # ("LSpline", [8]),
    # ("LSpline", [24, 8]),
    # ("LSpline", [24, 48, 24, 8]),
    # ("LSpline", [24, 48, 96, 24, 8]),

    #^ Controlling for parameters:
    # ("BSpline", [8]),
    # ("ReLU", [11]),

    # ("BSpline", [24, 8]),
    # ("ReLU", [30, 8]),

    # ("BSpline", [24, 48, 24, 8]),
    # ("ReLU", [26, 50, 26, 8]),

    # ("BSpline", [24, 48, 96, 48, 24, 8]),
    # ("ReLU", [26, 50, 98, 50, 26, 8]),

    #^ Controlling for FLOPS
    #! Why doesn't FLOPs correspond to forward latency?
    # ("BSpline", [8]),
    # ("ReLU", [24]),

    # ("BSpline", [24, 8]),
    # ("ReLU", [64, 8]),

    # ("BSpline", [24, 48, 24, 8]),
    # ("ReLU", [24, 96, 24, 8]),

    # ("BSpline", [24, 48, 96, 48, 24, 8]),
    # ("ReLU", [24, 48, 144, 48, 24, 8]),

    #^ Controlling for fwd lat

    # ("BSpline", [8]),
    # ("ReLU", [16, 32, 64, 128, 64, 32, 16, 8]),

    # ("BSpline", [16, 8]),
    # ("ReLU", [24, 48, 96, 192, 384, 576, 384, 192, 96, 48, 24, 8]),

    # ("BSpline", [24, 48, 24, 8]),
    # ("ReLU", [24, 48, 192, 768, 1152, 2304, 1152, 768, 192, 48, 24, 8]),

]

store = []

# for ctrl in [3,5,11,23,55,111]:
ctrl=3

for (arch, layers) in archs_layers:
    try:
        flops, params, fwd_lat_real = profile_model(arch, layers, ctrl=ctrl)
        store.append([f"{arch} {layers}", params, flops, fwd_lat_real])
    except Exception as e:
        print("Error: ", e, "on ", (arch, layers))
        store.append([f"{arch} {layers}", f"{e}", 0, 0])
        
    clear_output()

headers = ["Model", "Params", "FLOPs/input", "μs/input"]
print(tabulate(store, headers=headers, tablefmt="grid"))

+----------------------------------------------------------------+----------+----------------+------------+
| Model                                                          | Params   |    FLOPs/input |   μs/input |
| BSpline [8]                                                    | 105      |   432          |      30.88 |
+----------------------------------------------------------------+----------+----------------+------------+
| ReLU [16, 32, 64, 128, 64, 32, 16, 8]                          | 22.13 K  | 43896          |      32.43 |
+----------------------------------------------------------------+----------+----------------+------------+
| BSpline [16, 8]                                                | 361      |  1392          |      56.94 |
+----------------------------------------------------------------+----------+----------------+------------+
| ReLU [24, 48, 96, 192, 384, 576, 384, 192, 96, 48, 24, 8]      | 640.67 K |     1.279e+06  |      61    |
+---------------------------

In [None]:
import linspline
import importlib

importlib.reload(linspline)

print(profile_model("BSpline", [8]))
print(profile_model("LSpline", [8]))