In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
os.chdir('..')

In [4]:
import os

import wandb
from omegaconf import OmegaConf

### Download locally the best models from wandb

---

In [4]:
CKPT_DIR = "/scratch/izar/cizinsky/rl-for-kinetics/best_models"

In [5]:
!mkdir -p {CKPT_DIR}

#### Get overview of the relevant runs

In [6]:
api = wandb.Api()

runs = api.runs("ludekcizinsky/rl-renaissance")

tagged_runs = [run for run in runs if "debug" in run.tags]

for run in tagged_runs:
    print(f"{run.id} | {run.name} | tags: {run.tags}")

mttrmd5v | floral-water-246 | tags: ['debug', 'part2', 'working-version']
fef6a8ch | electric-frog-313 | tags: ['best_setup', 'debug']
zj7x0j8r | denim-glade-315 | tags: ['best_setup', 'clip_eps_decay', 'debug']
x3ovrtnm | wise-universe-316 | tags: ['best_setup', 'clip_eps_decay', 'debug', 'reproduce']
q40t6n35 | avid-darkness-317 | tags: ['debug', 'dev']
pzzyel44 | honest-capybara-318 | tags: ['debug', 'dev']


#### Download the best model for each run

In [7]:
run_ids = [run.id for run in tagged_runs]

for run_id in ["x3ovrtnm"]:
    print(f"Downloading checkpoint for run {run_id}...")
    run = next((run for run in runs if run.id == run_id), None)
    assert run is not None, "Run not found!"

    artifact_path = f"ludekcizinsky/rl-renaissance/{run.name}:v0"
    print(artifact_path)
    artifact = api.artifact(artifact_path, type="model")
    os.makedirs(f"{CKPT_DIR}/{run.name}", exist_ok=True)
    download_path = f"{CKPT_DIR}/{run.name}"

    # Config
    run_cfg = OmegaConf.create(run.config)
    OmegaConf.save(run_cfg, f"{download_path}/config.yaml")

    # Checkpoints
    artifact.download(download_path)
    print(f"Downloaded checkpoint to {download_path}.")

Downloading checkpoint for run x3ovrtnm...
ludekcizinsky/rl-renaissance/wise-universe-316:v0


[34m[1mwandb[0m:   5 of 5 files downloaded.  


Downloaded checkpoint to /scratch/izar/cizinsky/rl-for-kinetics/best_models/wise-universe-316.


### Inference

---

```bash
apptainer shell --nv --bind "$(pwd)":/home/renaissance/work --bind "/scratch/izar/$USER/rl-for-kinetics/output:/home/renaissance/output" --bind "/scratch/izar/$USER/rl-for-kinetics/best_models:/home/renaissance/best_models" /scratch/izar/$USER/images/renaissance_with_ml.sif
```

```bash
jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser
````

In [1]:
!ls /home/renaissance/best_models

celestial-dragon-134  hearty-pine-137	  sweet-eon-143
chocolate-sponge-135  olive-forest-140	  upbeat-thunder-139
comfy-terrain-128     pious-pyramid-211   vibrant-oath-142
crisp-sunset-141      resilient-oath-127  wise-universe-316
floral-water-246      sandy-blaze-130	  zesty-bird-206
fragrant-plasma-133   stilted-lake-138


In [7]:
!ls /home/renaissance/best_models/wise-universe-316

best_setup_e86_s41.pt  first_valid_setup_e11_s1.pt  policy.pt
config.yaml	       normalisation.pt		    value.pt


In [13]:
import torch
import numpy as np
from helpers.utils import setup_kinetic_env, log_final_eval_metrics, sample_params
from helpers.ppo_agent import PolicyNetwork

In [14]:
selected_model = "wise-universe-316"
ckpt_dir = f"/home/renaissance/best_models/{selected_model}"
policy_path = f"{ckpt_dir}/policy.pt"
config_path = f"{ckpt_dir}/config.yaml"
normalisation_path = f"{ckpt_dir}/normalisation.pt"

In [10]:
cfg = OmegaConf.load(config_path)

In [11]:
kinetic_env = setup_kinetic_env(cfg)
kinetic_env.logging_enabled = False

--------------------------------------------------
env:
  p_size: 384
  action_scale: 1
seed: 42
paths:
  names_km: data/varma_ecoli_shikki/parameter_names_km_fdp1.pkl
  output_dir: /home/renaissance/output
  met_model_name: varma_ecoli_shikki
device: cpu
logger:
  tags:
  - debug
  - best_setup
  - clip_eps_decay
  - reproduce
  entity: ludekcizinsky
  project: rl-renaissance
method:
  name: ppo_refinement
  actor_lr: 0.0001
  critic_lr: 0.001
  gae_lambda: 0.98
  max_log_std: 2
  min_log_std: -6
  clip_eps_end: 0.1
  parameter_dim: 384
  clip_eps_start: 0.3
  discount_factor: 0.99
  value_loss_weight: 0.5
  entropy_loss_weight: 0.01
reward:
  eig_partition: -2.5
training:
  batch_size: 25
  num_epochs: 10
  num_episodes: 100
  max_grad_norm: 0.5
  save_trained_models: true
  max_steps_per_episode: 50
  n_eval_samples_in_episode: 50
launch_cmd: train.py logger.tags=[debug, best_setup, clip_eps_decay, reproduce]
constraints:
  max_km: 3
  min_km: -25
  ss_idx: 1712
lr_scheduler:
  name

Process ForkPoolWorker-31:
Process ForkPoolWorker-28:
Process ForkPoolWorker-24:
Process ForkPoolWorker-2:
Process ForkPoolWorker-19:
Process ForkPoolWorker-25:
Process ForkPoolWorker-27:
Process ForkPoolWorker-6:
Process ForkPoolWorker-23:
Process ForkPoolWorker-11:
Process ForkPoolWorker-30:
Process ForkPoolWorker-9:
Process ForkPoolWorker-14:
Process ForkPoolWorker-1:
Process ForkPoolWorker-20:
Process ForkPoolWorker-29:
Traceback (most recent call last):
Traceback (most recent call last):
Process ForkPoolWorker-26:
Process ForkPoolWorker-8:
Process ForkPoolWorker-22:
Process ForkPoolWorker-17:
Process ForkPoolWorker-16:
Traceback (most recent call last):
Process ForkPoolWorker-13:
Process ForkPoolWorker-12:
Process ForkPoolWorker-32:
Process ForkPoolWorker-15:
Process ForkPoolWorker-4:
Process ForkPoolWorker-3:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last)

In [12]:
device = "cpu"
policy = PolicyNetwork(cfg).to(device)
policy.load_pretrained_policy_net(policy_path)

FYI: Loaded pretrained policy network from /home/renaissance/best_models/wise-universe-316/policy.pt.


In [16]:
normalisation = torch.load(normalisation_path)
obs_mean, obs_var = normalisation["obs_mean"], normalisation["obs_var"]

In [14]:
api = wandb.Api()
run = api.run("ludekcizinsky/rl-renaissance/mttrmd5v")

In [17]:
output = sample_params(policy, kinetic_env, obs_mean, obs_var, N=10, max_steps=50)

Sampling parameters:   0%|          | 0/10 [00:00<?, ?it/s]

[0]: max_eig: 1286.6741302019732, reward: 0.0010000020611536182
[1]: max_eig: 95.74345568879855, reward: 0.0010000020611536182
[2]: max_eig: 0.3227273018170841, reward: 0.05710832020412048
[3]: max_eig: 561.2529998133442, reward: 0.0010000020611536182
[4]: max_eig: 0.5135783828163806, reward: 0.04781620312404489
[5]: max_eig: 2.372957367329585, reward: 0.00859261697165222
[6]: max_eig: 1.29042514316338, reward: 0.023087137494391622
[7]: max_eig: -0.42674525248723333, reward: 0.11272362424820181
[8]: max_eig: -1.7536339531521101, reward: 0.32261363606294313
[9]: max_eig: -1.5562585935990072, reward: 0.28114521581425106
[10]: max_eig: -0.7931721452601949, reward: 0.15457561039892273
[11]: max_eig: -0.6250504368538443, reward: 0.13397005490560865
[12]: max_eig: -1.0680156674880774, reward: 0.1937896905213515
[13]: max_eig: -1.4083786183658895, reward: 0.25231308463418906
[14]: max_eig: -0.9241762924918994, reward: 0.17238776031509637
[15]: max_eig: -1.9759917778139229, reward: 0.372915455

Sampling parameters:  10%|█         | 1/10 [00:05<00:51,  5.71s/it]

[17]: max_eig: -2.5764463436210723, reward: 0.5201022839298046
[0]: max_eig: 104.73993219560546, reward: 0.0010000020611536182
[1]: max_eig: 8.28913805436602, reward: 0.0010206218641927334
[2]: max_eig: -0.05698509111369617, reward: 0.0809508582955542
[3]: max_eig: -0.05698548808375075, reward: 0.08095088749616759
[4]: max_eig: -0.0569063881561251, reward: 0.08094506919823431
[5]: max_eig: -0.10988675134866159, reward: 0.08492972430517984
[6]: max_eig: 0.08022402805200408, reward: 0.07142206397850658
[7]: max_eig: 0.22850302643608403, reward: 0.062312261684060004
[8]: max_eig: -0.8859654609680433, reward: 0.16702922412334503
[9]: max_eig: -1.4795104594287287, reward: 0.26593205493403005
[10]: max_eig: -1.1724119606509529, reward: 0.21055861166164774
[11]: max_eig: -1.4429061913577446, reward: 0.2588652232368188
[12]: max_eig: -1.3066617382995593, reward: 0.2336624220419001
[13]: max_eig: -1.7776791895921327, reward: 0.3278821295290538
[14]: max_eig: -1.0318572182998917, reward: 0.18822

Sampling parameters:  20%|██        | 2/10 [00:22<01:36, 12.03s/it]

[49]: max_eig: -0.1117540422157276, reward: 0.08507340353122111
[0]: max_eig: 24.902780603217614, reward: 0.0010000020611536182
[1]: max_eig: 25.96396093989543, reward: 0.0010000020611536182
[2]: max_eig: -0.017174473020853943, reward: 0.07807097912765915
[3]: max_eig: 108.87859708447749, reward: 0.0010000020611536182
[4]: max_eig: 104.46909302798099, reward: 0.0010000020611536182
[5]: max_eig: 37.22655953315994, reward: 0.0010000020611536182
[6]: max_eig: -0.3859783926754971, reward: 0.10874144634968114
[7]: max_eig: -0.39001243148503506, reward: 0.1091298656212032
[8]: max_eig: -0.3983020562967804, reward: 0.10993189956193132
[9]: max_eig: -0.385907674934719, reward: 0.10873464821362845
[10]: max_eig: -0.3941464632476948, reward: 0.10952918887674926
[11]: max_eig: -0.40403166272744734, reward: 0.11048929546808745
[12]: max_eig: -0.4084289666179176, reward: 0.11091877551672447
[13]: max_eig: -0.4048231615432571, reward: 0.11056649155894595
[14]: max_eig: -0.4034233572178502, reward: 0

Sampling parameters:  30%|███       | 3/10 [00:31<01:15, 10.79s/it]

[28]: max_eig: -2.6132289326601015, reward: 0.5292770284621189
[0]: max_eig: -0.41309100271096183, reward: 0.11137572362146687
[1]: max_eig: -0.8213407714253336, reward: 0.1582730907577058
[2]: max_eig: -0.932514016754161, reward: 0.17357508008785194
[3]: max_eig: -1.0188422863185642, reward: 0.18625261716162525
[4]: max_eig: 2341.4636756445, reward: 0.0010000020611536182
[5]: max_eig: 5262.883266092898, reward: 0.0010000020611536182
[6]: max_eig: 1671.7976180900112, reward: 0.0010000020611536182
[7]: max_eig: 140.44379076104028, reward: 0.0010000020611536182
[8]: max_eig: 2052.197019832958, reward: 0.0010000020611536182
[9]: max_eig: 4163.999596455726, reward: 0.0010000020611536182
[10]: max_eig: 3578.693311262437, reward: 0.0010000020611536182
[11]: max_eig: 4070.712393954, reward: 0.0010000020611536182
[12]: max_eig: 4494.473208419784, reward: 0.0010000020611536182
[13]: max_eig: 4228.227864401281, reward: 0.0010000020611536182
[14]: max_eig: 3305.1483019497837, reward: 0.0010000020

Sampling parameters:  40%|████      | 4/10 [00:47<01:17, 12.84s/it]

[49]: max_eig: -0.00198311135686942, reward: 0.07699732048597448
[0]: max_eig: 1165.766198790649, reward: 0.0010000020611536182
[1]: max_eig: 4418.858207949433, reward: 0.0010000020611536182
[2]: max_eig: -0.29607683645054106, reward: 0.10039874052444255
[3]: max_eig: 3661.1165813433554, reward: 0.0010000020611536182
[4]: max_eig: 3733.9048949983817, reward: 0.0010000020611536182
[5]: max_eig: 7099.644083833706, reward: 0.0010000020611536182
[6]: max_eig: 63.27742456877301, reward: 0.0010000020611536182
[7]: max_eig: 460.66865745377385, reward: 0.0010000020611536182
[8]: max_eig: 292.3579618427961, reward: 0.0010000020611536182
[9]: max_eig: 174.2998669967932, reward: 0.0010000020611536182
[10]: max_eig: 70.58529930352222, reward: 0.0010000020611536182
[11]: max_eig: 172.63598721773135, reward: 0.0010000020611536182
[12]: max_eig: 27.0341051906261, reward: 0.0010000020611536182
[13]: max_eig: 4.821955566645194, reward: 0.001660432115356872
[14]: max_eig: 1.1971769646869344, reward: 0.0

Sampling parameters:  50%|█████     | 5/10 [00:54<00:53, 10.75s/it]

[21]: max_eig: -3.327595675591532, reward: 0.6968463081949866
[0]: max_eig: 48.030829126980116, reward: 0.0010000020611536182
[1]: max_eig: 22.090590182716305, reward: 0.0010000020611536182
[2]: max_eig: 13.914698608071348, reward: 0.0010000743339036021
[3]: max_eig: 0.4274019483434705, reward: 0.05181549133674975
[4]: max_eig: -0.2432744532600435, reward: 0.09577090977837838
[5]: max_eig: -0.18996054100142684, reward: 0.09129490341996166
[6]: max_eig: -2.2855782392808384, reward: 0.4475990033068087
[7]: max_eig: -2.351575128538629, reward: 0.4639617530527744
[8]: max_eig: -1.368430962369769, reward: 0.24487165681680642


Sampling parameters:  60%|██████    | 6/10 [00:57<00:32,  8.19s/it]

[9]: max_eig: -2.648794375067219, reward: 0.5381301146935885
[0]: max_eig: -0.00255258190821155, reward: 0.07703731934892703
[1]: max_eig: -0.0026282355726220718, reward: 0.07704263461642931
[2]: max_eig: -0.002516992589601611, reward: 0.07703481903610455
[3]: max_eig: -0.022618194727796668, reward: 0.07845908932285446
[4]: max_eig: -0.002509652369798259, reward: 0.07703430336139143
[5]: max_eig: -0.002994046612886514, reward: 0.07706834054224616
[6]: max_eig: -0.041960580566238956, reward: 0.07985262619070847
[7]: max_eig: -1.3829297049119704, reward: 0.24755511801452773
[8]: max_eig: -1.529827723719683, reward: 0.275846165225002
[9]: max_eig: -1.4942726604381331, reward: 0.2688168493439073
[10]: max_eig: -1.3764167581392424, reward: 0.24634723404286968
[11]: max_eig: -1.3730170278994915, reward: 0.24571831239075484
[12]: max_eig: -1.4198717870930178, reward: 0.25448175432681025
[13]: max_eig: -2.0582320203159465, reward: 0.39231977581176886


Sampling parameters:  70%|███████   | 7/10 [01:02<00:21,  7.08s/it]

[14]: max_eig: -3.3067351248347485, reward: 0.692413343171276
[0]: max_eig: -0.2805071757824291, reward: 0.09901363302651202
[1]: max_eig: -0.20637271808006605, reward: 0.09265212379323885
[2]: max_eig: -0.6128400120184175, reward: 0.13256862504812034


Sampling parameters:  80%|████████  | 8/10 [01:03<00:10,  5.24s/it]

[3]: max_eig: -3.549225856265859, reward: 0.7416262146300213
[0]: max_eig: 1.417503267618468, reward: 0.020502770882714995
[1]: max_eig: -0.0025270625928838384, reward: 0.07703552649235558
[2]: max_eig: -0.0025507910454527446, reward: 0.07703719353080359
[3]: max_eig: -0.002431319373127758, reward: 0.07702880040923589
[4]: max_eig: -0.024600561152609213, reward: 0.07860086631001854
[5]: max_eig: -0.04821003909004105, reward: 0.08030775130233202
[6]: max_eig: -0.04780506606970562, reward: 0.08027818600683453
[7]: max_eig: -0.004160558442574374, reward: 0.07715036582007465
[8]: max_eig: -0.0023662876943378484, reward: 0.07702423216238156
[9]: max_eig: -0.0026200286309514377, reward: 0.07704205799746361
[10]: max_eig: -0.002423660834457003, reward: 0.077028262410727
[11]: max_eig: -0.003708570242020224, reward: 0.07711857386887086
[12]: max_eig: -0.0022665830225098564, reward: 0.07701722874880361
[13]: max_eig: -0.002122954632048134, reward: 0.07700714110466082
[14]: max_eig: -0.002214635

Sampling parameters:  90%|█████████ | 9/10 [01:18<00:08,  8.24s/it]

[45]: max_eig: -2.945400849173856, reward: 0.610545189860243
[0]: max_eig: 21143.314474941577, reward: 0.0010000020611536182
[1]: max_eig: 35990.39248110202, reward: 0.0010000020611536182
[2]: max_eig: 30119.525588306155, reward: 0.0010000020611536182
[3]: max_eig: 18481.155339424768, reward: 0.0010000020611536182
[4]: max_eig: 13222.227469907215, reward: 0.0010000020611536182
[5]: max_eig: 24815.9178398235, reward: 0.0010000020611536182
[6]: max_eig: 4210.259715255856, reward: 0.0010000020611536182
[7]: max_eig: 1420.7297689468624, reward: 0.0010000020611536182
[8]: max_eig: -0.004253984916593236, reward: 0.07715693877177968
[9]: max_eig: -0.004309461631842164, reward: 0.0771608420421107
[10]: max_eig: -0.04309246969033371, reward: 0.07993488003317908
[11]: max_eig: -1.9485145069668837, reward: 0.3665198319744743
[12]: max_eig: -0.08461270625419225, reward: 0.08300683905498406
[13]: max_eig: 60.05931950903243, reward: 0.0010000020611536182
[14]: max_eig: 2.417289035452572, reward: 0.0

Sampling parameters: 100%|██████████| 10/10 [01:28<00:00,  8.85s/it]

[30]: max_eig: -2.6612128986980452, reward: 0.5412161622196827





In [18]:
final_states, final_rewards, final_max_eigs, t_to_valid_solution = output

In [19]:
t_to_valid_solution

21.875

In [20]:
len(final_states)

10

In [None]:
final_states[0].dtype, final_states[0].shape, final_states[0].device

(torch.float32, torch.Size([384]), device(type='cpu'))

In [25]:
final_states = torch.stack(final_states)
final_states = final_states.to("cpu")
torch.std(final_states, dim=0).mean().item()

6.304969310760498

In [26]:
final_rewards

[0.5201022839298046,
 0.08507340353122111,
 0.5292770284621189,
 0.07699732048597448,
 0.6968463081949866,
 0.5381301146935885,
 0.692413343171276,
 0.7416262146300213,
 0.610545189860243,
 0.5412161622196827]

In [27]:
torch.Tensor(final_rewards)

tensor([0.5201, 0.0851, 0.5293, 0.0770, 0.6968, 0.5381, 0.6924, 0.7416, 0.6105,
        0.5412])

In [28]:
torch.Tensor(final_max_eigs)

tensor([-2.5764e+00, -1.1175e-01, -2.6132e+00, -1.9831e-03, -3.3276e+00,
        -2.6488e+00, -3.3067e+00, -3.5492e+00, -2.9454e+00, -2.6612e+00])

In [None]:
log_final_eval_metrics(policy, kinetic_env, N=100, max_steps=50, wandb_summary=run.summary)

In [16]:
run.summary.update()

In [17]:
wandb.finish()