In [1]:
import torch
from torch.autograd import profiler

from torch_kinetics import models, reactions

In [2]:
reaction_1 = reactions.UniReaction(
    name="A->B",
    enzyme="enz_1",
    substrates=["A"],
    products=["B"],
    kcat=34.0,
    kma=500.0,
)
reaction_2 = reactions.UniReaction(
    name="B->C",
    enzyme="enz_2",
    substrates=["B"],
    products=["C"],
    kcat=200.0,
    kma=8000.0,
)

model = models.Model()
model.add_reaction(reaction_1)
model.add_reaction(reaction_2)

In [4]:
s0_rand = torch.rand(100, 5)
s0_rand

tensor([[0.0459, 0.3841, 0.7612, 0.3249, 0.1214],
        [0.3647, 0.5830, 0.3504, 0.9644, 0.9378],
        [0.8772, 0.0606, 0.8906, 0.9406, 0.3259],
        [0.7089, 0.0340, 0.5659, 0.0398, 0.5260],
        [0.9666, 0.4584, 0.3739, 0.3480, 0.3593],
        [0.6400, 0.7897, 0.8562, 0.3439, 0.3230],
        [0.2041, 0.7083, 0.0769, 0.6814, 0.9705],
        [0.3675, 0.8909, 0.5795, 0.8384, 0.7731],
        [0.9262, 0.0495, 0.2007, 0.6320, 0.3265],
        [0.3806, 0.2958, 0.9957, 0.2053, 0.9646],
        [0.1763, 0.2350, 0.0262, 0.4376, 0.9142],
        [0.4053, 0.7869, 0.2378, 0.3642, 0.0080],
        [0.2135, 0.0456, 0.3292, 0.7331, 0.0976],
        [0.3492, 0.8689, 0.3159, 0.8769, 0.8521],
        [0.4990, 0.7753, 0.8849, 0.5283, 0.6433],
        [0.5632, 0.3618, 0.7504, 0.3337, 0.0189],
        [0.8558, 0.2932, 0.2698, 0.4072, 0.3008],
        [0.4060, 0.5259, 0.2552, 0.1239, 0.9383],
        [0.2302, 0.2808, 0.0808, 0.2171, 0.5885],
        [0.7052, 0.6302, 0.2904, 0.8150, 0.7986],


In [5]:
model(torch.zeros(1), s0_batched)

tensor([[-129.5238,    0.0000,  129.5238,    0.0000,    0.0000],
        [-125.4854,    0.0000, -249.4727,    0.0000,  374.9581],
        [-117.0273,    0.0000, -242.3004,    0.0000,  359.3277],
        [ -99.7565,    0.0000, -117.5474,    0.0000,  217.3040],
        [ -71.2393,    0.0000,  -14.0079,    0.0000,   85.2472]],
       grad_fn=<ScatterAddBackward0>)

In [6]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    s_prime = model(torch.zeros(1), s0_rand)

STAGE:2023-08-08 20:34:24 79765:2665984 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-08-08 20:34:24 79765:2665984 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-08-08 20:34:24 79765:2665984 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


### Baseline execution

In [24]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=10))

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::mul        16.96%      49.000us        16.96%      49.000us      12.250us       1.56 Kb       1.56 Kb             4  
            aten::expand        12.46%      36.000us        13.49%      39.000us       3.250us           0 b           0 b            12  
             aten::index        10.03%      29.000us        12.80%      37.000us      18.500us          16 b          16 b             2  
       aten::scatter_add        10.03%      29.000us        12.80%      37.000us       9.250us       7.81 Kb       7.81 Kb             4  
            aten::gather   

### Inplace `scatter_ad_`

In [7]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=10))

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::mul        15.23%      39.000us        15.23%      39.000us       9.750us       1.56 Kb       1.56 Kb             4  
      aten::scatter_add_        12.11%      31.000us        12.11%      31.000us       7.750us           0 b           0 b             4  
            aten::expand        11.33%      29.000us        12.11%      31.000us       2.583us           0 b           0 b            12  
             aten::empty        10.55%      27.000us        10.55%      27.000us       9.000us          20 b          20 b             3  
             aten::index   