In [1]:
import warnings as warn
# warn.filterwarnings('always')
import itertools as itr

import numpy as nmp
import pandas as pnd

import pymc3 as pmc
import joblib as jbl
import clonosGP as cln

In [2]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [3]:
def run_model(prior, cov, lik, data):    
    nmp.random.seed(42)
    pmc.tt_rng(42)

    thr = 0.05 if lik == 'Bin' else 0.0
    
    res = cln.infer(data, 
                    model_args={'K': 20, 'prior': prior, 'cov': cov, 'lik': lik, 'threshold': thr}, 
                    pymc3_args={'niters': 40000, 'method': 'advi', 'flow': 'scale-loc', 'learning_rate': 1e-2, 'random_seed': 42})

    loss = nmp.quantile(res['fit'].hist[30000:], [0.025, 0.5, 0.975])

    return pnd.DataFrame({
        'PRIOR': prior,
        'COV': cov,
        'LIK': lik,
        'METRIC': 'LOSS',
        'MEDIAN': loss[1],
        'LOW': loss[0],
        'HIGH': loss[2]
    })


In [4]:
ARGS = [('Flat', 'Exp', 'Bin'), ('Flat', 'Exp', 'BBin')] + list(itr.product(['GP0', 'GP1', 'GP2', 'GP3'], ['Exp', 'Mat32', 'Mat52', 'ExpQ'], ['Bin', 'BBin']))
# FNAMES = ['cll_Rincon_2019_patient2.csv', 'cll_Schuh_2012_CLL003.csv', 'cll_Schuh_2012_CLL006.csv', 
#           'cll_Schuh_2012_CLL077.csv','melanoma_Cutts_2017.csv', 'cll_Rincon_2019_patient1.csv']
FNAMES = ['cll_Schuh_2012_CLL006.csv']
for fname in FNAMES:
    DATA = pnd.read_csv(f'data/{fname}')
    RES = jbl.Parallel(n_jobs=4, verbose=10)(jbl.delayed(run_model)(*_, DATA) for _ in ARGS)
#     RES = [run_model(*_, DATA) for _ in ARGS]
    RES = pnd.concat(RES).reset_index(drop=True)
    RES.to_csv(f'results/{fname}', index=False)

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   5 tasks      | elapsed:  2.0min
[Parallel(n_jobs=4)]: Done  10 tasks      | elapsed:  3.0min
[Parallel(n_jobs=4)]: Done  17 tasks      | elapsed:  5.2min
[Parallel(n_jobs=4)]: Done  24 tasks      | elapsed:  7.3min
[Parallel(n_jobs=4)]: Done  31 out of  34 | elapsed: 12.0min remaining:  1.2min
[Parallel(n_jobs=4)]: Done  34 out of  34 | elapsed: 14.3min finished
