In [1]:
import torch
from torch.autograd import profiler

from torch_kinetics import models, reactions

In [2]:
reaction_1 = reactions.UniReaction(
    name="A->B",
    enzyme="enz_1",
    substrates=["A"],
    products=["B"],
    kcat=34.0,
    kma=500.0,
)
reaction_2 = reactions.UniReaction(
    name="B->C",
    enzyme="enz_2",
    substrates=["B"],
    products=["C"],
    kcat=200.0,
    kma=8000.0,
)

model = models.Model()
model.add_reaction(reaction_1)
model.add_reaction(reaction_2)

In [None]:
s0_rand = torch.rand(100, 5)
s0_rand

In [None]:
model(torch.zeros(1), s0_rand)

In [10]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    s_prime = model(torch.zeros(1), s0_rand)

STAGE:2023-08-08 21:08:46 45415:2784074 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-08-08 21:08:46 45415:2784074 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-08-08 21:08:46 45415:2784074 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


### Baseline execution

In [24]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=10))

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::mul        16.96%      49.000us        16.96%      49.000us      12.250us       1.56 Kb       1.56 Kb             4  
            aten::expand        12.46%      36.000us        13.49%      39.000us       3.250us           0 b           0 b            12  
             aten::index        10.03%      29.000us        12.80%      37.000us      18.500us          16 b          16 b             2  
       aten::scatter_add        10.03%      29.000us        12.80%      37.000us       9.250us       7.81 Kb       7.81 Kb             4  
            aten::gather   

### Replace `scatter_ad` with Inplace `scatter_ad_`

In [7]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=10))

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::mul        15.23%      39.000us        15.23%      39.000us       9.750us       1.56 Kb       1.56 Kb             4  
      aten::scatter_add_        12.11%      31.000us        12.11%      31.000us       7.750us           0 b           0 b             4  
            aten::expand        11.33%      29.000us        12.11%      31.000us       2.583us           0 b           0 b            12  
             aten::empty        10.55%      27.000us        10.55%      27.000us       9.000us          20 b          20 b             3  
             aten::index   

### Replace `gather` and `scatter_add_` with `index_select` and `index_add_`

In [11]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=10))

------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
        aten::index_add_        21.29%      66.000us        28.71%      89.000us      22.250us           0 b           0 b             4  
               aten::mul        13.23%      41.000us        13.23%      41.000us      10.250us       1.56 Kb       1.56 Kb             4  
               aten::div         9.68%      30.000us         9.68%      30.000us      15.000us         800 b         800 b             2  
             aten::empty         8.71%      27.000us         8.71%      27.000us       3.857us          20 b          20 b             7  
            aten::select   

## Observations
- Surprisingly the version with explicit `expand` combined with `gather` and `scatter_add_` has the best time performance
- Decided to leave the `index_select` and `index_add_` version because of the better readability