In [16]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from environment import RNAInvEnvironment, make_vec_env, Monitor
from RNA_helper import get_puzzle
import torch as th
from models import EmbeddinsFeatureExtractor
from stable_baselines3.common import logger

In [3]:
# 1, 41, 84, 92, 97, 5
puzzle_idx=5
objective_structure, sequence, puzzle_name = get_puzzle(idx=puzzle_idx, return_name=True, verbose=False)
len(objective_structure)

5 -209.8000030517578 -209.8000030517578
yes



379

In [4]:
max_steps = 1
features_dim = 512
EMBEDDING_DIM = 16
metric = 'energies_mse'
model_name = puzzle_name.lower().replace(' ', '_') + f'_{features_dim}_{EMBEDDING_DIM}_{metric}'
print(model_name)

saccharomyces_cerevisiae_-_difficulty_level_0_512_16_energies_mse


In [5]:
env_kwargs = {
    'objective_structure': objective_structure,
    'max_steps': max_steps,
    'tuple_obs_space': True,
    'metric_type': metric,
    'sequences_file': f'solved_puzzles/{model_name}.txt'
}

In [6]:
n_envs=12
env = make_vec_env(RNAInvEnvironment, n_envs=n_envs, env_kwargs=env_kwargs)
# env = RNAInvEnvironment(objective_structure=objective_structure, max_steps=max_steps, tuple_obs_space=True)



In [7]:
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.policies import ActorCriticPolicy

In [8]:
policy_kwargs = dict(
    features_extractor_class=EmbeddinsFeatureExtractor,
    features_extractor_kwargs=dict(EMBEDDING_DIM=EMBEDDING_DIM, features_dim=features_dim),
)

In [9]:
model = PPO(
    ActorCriticPolicy,
    env,
    verbose=1,
    tensorboard_log='tensorboard_logs',
    n_steps=512,
    gamma=0.99,
    policy_kwargs=policy_kwargs
)

Using cuda device


In [10]:
# log_path = f"logs/{model_name}"
# # set up logger
# new_logger = logger.configure(log_path, ["stdout", "csv", "log", "tensorboard", "json"])
# model.set_logger(new_logger)

In [11]:
# eval_env = make_vec_env(
#     RNAInvEnvironment, n_envs=1,
#     env_kwargs={'objective_structure': objective_structure, 'max_steps': max_steps, 'tuple_obs_space': True}
# )

eval_env = make_vec_env(
    RNAInvEnvironment, n_envs=1,
    env_kwargs=env_kwargs,
    monitor_dir=f'logs/{model_name}',
    monitor_kwargs={
        'info_keywords': (
            'free_energy',
            'structure_distance',
            'energy_to_objective',
            'energy_reward',
            'distance_reward',
            'folding_struc',
            'sequence',
            'solved',
            'unique_sequences_N'
        )
    }
)

In [12]:
eval_callback = EvalCallback(
    eval_env = eval_env,
    eval_freq=512*5,
    n_eval_episodes=1024,
    deterministic=True,
    verbose=1,
    best_model_save_path=f'models/{model_name}',
)


In [None]:
%%time
model.learn(
    total_timesteps=1_000_000,
    tb_log_name=model_name,
    callback=[eval_callback]
)

Logging to tensorboard_logs/saccharomyces_cerevisiae_-_difficulty_level_0_512_16_energies_mse_1
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 1         |
|    ep_rew_mean     | -7.54e+03 |
| time/              |           |
|    fps             | 5         |
|    iterations      | 1         |
|    time_elapsed    | 1135      |
|    total_timesteps | 6144      |
----------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1           |
|    ep_rew_mean          | -7.44e+03   |
| time/                   |             |
|    fps                  | 5           |
|    iterations           | 2           |
|    time_elapsed         | 2426        |
|    total_timesteps      | 12288       |
| train/                  |             |
|    approx_kl            | 0.036745187 |
|    clip_fraction        | 0.374       |
|    clip_range           | 0.2         |
|    ent