In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
import os
os.chdir('..')

In [6]:
import os

import wandb
from omegaconf import OmegaConf

### Download locally the best models from wandb

---

In [4]:
CKPT_DIR = "/scratch/izar/cizinsky/rl-for-kinetics/best_models"

In [5]:
!mkdir -p {CKPT_DIR}

#### Get overview of the relevant runs

In [None]:
api = wandb.Api()

runs = api.runs("ludekcizinsky/rl-renaissance")

tagged_runs = [run for run in runs if "dev" in run.tags]

for run in tagged_runs:
    print(f"{run.id} | {run.name} | tags: {run.tags}")

#### Download the best model for each run

In [None]:
run_ids = [run.id for run in tagged_runs]

for run_id in ["mttrmd5v"]:
    print(f"Downloading checkpoint for run {run_id}...")
    run = next((run for run in runs if run.id == run_id), None)
    assert run is not None, "Run not found!"

    artifact_path = f"ludekcizinsky/rl-renaissance/{run.name}:v0"
    print(artifact_path)
    artifact = api.artifact(artifact_path, type="model")
    os.makedirs(f"{CKPT_DIR}/{run.name}", exist_ok=True)
    download_path = f"{CKPT_DIR}/{run.name}"

    # Config
    run_cfg = OmegaConf.create(run.config)
    OmegaConf.save(run_cfg, f"{download_path}/config.yaml")

    # Checkpoints
    artifact.download(download_path)
    print(f"Downloaded checkpoint to {download_path}.")

### Inference

---

```bash
apptainer shell --nv --bind "$(pwd)":/home/renaissance/work --bind "/scratch/izar/$USER/rl-for-kinetics/output:/home/renaissance/output" --bind "/scratch/izar/$USER/rl-for-kinetics/best_models:/home/renaissance/best_models" /scratch/izar/$USER/images/renaissance_with_ml.sif
```

```bash
jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser
````

In [7]:
!ls /home/renaissance/best_models

celestial-dragon-134  hearty-pine-137	  sweet-eon-143
chocolate-sponge-135  olive-forest-140	  upbeat-thunder-139
comfy-terrain-128     pious-pyramid-211   vibrant-oath-142
crisp-sunset-141      resilient-oath-127  zesty-bird-206
floral-water-246      sandy-blaze-130
fragrant-plasma-133   stilted-lake-138


In [8]:
!ls /home/renaissance/best_models/floral-water-246

best_setup_e95_s48.pt  first_valid_setup_e7_s5.pt  value.pt
config.yaml	       policy.pt


In [9]:
import torch
import numpy as np
from helpers.utils import setup_kinetic_env, log_final_eval_metrics
from helpers.ppo_agent import PolicyNetwork

In [10]:
selected_model = "floral-water-246"
ckpt_dir = f"/home/renaissance/best_models/{selected_model}"
policy_path = f"{ckpt_dir}/policy.pt"
config_path = f"{ckpt_dir}/config.yaml"

In [11]:
cfg = OmegaConf.load(config_path)

In [12]:
kinetic_env = setup_kinetic_env(cfg)
kinetic_env.logging_enabled = False

--------------------------------------------------
env:
  p_size: 384
  action_scale: 1
seed: 42
paths:
  names_km: data/varma_ecoli_shikki/parameter_names_km_fdp1.pkl
  output_dir: /home/renaissance/output
  met_model_name: varma_ecoli_shikki
device: cpu
logger:
  tags:
  - dev
  entity: ludekcizinsky
  project: rl-renaissance
method:
  name: ppo_refinement
  actor_lr: 0.0001
  clip_eps: 0.2
  critic_lr: 0.001
  gae_lambda: 0.98
  max_log_std: 2
  min_log_std: -6
  parameter_dim: 384
  discount_factor: 0.99
  value_loss_weight: 0.5
  entropy_loss_weight: 0.01
reward:
  eig_partition: -2.5
training:
  batch_size: 25
  num_epochs: 10
  num_episodes: 100
  max_grad_norm: 0.5
  save_trained_models: true
  max_steps_per_episode: 50
  n_eval_samples_in_episode: 50
launch_cmd: train.py logger.tags=[dev]
constraints:
  max_km: 3
  min_km: -25
  ss_idx: 1712
lr_scheduler:
  name: constant

--------------------------------------------------
FYI: Loading kinetic and thermodynamic data.


In [13]:
device = "cpu"
policy = PolicyNetwork(cfg).to(device)
policy.load_pretrained_policy_net(policy_path)

FYI: Loaded pretrained policy network from /home/renaissance/best_models/floral-water-246/policy.pt.


In [14]:
api = wandb.Api()
run = api.run("ludekcizinsky/rl-renaissance/mttrmd5v")

In [15]:
log_final_eval_metrics(policy, kinetic_env, N=100, max_steps=50, wandb_summary=run.summary)

Sampling parameters:   0%|          | 0/100 [00:00<?, ?it/s]

Sampling parameters: 100%|██████████| 100/100 [08:55<00:00,  5.35s/it]


In [16]:
run.summary.update()

In [17]:
wandb.finish()