In [None]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
from os.path import exists

sys.path.append('../..')

In [None]:
import pylab as plt
import pandas as pd
import numpy as np
from loguru import logger
import seaborn as sns

from stable_baselines3 import PPO, DQN

In [None]:
from vimms.Common import POSITIVE, set_log_level_warning, load_obj, save_obj
from vimms.ChemicalSamplers import UniformRTAndIntensitySampler, GaussianChromatogramSampler, UniformMZFormulaSampler, \
    MZMLFormulaSampler, MZMLRTandIntensitySampler, MZMLChromatogramSampler
from vimms.Noise import UniformSpikeNoise
from vimms.Evaluation import evaluate_real
from vimms.Chemicals import ChemicalMixtureFromMZML
from vimms.Roi import RoiBuilderParams, SmartRoiParams

from mass_spec_utils.data_import.mzmine import load_picked_boxes

from vimms_gym.env import DDAEnv
from vimms_gym.chemicals import generate_chemicals
from vimms_gym.evaluation import evaluate, run_method
from vimms_gym.common import METHOD_RANDOM, METHOD_FULLSCAN, METHOD_TOPN, METHOD_PPO, METHOD_DQN

# 1. Parameters

In [None]:
n_chemicals = (200, 500)
mz_range = (100, 600)
rt_range = (400, 800)
intensity_range = (1E4, 1E20)

In [None]:
min_mz = mz_range[0]
max_mz = mz_range[1]
min_rt = rt_range[0]
max_rt = rt_range[1]
min_log_intensity = np.log(intensity_range[0])
max_log_intensity = np.log(intensity_range[1])

In [None]:
isolation_window = 0.7
N = 10
rt_tol = 120
exclusion_t_0 = 15
mz_tol = 10
min_ms1_intensity = 5000
ionisation_mode = POSITIVE

enable_spike_noise = True
noise_density = 0.1
noise_max_val = 1E3

In [None]:
mzml_filename = '../fullscan_QCB.mzML'
samplers_pickle = 'samplers_QCB_medium.p'
if exists(samplers_pickle):
    logger.info('Loaded %s' % samplers_pickle)
    samplers = load_obj(samplers_pickle)
    mz_sampler = samplers['mz']
    ri_sampler = samplers['rt_intensity']
    cr_sampler = samplers['chromatogram']
else:
    logger.info('Creating samplers from %s' % mzml_filename)
    mz_sampler = MZMLFormulaSampler(mzml_filename, min_mz=min_mz, max_mz=max_mz)
    ri_sampler = MZMLRTandIntensitySampler(mzml_filename, min_rt=min_rt, max_rt=max_rt,
                                           min_log_intensity=min_log_intensity,
                                           max_log_intensity=max_log_intensity)
    roi_params = RoiBuilderParams(min_roi_length=3, at_least_one_point_above=5000)
    cr_sampler = MZMLChromatogramSampler(mzml_filename, roi_params=roi_params)
    samplers = {
        'mz': mz_sampler,
        'rt_intensity': ri_sampler,
        'chromatogram': cr_sampler
    }
    save_obj(samplers, samplers_pickle)

In [None]:
params = {
    'chemical_creator': {
        'mz_range': mz_range,
        'rt_range': rt_range,
        'intensity_range': intensity_range,
        'n_chemicals': n_chemicals,
        'mz_sampler': mz_sampler,
        'ri_sampler': ri_sampler,
        'cr_sampler': cr_sampler
    },
    'noise': {
        'enable_spike_noise': enable_spike_noise,
        'noise_density': noise_density,
        'noise_max_val': noise_max_val,
        'mz_range': mz_range
    },
    'env': {
        'ionisation_mode': ionisation_mode,
        'rt_range': rt_range,
        'isolation_window': isolation_window,
        'mz_tol': mz_tol,
        'rt_tol': rt_tol,
    }
}

In [None]:
max_peaks = 200
in_dir = 'results'

In [None]:
n_eval_episodes = 1
deterministic = True

# 2. Evaluation

#### Generate some chemical sets

In [None]:
set_log_level_warning()

In [None]:
eval_dir = 'optimise_baselines'
method = METHOD_TOPN

In [None]:
chemical_creator_params = params['chemical_creator']

chem_list = []
for i in range(n_eval_episodes):
    print(i)
    chems = generate_chemicals(chemical_creator_params)
    chem_list.append(chems)

#### Run different methods

In [None]:
for chems in chem_list:
    print(len(chems))

In [None]:
max_peaks

In [None]:
out_dir = eval_dir
in_dir, out_dir

#### Compare to Top-10

In [None]:
env_name = 'DDAEnv'
model_name = 'PPO'
intensity_threshold = 0.5

In [None]:
rt_tols = [15, 30, 60, 120, 240, 300]
Ns = [5, 10, 15, 20, 25]

In [None]:
topN_res = {}
for rt_tol in rt_tols:
    for N in Ns:

        effective_rt_tol = rt_tol
        copy_params = dict(params)        
        copy_params['env']['rt_tol'] = effective_rt_tol

        banner = 'method = %s max_peaks = %d N = %d rt_tol = %d' % (method, max_peaks, N, effective_rt_tol)
        print(banner)
        print()

        if method == METHOD_PPO:
            fname = os.path.join(in_dir, '%s_%s.zip' % (env_name, model_name))
            model = PPO.load(fname)
        elif method == METHOD_DQN:
            fname = os.path.join(in_dir, '%s_%s.zip' % (env_name, model_name))
            model = DQN.load(fname)
        else:
            model = None

        episodic_results = run_method(env_name, copy_params, max_peaks, chem_list, method, out_dir, 
                                      N=N, min_ms1_intensity=min_ms1_intensity, model=model,
                                      print_eval=True, print_reward=False, intensity_threshold=intensity_threshold)
        eval_results = [er.eval_res for er in episodic_results][0]

        key = (N, rt_tol)
        topN_res[key] = eval_results
        print()

In [None]:
topN_res

In [None]:
method_eval_results = {
    method: topN_res
}

#### Test classic controllers in ViMMS

In [None]:
from vimms.MassSpec import IndependentMassSpectrometer
from vimms.Controller import TopNController, TopN_SmartRoiController, WeightedDEWController
from vimms.Environment import Environment

In [None]:
spike_noise = None
if enable_spike_noise:
    noise_params = params['noise']
    noise_density = noise_params['noise_density']
    noise_max_val = noise_params['noise_max_val']
    noise_min_mz = noise_params['mz_range'][0]
    noise_max_mz = noise_params['mz_range'][1]
    spike_noise = UniformSpikeNoise(noise_density, noise_max_val, min_mz=noise_min_mz,
                                    max_mz=noise_max_mz)

Run Top-N Controller

In [None]:
method = 'TopN_Controller'
print('method = %s' % method)
print()

chems = chem_list[0]
res = {}
for rt_tol in rt_tols:
    for N in Ns:

        effective_rt_tol = rt_tol
        mass_spec = IndependentMassSpectrometer(ionisation_mode, chems, spike_noise=spike_noise)
        controller = TopNController(ionisation_mode, N, isolation_window, mz_tol, rt_tol,
                                    min_ms1_intensity)
        env = Environment(mass_spec, controller, min_rt, max_rt, progress_bar=False, out_dir=out_dir,
                          out_file='%s_%d.mzML' % (method, i), save_eval=True)
        env.run()
        eval_res = evaluate(env, intensity_threshold)
        key = (N, rt_tol)
        print(N, rt_tol, eval_res)
        res[key] = eval_res

method_eval_results[method] = res

Run SmartROI Controller

TO FINISH BELOW

In [None]:
alphas = [2, 3, 5, 10, 1E3, 1E6]
betas = [0, 0.1, 0.5, 1, 5]
smartroi_N = 10
smartroi_dew = 15

In [None]:
method = 'SmartROI_Controller'
print('method = %s' % method)
print()

chems = chem_list[0]
res = {}
for alpha in alphas:
    for beta in betas:

        mass_spec = IndependentMassSpectrometer(ionisation_mode, chems, spike_noise=spike_noise)
        
        roi_params = RoiBuilderParams(min_roi_intensity=500, min_roi_length=0)    
        smartroi_params = SmartRoiParams(intensity_increase_factor=alpha, drop_perc=beta/100.0)
        controller = TopN_SmartRoiController(ionisation_mode, isolation_window, smartroi_N, mz_tol, smartroi_dew,
                                    min_ms1_intensity, roi_params, smartroi_params)

        env = Environment(mass_spec, controller, min_rt, max_rt, progress_bar=False, out_dir=out_dir,
                          out_file='%s_%d.mzML' % (method, i), save_eval=True)
        env.run()
        eval_res = evaluate(env, intensity_threshold)
        key = (N, rt_tol)
        print(alpha, beta, eval_res)
        res[key] = eval_res

method_eval_results[method] = res

Run WeightedDEW Controller

In [None]:
t0s = [1, 3, 10, 15, 30, 60]
t1s = [15, 60, 120, 240, 360, 3600]
weighteddew_N = 10

In [None]:
method = 'WeightedDEW_Controller'
print('method = %s' % method)
print()

chems = chem_list[0]
res = {}
for t0 in t0s:
    for t1 in t1s:

        if t0 > t1:
            print('Invalid combination')
            continue
        
        mass_spec = IndependentMassSpectrometer(ionisation_mode, chems, spike_noise=spike_noise)
        
        controller = WeightedDEWController(ionisation_mode, weighteddew_N, isolation_window, mz_tol, t1,
                                    min_ms1_intensity, exclusion_t_0=t0)
        
        env = Environment(mass_spec, controller, min_rt, max_rt, progress_bar=False, out_dir=out_dir,
                          out_file='%s_%d.mzML' % (method, i), save_eval=True)
        env.run()
        eval_res = evaluate(env, intensity_threshold)
        key = (t0, t1)
        print(t0, t1, eval_res)
        res[key] = eval_res
        
method_eval_results[method] = res