In [None]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('C:\\Users\\joewa\\Work\\git\\vimms')
sys.path.append('C:\\Users\\Vinny\\work\\vimms')
sys.path.append('..')

In [None]:
import numpy as np
import torch
import random as rand
import pylab as plt
import multiprocessing

In [None]:
from vimms.Common import *
from vimms.Gym import FragmentEnv
from vimms.Evaluation import evaluate_simulated_env, evaluate_multiple_simulated_env

In [None]:
np.random.seed(0)
rand.seed(0)
torch.manual_seed(0)

### Parameters

In [None]:
set_log_level_warning()

In [None]:
n_chemicals = (400, 1000)
mz_range = (100, 600)
rt_range = (0, 500)
intensity_range = (1E5, 1E10)

In [None]:
# n_chemicals = (200, 500)
# mz_range = (100, 600)
# rt_range = (0, 300)
# intensity_range = (1E5, 1E10)

In [None]:
min_mz = mz_range[0]
max_mz = mz_range[1]
min_rt = rt_range[0]
max_rt = rt_range[1]
min_log_intensity = np.log(intensity_range[0])
max_log_intensity = np.log(intensity_range[1])

In [None]:
isolation_window = 0.7
N = 10
rt_tol = 15
mz_tol = 10
min_ms1_intensity = 5000
ionisation_mode = POSITIVE
noise_density = 0.3
noise_max_val = 1e4

### Custom gym

In [None]:
params = {
    'chemical_creator': {
        'mz_range': mz_range,
        'rt_range': rt_range,
        'intensity_range': intensity_range,
        'n_chemicals': n_chemicals
    },
    'noise': {
        'noise_density': noise_density,
        'noise_max_val': noise_max_val,
        'mz_range': mz_range
    },
    'env': {
        'ionisation_mode': ionisation_mode,
        'rt_range': rt_range,
        'N': N,
        'isolation_window': isolation_window,
        'mz_tol': mz_tol,
        'rt_tol': rt_tol,
        'min_ms1_intensity': min_ms1_intensity
    }
}

In [None]:
set_log_level_info()

### PPO

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.env_checker import check_env

In [None]:
env = FragmentEnv(params)
check_env(env)

In [None]:
n_envs = int(multiprocessing.cpu_count() / 2)
n_envs

In [None]:
env = make_vec_env(FragmentEnv, n_envs=n_envs, env_kwargs={'params': params})
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="./results/ppo_fragmentenv_tensorboard/")
model.learn(total_timesteps=250000)

In [None]:
model.save('results/ppo_all_chems')

## Evaluation

In [None]:
model = PPO.load('results/ppo_all_chems')

In [None]:
def evaluation(model, num_episodes):
    env = FragmentEnv(params)
    total_rewards = []
    total_reward_per_chems = []
    env_list = []
    
    for i_episode in range(num_episodes):
        observation = env.reset()
        total_reward = 0
        for t in range(1000):
            # env.render()
            if model == 'random':
                action = env.action_space.sample()
                out_file = 'test_%s_%d.mzML' % (model, i_episode)
            elif model == 'TopN':
                action = -1
                out_file = 'test_%s_%d.mzML' % (model, i_episode)                
            else:
                action, _ = model.predict(observation)     
                out_file = 'test_%s_%d.mzML' % ('PPO', i_episode)
                                                
            observation, reward, done, info = env.step(action)
            total_reward += reward
            if done:
                seen_actions = env.seen_actions.most_common()
                n_chems = len(env.chems)
                reward_per_chems = total_reward / n_chems
                print('Episode %d timesteps %d reward %f n_chems %d reward/chems %f actions %s' % (i_episode, t+1, total_reward, n_chems, reward_per_chems, seen_actions))
                total_rewards.append(total_reward)
                total_reward_per_chems.append(reward_per_chems)
                env_list.append(env.vimms_env)
                if i_episode % write_mzml_every == 0 or i_episode == num_episodes-1:
                    env.vimms_env.write_mzML('results', out_file)
                break
    env.close()
    logger.info('Average total reward = %f' % np.mean(total_rewards))
    return np.array(total_rewards), np.array(total_reward_per_chems), env_list

In [None]:
num_episodes = 100
write_mzml_every = 20

### PPO

In [None]:
ppo_total_rewards, ppo_reward_per_chems, ppo_env_list = evaluation(model, num_episodes)

### Random

In [None]:
random_total_rewards, random_reward_per_chems, random_env_list = evaluation('random', num_episodes)

### TopN

In [None]:
topN_total_rewards, topN_reward_per_chems, topN_env_list = evaluation('TopN', num_episodes)

### Plots

In [None]:
def plot_diff(controller_names, scores_list, ref_name, ref_scores):
    for controller_name, scores in zip(controller_names, scores_list):
        diff = scores - ref_scores
        perc = np.multiply(diff, 1/ref_scores) * 100
        plt.plot(diff, label=controller_name)
    plt.title('Score improvement over %s' % ref_name)
    plt.ylabel('Score Improvement (%)')
    plt.xlabel('Episode')        
    plt.legend()

def plot_arr(controller_names, arr_list, title):
    for controller_name, arr in zip(controller_names, arr_list):
        plt.plot(arr, label=controller_name)
    plt.title('%s per Episode' % title)
    plt.ylabel(title)
    plt.xlabel('Episode')        
    plt.legend()
        
def get_scores(env_list):
    scores = []
    for env in env_list:
        score = get_score(env)
        scores.append(score)
    return np.array(scores)

def get_score(env):
    res = evaluate_simulated_env(env)
    score = res['coverage_proportion'] * res['intensity_proportion']
    return score

In [None]:
plot_arr(['PPO', 'TopN', 'Random'], [ppo_total_rewards, topN_total_rewards, random_total_rewards], 'Total Rewards')

In [None]:
plot_arr(['PPO', 'TopN', 'Random'], [ppo_reward_per_chems, topN_reward_per_chems, random_reward_per_chems], 'Reward/Chems')

In [None]:
ppo_scores = get_scores(ppo_env_list)
topN_scores = get_scores(topN_env_list)
random_scores = get_scores(random_env_list)
plot_arr(['PPO', 'TopN', 'Random'], [ppo_scores, topN_scores, random_scores], 'Scores')

In [None]:
plot_diff(['PPO', 'Random'], [ppo_scores, random_scores], 'TopN', topN_scores)