In [None]:
import torch
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from metriwatt.profiler import TorchProfiler
from metriwatt.ncu import NCUProfiler
from metriwatt.network import ToyNetwork, run_toy_network_forward_ncu, run_toy_network_forward_backward_ncu, construct_toy_network_and_input_for_ncu

In [None]:
def toy_network_forward_flops(dim, n_layers, n_tokens):
    with TorchProfiler() as prof:
        net = ToyNetwork(n_layers=n_layers, dim=dim)
        x = torch.randn(dim, n_tokens, device="cuda")
        with prof.record_context("forward"):
            _ = net(x)
    return prof.get_flops_by_step().loc["forward", "flops"]


def toy_network_backward_flops(dim, n_layers, n_tokens):
    with TorchProfiler() as prof:
        net = ToyNetwork(n_layers=n_layers, dim=dim)
        x = torch.randn(dim, n_tokens, device="cuda")
        y = net(x)
        with prof.record_context("backward"):
            y.sum().backward()
    return prof.get_flops_by_step().loc["backward", "flops"]


def toy_network_forward_flops_ncu(dim, n_layers, n_tokens):
    ncu = NCUProfiler()
    _ = ncu.profile_function(run_toy_network_forward_ncu, {
        "dim": dim,
        "n_layers": n_layers,
        "n_tokens": n_tokens,
    })
    flops = ncu.get_total_flops()

    ncu2 = NCUProfiler()
    ncu2.profile_function(construct_toy_network_and_input_for_ncu, {
        "dim": dim,
        "n_layers": n_layers,
        "n_tokens": n_tokens,
    })
    setup_flops = ncu2.get_total_flops()
    flops -= setup_flops

    ncu.result.to_csv("experiments/toy_network/results/toy_network_forward_flops_ncu.csv")
    ncu2.result.to_csv("experiments/toy_network/results/toy_network_setup_flops_ncu.csv")

    return flops


def toy_network_forward_backward_flops_ncu(dim, n_layers, n_tokens):
    ncu = NCUProfiler()
    _ = ncu.profile_function(run_toy_network_forward_backward_ncu, {
        "dim": dim,
        "n_layers": n_layers,
        "n_tokens": n_tokens,
    })
    flops = ncu.get_total_flops()

    ncu.result.to_csv("experiments/toy_network/results/toy_network_forward_backward_flops_ncu.csv")

    ncu2 = NCUProfiler()
    ncu2.profile_function(construct_toy_network_and_input_for_ncu, {
        "dim": dim,
        "n_layers": n_layers,
        "n_tokens": n_tokens,
    })
    setup_flops = ncu2.get_total_flops()
    flops -= setup_flops

    return flops


def toy_network_params(dim, n_layers):
    net = ToyNetwork(n_layers=n_layers, dim=dim)
    return sum(p.numel() for p in net.parameters() if p.requires_grad)

In [None]:
N = 10      # number of layers
D = 1024    # size of input/output dimension
M = 128     # sequence length

In [8]:
baseline_num_params = toy_network_params(D, N)
print(baseline_num_params)

baseline_forward_flops_exp = toy_network_forward_flops(D, N, M)
baseline_backward_flops_exp = toy_network_backward_flops(D, N, M)

baseline_forward_flops_ncu = toy_network_forward_flops_ncu(D, N, M)
baseline_backward_flops_ncu = toy_network_forward_backward_flops_ncu(D, N, M) - baseline_forward_flops_ncu

baseline_forward_flops_theory = N * (2 * D * D * M)
baseline_backward_flops_theory = (N - 1) * (4 * D * D * M) + 2 * D * D * M

baseline_forward_flops_theory_exact = baseline_forward_flops_theory + N * (D * M)
baseline_backward_flops_theory_exact = baseline_backward_flops_theory + (N - 1) * (2 * D * M) + D * M

baseline_forward_flops_theory_exact_gpu = baseline_forward_flops_theory_exact # TODO: add actual equation
baseline_backward_flops_theory_exact_gpu = baseline_backward_flops_theory_exact # TODO: add actual equation

baseline_df = pd.DataFrame(
    {
        "forward_flops": [baseline_forward_flops_exp, baseline_forward_flops_ncu, baseline_forward_flops_theory, baseline_forward_flops_theory_exact, baseline_forward_flops_theory_exact_gpu],
        "backward_flops": [baseline_backward_flops_exp, baseline_backward_flops_ncu, baseline_backward_flops_theory, baseline_backward_flops_theory_exact, baseline_backward_flops_theory_exact_gpu],
    },
    index=["torch profiler (experimental)", "ncu profiler (experimental)", "theoretical (analyzed)", "theoretical exact (analyzed)", "theoretical exact gpu implementation (analyzed)"],
)
pd.set_option('display.float_format', '{:,.0f}'.format)
print(baseline_df)

10485760
                                                 forward_flops  backward_flops
torch profiler (experimental)                    2,684,354,560   5,100,273,664
ncu profiler (experimental)                      2,727,606,309   5,134,953,113
theoretical (analyzed)                           2,684,354,560   5,100,273,664
theoretical exact (analyzed)                     2,685,665,280   5,102,764,032
theoretical exact gpu implementation (analyzed)  2,685,665,280   5,102,764,032
