In [1]:
%%capture
!pip3 uninstall torch -y
!pip3 uninstall torchvision -y
!pip3 uninstall torchaudio -y
!pip3 install torch --extra-index-url https://download.pytorch.org/whl/cu116
!pip3 install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.13.0+cu116.html

In [2]:
from functools import partial
from itertools import product
import pandas as pd
import torch
import torch.utils.benchmark as benchmark
from torch_scatter import scatter
from tqdm import tqdm
from torch_geometric.nn.models.schnet import GaussianSmearing, InteractionBlock

In [3]:
torch.__version__

'1.13.0+cu116'

In [4]:
class GatherOp(torch.nn.Module):
    def __init__(self, num_inputs, num_features, num_outputs) -> None:
        super().__init__()
        input = torch.randn(num_inputs, num_features)
        index = torch.randint(num_inputs, (num_outputs,))
        self.register_buffer("input", input)
        self.register_buffer("index", index)

    def forward(self):
        return self.input.index_select(dim=0, index=self.index)


class ScatterOp(torch.nn.Module):
    def __init__(self, num_inputs, num_features, num_outputs, reduce) -> None:
        super().__init__()
        self.scatter = partial(scatter, dim_size=num_outputs, reduce=reduce)
        input = torch.randn(num_inputs, num_features)
        index = torch.randint(num_outputs, (num_inputs,))
        self.register_buffer("input", input)
        self.register_buffer("index", index)

    def forward(self):
        return self.scatter(src=self.input, index=self.index, dim=0)


class ScatterAddOp(ScatterOp):
    def __init__(self, *args) -> None:
        super().__init__(*args, reduce="add")


class ScatterMinOp(ScatterOp):
    def __init__(self, *args) -> None:
        super().__init__(*args, reduce="min")


class ScatterMaxOp(ScatterOp):
    def __init__(self, *args) -> None:
        super().__init__(*args, reduce="max")


class InteractionBlockOp(torch.nn.Module):
    def __init__(self, num_inputs, num_features, num_outputs) -> None:
        super().__init__()
        self.block = InteractionBlock(
            hidden_channels=num_features,
            num_gaussians=50,
            num_filters=num_features,
            cutoff=6.0,
        )
        input = torch.randn(num_inputs, num_features)
        src = torch.randint(num_inputs, (num_outputs,))
        dst = torch.randint(num_inputs, (num_outputs,))
        edge_index = torch.stack([src, dst])
        edge_weight = torch.empty(num_outputs).uniform_(1.0, 6.0)

        grbf = GaussianSmearing(start=0.0, stop=6.0, num_gaussians=50)
        edge_attr = grbf(edge_weight)

        self.register_buffer("input", input)
        self.register_buffer("edge_index", edge_index)
        self.register_buffer("edge_weight", edge_weight)
        self.register_buffer("edge_attr", edge_attr)

    def forward(self):
        return self.block(self.input, self.edge_index, self.edge_weight, self.edge_attr)

In [5]:
def mpbench(
    operation: str,
    seed: int,
    num_inputs: int,
    num_features: int,
    num_outputs: int,
    device: str = "cpu",
) -> int:
    setup = f"""
    from __main__ import {operation}
    torch.manual_seed(seed)
    op = {operation}(num_inputs, num_features, num_outputs)
    op.to(device='{device}')
    """

    inputs = {
        "seed": seed,
        "num_inputs": num_inputs,
        "num_features": num_features,
        "num_outputs": num_outputs,
    }

    if device == "cuda":
        torch.cuda.empty_cache()

    t = benchmark.Timer(stmt="op()", setup=setup, globals=inputs)
    result = t.blocked_autorange(min_run_time=2.0)
    # convert to microseconds
    return 1e6 * result.mean, 1e6 * result.iqr


def run_sweep(operation: str):
    # nodes ~ 16, 32, ..., 4096
    num_inputs = [2**e for e in range(4, 13)]

    # embedding size ~ 16, 32, ..., 512
    num_features = [2**e for e in range(4, 10)]

    # edges ~ 32, 64, ..., 32768
    num_outputs = [2**e for e in range(5, 16)]

    grid = product(num_inputs, num_features, num_outputs)

    seed = 0
    data = []

    for num_inputs, num_features, num_outputs in tqdm(list(grid)):
        metrics = {
            "num_inputs": num_inputs,
            "num_features": num_features,
            "num_outputs": num_outputs,
        }
        time, iqr = mpbench(operation, seed, device="cuda", **metrics)
        data.append({**metrics, "time": time, "iqr": iqr})

    df = pd.DataFrame(data)
    df.to_pickle(f"{operation}_a100.pickle")

In [6]:
run_sweep("GatherOp")

100%|██████████| 594/594 [22:04<00:00,  2.23s/it]


In [7]:
run_sweep("ScatterAddOp")

100%|██████████| 594/594 [23:47<00:00,  2.40s/it]


In [8]:
run_sweep("InteractionBlockOp")

100%|██████████| 594/594 [27:52<00:00,  2.82s/it]


In [9]:
run_sweep("ScatterMaxOp")

100%|██████████| 594/594 [26:48<00:00,  2.71s/it]


In [10]:
run_sweep("ScatterMinOp")

100%|██████████| 594/594 [26:27<00:00,  2.67s/it]
