In [1]:
import torch
from torch.autograd import profiler

from torch_kinetics import models, reactions

In [2]:
reaction_1 = reactions.UniReaction(
    name="A->B",
    enzyme="enz_1",
    substrates=["A"],
    products=["B"],
    kcat=34.0,
    kma=500.0,
)
reaction_2 = reactions.UniReaction(
    name="B->C",
    enzyme="enz_2",
    substrates=["B"],
    products=["C"],
    kcat=200.0,
    kma=8000.0,
)

model = models.Model()
model.add_reaction(reaction_1)
model.add_reaction(reaction_2)

In [3]:
s0_rand = torch.rand(100, 5)
s0_rand

tensor([[0.5698, 0.3105, 0.6760, 0.0053, 0.2287],
        [0.4137, 0.9749, 0.2393, 0.1805, 0.5112],
        [0.5921, 0.0493, 0.5195, 0.4242, 0.2241],
        [0.5981, 0.7395, 0.6908, 0.3848, 0.9207],
        [0.5791, 0.4024, 0.8693, 0.6014, 0.6073],
        [0.8820, 0.2059, 0.3681, 0.3879, 0.3754],
        [0.9016, 0.4803, 0.0908, 0.5563, 0.4625],
        [0.7858, 0.3727, 0.4345, 0.8105, 0.8525],
        [0.5062, 0.9480, 0.1211, 0.1743, 0.2193],
        [0.7061, 0.9760, 0.0071, 0.8095, 0.3465],
        [0.0732, 0.1781, 0.5915, 0.2499, 0.5477],
        [0.2926, 0.0345, 0.0193, 0.2459, 0.6387],
        [0.8477, 0.2388, 0.6247, 0.9225, 0.6335],
        [0.1916, 0.5154, 0.8917, 0.8151, 0.8007],
        [0.9499, 0.3578, 0.2799, 0.5827, 0.8390],
        [0.3062, 0.2073, 0.5836, 0.6348, 0.1186],
        [0.1163, 0.5982, 0.0101, 0.1812, 0.6852],
        [0.6669, 0.8398, 0.0808, 0.9499, 0.1855],
        [0.5717, 0.9246, 0.2893, 0.7458, 0.4723],
        [0.9288, 0.0452, 0.2237, 0.9257, 0.5230],


In [4]:
model(torch.zeros(1), s0_rand)

tensor([[-1.2018e-02,  0.0000e+00,  1.1929e-02,  0.0000e+00,  8.9218e-05],
        [-2.7406e-02,  0.0000e+00,  2.6326e-02,  0.0000e+00,  1.0798e-03],
        [-1.9808e-03,  0.0000e+00, -3.5272e-03,  0.0000e+00,  5.5079e-03],
        [-3.0039e-02,  0.0000e+00,  2.3395e-02,  0.0000e+00,  6.6446e-03],
        [-1.5829e-02,  0.0000e+00,  2.7608e-03,  0.0000e+00,  1.3068e-02],
        [-1.2327e-02,  0.0000e+00,  8.7578e-03,  0.0000e+00,  3.5695e-03],
        [-2.9394e-02,  0.0000e+00,  2.8132e-02,  0.0000e+00,  1.2625e-03],
        [-1.9882e-02,  0.0000e+00,  1.1078e-02,  0.0000e+00,  8.8038e-03],
        [-3.2601e-02,  0.0000e+00,  3.2073e-02,  0.0000e+00,  5.2773e-04],
        [-4.6793e-02,  0.0000e+00,  4.6649e-02,  0.0000e+00,  1.4351e-04],
        [-8.8630e-04,  0.0000e+00, -2.8094e-03,  0.0000e+00,  3.6957e-03],
        [-6.8571e-04,  0.0000e+00,  5.6700e-04,  0.0000e+00,  1.1872e-04],
        [-1.3742e-02,  0.0000e+00, -6.6394e-04,  0.0000e+00,  1.4406e-02],
        [-6.7144e-03,  0.

In [10]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    s_prime = model(torch.zeros(1), s0_rand)

STAGE:2023-08-08 21:08:46 45415:2784074 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-08-08 21:08:46 45415:2784074 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-08-08 21:08:46 45415:2784074 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


### Baseline execution

In [24]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=10))

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::mul        16.96%      49.000us        16.96%      49.000us      12.250us       1.56 Kb       1.56 Kb             4  
            aten::expand        12.46%      36.000us        13.49%      39.000us       3.250us           0 b           0 b            12  
             aten::index        10.03%      29.000us        12.80%      37.000us      18.500us          16 b          16 b             2  
       aten::scatter_add        10.03%      29.000us        12.80%      37.000us       9.250us       7.81 Kb       7.81 Kb             4  
            aten::gather   

### Inplace `scatter_ad_`

In [7]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=10))

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::mul        15.23%      39.000us        15.23%      39.000us       9.750us       1.56 Kb       1.56 Kb             4  
      aten::scatter_add_        12.11%      31.000us        12.11%      31.000us       7.750us           0 b           0 b             4  
            aten::expand        11.33%      29.000us        12.11%      31.000us       2.583us           0 b           0 b            12  
             aten::empty        10.55%      27.000us        10.55%      27.000us       9.000us          20 b          20 b             3  
             aten::index   

### Replace `gather` and `scatter_add_` with `index_select` and `index_add_`

In [11]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=10))

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
        aten::index_add_        21.29%      66.000us        28.71%      89.000us      22.250us           0 b           0 b             4  
               aten::mul        13.23%      41.000us        13.23%      41.000us      10.250us       1.56 Kb       1.56 Kb             4  
               aten::div         9.68%      30.000us         9.68%      30.000us      15.000us         800 b         800 b             2  
             aten::empty         8.71%      27.000us         8.71%      27.000us       3.857us          20 b          20 b             7  
            aten::select   

## Observations
- Surprisingly the version with explicit `expand` combined with `gather` and `scatter_add_` has the best time performance
- Decided to leave the `index_select` and `index_add_` version because of the better readability