## Inference

In this notebook, we will show how to use our best model (or any other model saved during training) to generate kinetic models.

### Setup

---

Assuming you have downloaded the best models, and you are at the root of the repository, then start the container with the following command (make sure to replace the output and inference models paths with your own):

```bash
apptainer shell --nv --bind "$(pwd)":/home/renaissance/work --bind "/scratch/izar/$USER/rl-for-kinetics/output:/home/renaissance/output" --bind "/home/$USER/rl-gen-of-kinetic-models/inference:/home/renaissance/inference" /scratch/izar/$USER/images/renaissance_with_ml.sif
```


Then, start the Jupyter notebook with the following command:

```bash
jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser
````

Below commands make sure you are in the right directory and have all the necessary imports.

In [1]:
import os
os.chdir('..') # important: run this cell only once!

In [2]:
import os

import wandb
from omegaconf import OmegaConf

import torch

from helpers.utils import setup_kinetic_env, sample_params
from helpers.ppo_agent import PolicyNetwork

### Generate kinetic parameters

---

Define name of the model you want to use for inference.

In [3]:
selected_model = "best_model" # name of the folder in your best_models directory 
ckpt_dir = f"/home/renaissance/inference/{selected_model}"
policy_path = f"{ckpt_dir}/policy.pt"
config_path = f"{ckpt_dir}/config.yaml"
normalisation_path = f"{ckpt_dir}/normalisation.pt"

Load the model configuration.

In [4]:
cfg = OmegaConf.load(config_path)

Load the kinetic RL environment (which includes the reward function) Take around 30 seconds. In addition, it prints the configuration of the environment and model.

In [5]:
kinetic_env = setup_kinetic_env(cfg)
kinetic_env.logging_enabled = False

--------------------------------------------------
env:
  p_size: 384
  action_scale: 1
seed: 42
paths:
  names_km: data/varma_ecoli_shikki/parameter_names_km_fdp1.pkl
  output_dir: /home/renaissance/output
  met_model_name: varma_ecoli_shikki
device: cpu
logger:
  tags:
  - debug
  - best_setup
  - clip_eps_decay
  - reproduce
  entity: ludekcizinsky
  project: rl-renaissance
method:
  name: ppo_refinement
  actor_lr: 0.0001
  critic_lr: 0.001
  gae_lambda: 0.98
  max_log_std: 2
  min_log_std: -6
  clip_eps_end: 0.1
  parameter_dim: 384
  clip_eps_start: 0.3
  discount_factor: 0.99
  value_loss_weight: 0.5
  entropy_loss_weight: 0.01
reward:
  eig_partition: -2.5
training:
  batch_size: 25
  num_epochs: 10
  num_episodes: 100
  max_grad_norm: 0.5
  save_trained_models: true
  max_steps_per_episode: 50
  n_eval_samples_in_episode: 50
launch_cmd: train.py logger.tags=[debug, best_setup, clip_eps_decay, reproduce]
constraints:
  max_km: 3
  min_km: -25
  ss_idx: 1712
lr_scheduler:
  name

Load the policy network.

In [6]:
policy = PolicyNetwork(cfg)
policy.load_pretrained_policy_net(policy_path)

FYI: Loaded pretrained policy network from /home/renaissance/inference/best_model/policy.pt.


Load normalisation parameters to scale the states during inference.

In [7]:
normalisation = torch.load(normalisation_path)
obs_mean, obs_var = normalisation["obs_mean"], normalisation["obs_var"]

Finally, sample the kinetic parameters ($N$ is the number of samples, $\text{max\_steps}$ is the maximum number of env steps for each sample).

If deterministic is set to `True`, the policy will always take the mean as the action. If verbose is set to `True`, it will print the max eigenvalue and reward for each step.

In [8]:
output = sample_params(policy, kinetic_env, obs_mean, obs_var, N=10, max_steps=50, verbose=True, deterministic=False)

Sampling parameters:   0%|          | 0/10 [00:00<?, ?it/s]

[0]: max_eig: 1286.6741302019732, reward: 0.0010000020611536182
[1]: max_eig: 95.74345568879855, reward: 0.0010000020611536182
[2]: max_eig: 0.3227273018170841, reward: 0.05710832020412048
[3]: max_eig: 561.2529998133442, reward: 0.0010000020611536182
[4]: max_eig: 0.5135783828163806, reward: 0.04781620312404489
[5]: max_eig: 2.372957367329585, reward: 0.00859261697165222
[6]: max_eig: 1.29042514316338, reward: 0.023087137494391622
[7]: max_eig: -0.42674525248723333, reward: 0.11272362424820181
[8]: max_eig: -1.7536339531521101, reward: 0.32261363606294313
[9]: max_eig: -1.5562585935990072, reward: 0.28114521581425106
[10]: max_eig: -0.7931721452601949, reward: 0.15457561039892273
[11]: max_eig: -0.6250504368538443, reward: 0.13397005490560865
[12]: max_eig: -1.0680156674880774, reward: 0.1937896905213515
[13]: max_eig: -1.4083786183658895, reward: 0.25231308463418906
[14]: max_eig: -0.9241762924918994, reward: 0.17238776031509637
[15]: max_eig: -1.9759917778139229, reward: 0.372915455

Sampling parameters:  10%|█         | 1/10 [00:05<00:46,  5.18s/it]

[17]: max_eig: -2.5764463436210723, reward: 0.5201022839298046
[0]: max_eig: 104.73993219560546, reward: 0.0010000020611536182
[1]: max_eig: 8.28913805436602, reward: 0.0010206218641927334
[2]: max_eig: -0.05698509111369617, reward: 0.0809508582955542
[3]: max_eig: -0.05698548808375075, reward: 0.08095088749616759
[4]: max_eig: -0.0569063881561251, reward: 0.08094506919823431
[5]: max_eig: -0.10988675134866159, reward: 0.08492972430517984
[6]: max_eig: 0.08022402805200408, reward: 0.07142206397850658
[7]: max_eig: 0.22850302643608403, reward: 0.062312261684060004
[8]: max_eig: -0.8859654609680433, reward: 0.16702922412334503
[9]: max_eig: -1.4795104594287287, reward: 0.26593205493403005
[10]: max_eig: -1.1724119606509529, reward: 0.21055861166164774
[11]: max_eig: -1.4429061913577446, reward: 0.2588652232368188
[12]: max_eig: -1.3066617382995593, reward: 0.2336624220419001
[13]: max_eig: -1.7776791895921327, reward: 0.3278821295290538
[14]: max_eig: -1.0318572182998917, reward: 0.18822

Sampling parameters:  20%|██        | 2/10 [00:19<01:23, 10.38s/it]

[49]: max_eig: -0.1117540422157276, reward: 0.08507340353122111
[0]: max_eig: 24.902780603217614, reward: 0.0010000020611536182
[1]: max_eig: 25.96396093989543, reward: 0.0010000020611536182
[2]: max_eig: -0.017174473020853943, reward: 0.07807097912765915
[3]: max_eig: 108.87859708447749, reward: 0.0010000020611536182
[4]: max_eig: 104.46909302798099, reward: 0.0010000020611536182
[5]: max_eig: 37.22655953315994, reward: 0.0010000020611536182
[6]: max_eig: -0.3859783926754971, reward: 0.10874144634968114
[7]: max_eig: -0.39001243148503506, reward: 0.1091298656212032
[8]: max_eig: -0.3983020562967804, reward: 0.10993189956193132
[9]: max_eig: -0.385907674934719, reward: 0.10873464821362845
[10]: max_eig: -0.3941464632476948, reward: 0.10952918887674926
[11]: max_eig: -0.40403166272744734, reward: 0.11048929546808745
[12]: max_eig: -0.4084289666179176, reward: 0.11091877551672447
[13]: max_eig: -0.4048231615432571, reward: 0.11056649155894595
[14]: max_eig: -0.4034233572178502, reward: 0

Sampling parameters:  30%|███       | 3/10 [00:27<01:06,  9.52s/it]

[28]: max_eig: -2.6132289326601015, reward: 0.5292770284621189
[0]: max_eig: -0.41309100271096183, reward: 0.11137572362146687
[1]: max_eig: -0.8213407714253336, reward: 0.1582730907577058
[2]: max_eig: -0.932514016754161, reward: 0.17357508008785194
[3]: max_eig: -1.0188422863185642, reward: 0.18625261716162525
[4]: max_eig: 2341.4636756445, reward: 0.0010000020611536182
[5]: max_eig: 5262.883266092898, reward: 0.0010000020611536182
[6]: max_eig: 1671.7976180900112, reward: 0.0010000020611536182
[7]: max_eig: 140.44379076104028, reward: 0.0010000020611536182
[8]: max_eig: 2052.197019832958, reward: 0.0010000020611536182
[9]: max_eig: 4163.999596455726, reward: 0.0010000020611536182
[10]: max_eig: 3578.693311262437, reward: 0.0010000020611536182
[11]: max_eig: 4070.712393954, reward: 0.0010000020611536182
[12]: max_eig: 4494.473208419784, reward: 0.0010000020611536182
[13]: max_eig: 4228.227864401281, reward: 0.0010000020611536182
[14]: max_eig: 3305.1483019497837, reward: 0.0010000020

Sampling parameters:  40%|████      | 4/10 [00:42<01:09, 11.56s/it]

[49]: max_eig: -0.00198311135686942, reward: 0.07699732048597448
[0]: max_eig: 1165.766198790649, reward: 0.0010000020611536182
[1]: max_eig: 4418.858207949433, reward: 0.0010000020611536182
[2]: max_eig: -0.29607683645054106, reward: 0.10039874052444255
[3]: max_eig: 3661.1165813433554, reward: 0.0010000020611536182
[4]: max_eig: 3733.9048949983817, reward: 0.0010000020611536182
[5]: max_eig: 7099.644083833706, reward: 0.0010000020611536182
[6]: max_eig: 63.27742456877301, reward: 0.0010000020611536182
[7]: max_eig: 460.66865745377385, reward: 0.0010000020611536182
[8]: max_eig: 292.3579618427961, reward: 0.0010000020611536182
[9]: max_eig: 174.2998669967932, reward: 0.0010000020611536182
[10]: max_eig: 70.58529930352222, reward: 0.0010000020611536182
[11]: max_eig: 172.63598721773135, reward: 0.0010000020611536182
[12]: max_eig: 27.0341051906261, reward: 0.0010000020611536182
[13]: max_eig: 4.821955566645194, reward: 0.001660432115356872
[14]: max_eig: 1.1971769646869344, reward: 0.0

Sampling parameters:  50%|█████     | 5/10 [00:48<00:48,  9.64s/it]

[21]: max_eig: -3.327595675591532, reward: 0.6968463081949866
[0]: max_eig: 48.030829126980116, reward: 0.0010000020611536182
[1]: max_eig: 22.090590182716305, reward: 0.0010000020611536182
[2]: max_eig: 13.914698608071348, reward: 0.0010000743339036021
[3]: max_eig: 0.4274019483434705, reward: 0.05181549133674975
[4]: max_eig: -0.2432744532600435, reward: 0.09577090977837838
[5]: max_eig: -0.18996054100142684, reward: 0.09129490341996166
[6]: max_eig: -2.2855782392808384, reward: 0.4475990033068087
[7]: max_eig: -2.351575128538629, reward: 0.4639617530527744
[8]: max_eig: -1.368430962369769, reward: 0.24487165681680642


Sampling parameters:  60%|██████    | 6/10 [00:51<00:29,  7.34s/it]

[9]: max_eig: -2.648794375067219, reward: 0.5381301146935885
[0]: max_eig: -0.00255258190821155, reward: 0.07703731934892703
[1]: max_eig: -0.0026282355726220718, reward: 0.07704263461642931
[2]: max_eig: -0.002516992589601611, reward: 0.07703481903610455
[3]: max_eig: -0.022618194727796668, reward: 0.07845908932285446
[4]: max_eig: -0.002509652369798259, reward: 0.07703430336139143
[5]: max_eig: -0.002994046612886514, reward: 0.07706834054224616
[6]: max_eig: -0.041960580566238956, reward: 0.07985262619070847
[7]: max_eig: -1.3829297049119704, reward: 0.24755511801452773
[8]: max_eig: -1.529827723719683, reward: 0.275846165225002
[9]: max_eig: -1.4942726604381331, reward: 0.2688168493439073
[10]: max_eig: -1.3764167581392424, reward: 0.24634723404286968
[11]: max_eig: -1.3730170278994915, reward: 0.24571831239075484
[12]: max_eig: -1.4198717870930178, reward: 0.25448175432681025
[13]: max_eig: -2.0582320203159465, reward: 0.39231977581176886


Sampling parameters:  70%|███████   | 7/10 [00:55<00:18,  6.33s/it]

[14]: max_eig: -3.3067351248347485, reward: 0.692413343171276
[0]: max_eig: -0.2805071757824291, reward: 0.09901363302651202
[1]: max_eig: -0.20637271808006605, reward: 0.09265212379323885
[2]: max_eig: -0.6128400120184175, reward: 0.13256862504812034


Sampling parameters:  80%|████████  | 8/10 [00:56<00:09,  4.67s/it]

[3]: max_eig: -3.549225856265859, reward: 0.7416262146300213
[0]: max_eig: 1.417503267618468, reward: 0.020502770882714995
[1]: max_eig: -0.0025270625928838384, reward: 0.07703552649235558
[2]: max_eig: -0.0025507910454527446, reward: 0.07703719353080359
[3]: max_eig: -0.002431319373127758, reward: 0.07702880040923589
[4]: max_eig: -0.024600561152609213, reward: 0.07860086631001854
[5]: max_eig: -0.04821003909004105, reward: 0.08030775130233202
[6]: max_eig: -0.04780506606970562, reward: 0.08027818600683453
[7]: max_eig: -0.004160558442574374, reward: 0.07715036582007465
[8]: max_eig: -0.0023662876943378484, reward: 0.07702423216238156
[9]: max_eig: -0.0026200286309514377, reward: 0.07704205799746361
[10]: max_eig: -0.002423660834457003, reward: 0.077028262410727
[11]: max_eig: -0.003708570242020224, reward: 0.07711857386887086
[12]: max_eig: -0.0022665830225098564, reward: 0.07701722874880361
[13]: max_eig: -0.002122954632048134, reward: 0.07700714110466082
[14]: max_eig: -0.002214635

Sampling parameters:  90%|█████████ | 9/10 [01:09<00:07,  7.26s/it]

[45]: max_eig: -2.945400849173856, reward: 0.610545189860243
[0]: max_eig: 21143.314474941577, reward: 0.0010000020611536182
[1]: max_eig: 35990.39248110202, reward: 0.0010000020611536182
[2]: max_eig: 30119.525588306155, reward: 0.0010000020611536182
[3]: max_eig: 18481.155339424768, reward: 0.0010000020611536182
[4]: max_eig: 13222.227469907215, reward: 0.0010000020611536182
[5]: max_eig: 24815.9178398235, reward: 0.0010000020611536182
[6]: max_eig: 4210.259715255856, reward: 0.0010000020611536182
[7]: max_eig: 1420.7297689468624, reward: 0.0010000020611536182
[8]: max_eig: -0.004253984916593236, reward: 0.07715693877177968
[9]: max_eig: -0.004309461631842164, reward: 0.0771608420421107
[10]: max_eig: -0.04309246969033371, reward: 0.07993488003317908
[11]: max_eig: -1.9485145069668837, reward: 0.3665198319744743
[12]: max_eig: -0.08461270625419225, reward: 0.08300683905498406
[13]: max_eig: 60.05931950903243, reward: 0.0010000020611536182
[14]: max_eig: 2.417289035452572, reward: 0.0

Sampling parameters: 100%|██████████| 10/10 [01:18<00:00,  7.87s/it]

[30]: max_eig: -2.6612128986980452, reward: 0.5412161622196827





Now, look at the generated kinetic models or inspect the rewards.

In [11]:
final_states, final_rewards, final_max_eigs, steps_to_valid_solution = output

For instance, compute IR:

In [12]:
ir = sum([1 for r in final_rewards if r > 0.5]) / len(final_rewards)

print(f"IR: {ir} and average number of steps to valid solution: {steps_to_valid_solution}")

IR: 0.8 and average number of steps to valid solution: 21.875
