In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import sys, os, re

import pandas as pd
pd.set_option('display.max_columns', 100)
pd.set_option('display.max_colwidth', 100)
pd.set_option('display.width', 80)
pd.set_option('display.max_rows', 100)

sys.path.append('/net/pulsar/home/koes/mtr22')
import param_search

expt_dir = '/net/pulsar/home/koes/mtr22/gan/torch_training'
os.chdir(expt_dir)
os.getcwd()

'/net/pulsar/home/koes/mtr22/gan/torch_training'

In [2]:
template_file = 'train.sh'
template = '''\
#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=18
#SBATCH --partition=dept_gpu
#SBATCH --gres=gpu:1
#SBATCH --mem=32gb
#SBATCH --time=672:00:00
#SBATCH --qos=normal
#SBATCH -o %J.stdout
#SBATCH -e %J.stderr
#SBATCH --dependency=singleton
{job_params}
source ~/.bashrc
cd $SLURM_SUBMIT_DIR

python3 $LIGAN_ROOT/train.py \\
    --random_seed {random_seed} \\
    --data_root {data_root} \\
    --rec_molcache {rec_molcache} \\
    --lig_molcache {lig_molcache} \\
    --train_file {train_file} \\
    --test_file {test_file} \\
    --batch_size {batch_size} \\
    --rec_map_file {rec_map_file} \\
    --lig_map_file {lig_map_file} \\
    --model_type {model_type} \\
    --skip_connect {skip_connect} \\
    --kldiv_loss_wt {kldiv_loss_wt} \\
    --recon_loss_wt {recon_loss_wt} \\
    --gan_loss_type {gan_loss_type} \\
    --gan_loss_wt {gan_loss_wt} \\
    --disc_grad_norm {disc_grad_norm} \\
    --max_iter {max_iter} \\
    --test_interval {test_interval} \\
    --n_test_batches {n_test_batches} \\
    --save_interval {save_interval} \\
    --out_prefix {job_name}
'''
with open(template_file, 'w') as f:
    f.write(template)

In [192]:
# for ligand-only models
lig_only_param_space = param_search.ParamSpace(
    data_root='/net/pulsar/home/koes/mtr22/molport',
    rec_molcache='/net/pulsar/home/koes/mtr22/gan/data/molportFULL_rec.molcache2',
    lig_molcache='/net/pulsar/home/koes/mtr22/gan/data/molportFULL_lig.molcache2',
    train_file='/net/pulsar/home/koes/mtr22/gan/data/molportFULL_rand_train0.types',
    test_file='/net/pulsar/home/koes/mtr22/gan/data/molportFULL_rand_test0.types',
    skip_connect=False
)

# for receptor-conditional models
rec_cond_param_space = param_search.ParamSpace(
    data_root='/net/pulsar/home/koes/paf46/Research/CrossDocking_script/PocketomeOutput/PocketomeGenCross_Output',
    rec_molcache='/net/pulsar/home/koes/paf46/git/cnnaffinitypaper/models/crossdock2020_rec.molcache2',
    lig_molcache='/net/pulsar/home/koes/paf46/git/cnnaffinitypaper/models/crossdock2020_lig.molcache2',
    train_file='/net/pulsar/home/koes/paf46/git/cnnaffinitypaper/types/it2_tt_0_train0.types',
    test_file='/net/pulsar/home/koes/paf46/git/cnnaffinitypaper/types/it2_tt_0_test0.types',
    skip_connect=True
)

# general parameters
general_param_space = param_search.ParamSpace(
    random_seed=[0, 1, 2],
    batch_size=10,
    rec_map_file='/net/pulsar/home/koes/mtr22/gan/data/my_rec_map',
    lig_map_file='/net/pulsar/home/koes/mtr22/gan/data/my_lig_map',
    kldiv_loss_wt=0.1,
    recon_loss_wt=1,
    disc_grad_norm_type='2',
    max_iter=100000,
    test_interval=100,
    n_test_batches=10,
    fit_interval=1000,
    save_interval=10000,
    n_latent=[128, 256, 1024],
)

lig_only_param_space.update(general_param_space)
rec_cond_param_space.update(general_param_space)

# model type-specific parameters
ae_vae_param_space = param_search.ParamSpace(
    model_type=['AE', 'VAE'],
    optim_type=['Adam'],
    learning_rate=[1e-5],
    gan_loss_type='0',
    extra_sbatch_line='',
    **lig_only_param_space,
)

ce_cvae_param_space = param_search.ParamSpace(
    model_type=['CE', 'CVAE'],
    optim_type=['Adam'],
    learning_rate=[1e-5],
    gan_loss_type='0',
    extra_sbatch_line='',
    **rec_cond_param_space,
)

gan_param_space = param_search.ParamSpace(
    model_type=['GAN'],
    optim_type=['Adam'],
    learning_rate=[1e-5],
    gan_loss_type='x',
    gan_loss_wt=1,
    extra_sbatch_line='',
    **lig_only_param_space,
)

wgan_param_space = param_search.ParamSpace(
    model_type=['GAN'],
    optim_type=['RMSprop'],
    learning_rate=[1e-7],
    gan_loss_type='w',
    gan_loss_wt=1,
    extra_sbatch_line='',
    **lig_only_param_space,
)

cgan_param_space = param_search.ParamSpace(
    model_type=['CGAN'],
    optim_type=['Adam'],
    learning_rate=[1e-5],
    gan_loss_type='x',
    gan_loss_wt=1,
    extra_sbatch_line='',
    **rec_cond_param_space,
)

cwgan_param_space = param_search.ParamSpace(
    model_type=['CGAN'],
    optim_type=['RMSprop'],
    learning_rate=[1e-7],
    gan_loss_type='w',
    gan_loss_wt=1,
    extra_sbatch_line='',
    **rec_cond_param_space,
)

# dual encoder models require 12gb GPUs

vaegan_param_space = param_search.ParamSpace(
    model_type=['VAEGAN'],
    optim_type=['Adam'],
    learning_rate=[1e-5],
    gan_loss_type='x',
    gan_loss_wt=10,
    extra_sbatch_line='#SBATCH -C M12',
    **lig_only_param_space,
)

vaewgan_param_space = param_search.ParamSpace(
    model_type=['VAEGAN'],
    optim_type=['RMSprop'],
    learning_rate=[1e-7],
    gan_loss_type='w',
    gan_loss_wt=1,
    extra_sbatch_line='#SBATCH -C M12',
    **lig_only_param_space,
)

cvaegan_param_space = param_search.ParamSpace(
    model_type=['CVAEGAN'],
    optim_type=['Adam'],
    learning_rate=[1e-5],
    gan_loss_type='x',
    gan_loss_wt=10,
    extra_sbatch_line='#SBATCH -C M12',
    **rec_cond_param_space,
)

cvaewgan_param_space = param_search.ParamSpace(
    model_type=['CVAEGAN'],
    optim_type=['RMSprop'],
    learning_rate=[1e-7],
    gan_loss_type='w',
    gan_loss_wt=1,
    extra_sbatch_line='#SBATCH -C M12',
    **rec_cond_param_space,
)

all_param_spaces = [

    gan_param_space,
    wgan_param_space,
    cgan_param_space,
    cwgan_param_space,
    vaegan_param_space,
    vaewgan_param_space,
    cvaegan_param_space,
    cvaewgan_param_space,

    #ae_vae_param_space,
    #ce_cvae_param_space,
]

[len(p) for p in all_param_spaces]

[9, 9, 9, 9, 9, 9, 9, 9]

In [183]:
job_files = param_search.setup(
    expt_dir=expt_dir,
    name_format='train_{model_type}_8_{n_latent}_{gan_loss_type}_{random_seed}',
    template_file='train.sh',
    param_space=sum(map(list, all_param_spaces), [])
)
len(job_files) 

# Submit training jobs to cluster

In [320]:
job_ids = param_search.submit(job_files[8])
print(job_ids)

[6736277, 6736278, 6736279, 6736280, 6736281, 6736282, 6736283, 6736284, 6736285, 6736286, 6736287, 6736288, 6736289, 6736290, 6736291, 6736292, 6736293, 6736294, 6736295, 6736296, 6736297, 6736298, 6736299, 6736300, 6736301, 6736302, 6736303, 6736304, 6736305, 6736306, 6736307, 6736308, 6736309, 6736310, 6736311, 6736312, 6736313, 6736314, 6736315, 6736316, 6736317, 6736318, 6736319, 6736320, 6736321, 6736322, 6736323, 6736324, 6736325, 6736326, 6736327, 6736328, 6736329, 6736330, 6736331, 6736332, 6736333, 6736334, 6736335, 6736336, 6736337, 6736338, 6736339, 6736340, 6736341, 6736342, 6736343, 6736344, 6736345, 6736346, 6736347, 6736348]


In [136]:
import numpy as np
import glob

experiment = pd.DataFrame(dict(
    group_name=[
        'non_gan',
        'train_fitting',
    ],
    job_id=[
        [6733823, 6733824, 6733825, 6733826],
        [6736277, 6736278, 6736279, 6736280, 6736281, 6736282, 6736283, 6736284, 6736285, 6736286, 6736287, 6736288, 6736289, 6736290, 6736291, 6736292, 6736293, 6736294, 6736295, 6736296, 6736297, 6736298, 6736299, 6736300, 6736301, 6736302, 6736303, 6736304, 6736305, 6736306, 6736307, 6736308, 6736309, 6736310, 6736311, 6736312, 6736313, 6736314, 6736315, 6736316, 6736317, 6736318, 6736319, 6736320, 6736321, 6736322, 6736323, 6736324, 6736325, 6736326, 6736327, 6736328, 6736329, 6736330, 6736331, 6736332, 6736333, 6736334, 6736335, 6736336, 6736337, 6736338, 6736339, 6736340, 6736341, 6736342, 6736343, 6736344, 6736345, 6736346, 6736347, 6736348],
    ]
)).explode('job_id')
experiment

Unnamed: 0,group_name,job_id
0,non_gan,6733823
0,non_gan,6733824
0,non_gan,6733825
0,non_gan,6733826
1,train_fitting,6736277
1,train_fitting,6736278
1,train_fitting,6736279
1,train_fitting,6736280
1,train_fitting,6736281
1,train_fitting,6736282


In [186]:
import param_search

queue_status = param_search.status()
queue_status[['job_id', 'array_idx']] = queue_status.job_id.str.split('_', expand=True).apply(pd.to_numeric)
queue_status[queue_status.user == 'mtr22']

df = experiment.merge(queue_status, on='job_id', how='left')

def find_work_dir(x):
    if not pd.isnull(x['work_dir']):
        return x['work_dir']
    f = glob.glob('*/{}.*'.format(x['job_id']))
    return os.path.dirname(f[0])

def find_job_state(x):
    if not pd.isnull(x['job_state']):
        return x['job_state']
    if x['stderr'] is not None:
        return 'ERR'
    return 'OK'

def catch_exc(func, exc_type=FileNotFoundError, default=None):
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except exc_type as e:
            return default
    return wrapper

df['work_dir'] = df.apply(find_work_dir, axis=1)
df['job_name'] = df['work_dir'].map(os.path.basename)
df['stdout_file'] = df.apply(lambda x: '{work_dir}/{job_id}.stdout'.format(**x), axis=1)
df['stderr_file'] = df.apply(lambda x: '{work_dir}/{job_id}.stderr'.format(**x), axis=1)
df['stdout'] = df['stdout_file'].map(catch_exc(param_search.job_output.read_stdout_file))
df['stderr'] = df['stderr_file'].map(catch_exc(param_search.job_output.read_stderr_file))
df['job_state'] = df.apply(find_job_state, axis=1)

In [187]:
df.set_index(['group_name', 'job_name', 'job_id'])[['job_state', 'stdout', 'stderr']]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,job_state,stdout,stderr
group_name,job_name,job_id,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
non_gan,train_AE_7,6733823,R,[iteration=383449 phase=train] loss=34.3289 recon_loss=34.3289 lig_norm=128.6688 lig_gen_norm=12...,
non_gan,train_VAE_7,6733824,R,[iteration=414528 phase=train] loss=82.1720 recon_loss=48.3153 kldiv_loss=338.5666 lig_norm=124....,
non_gan,train_CE_7,6733825,R,[iteration=352135 phase=train] loss=555.5048 recon_loss=555.5048 lig_norm=112.1223 lig_gen_norm=...,
non_gan,train_CVAE_7,6733826,R,[iteration=284864 phase=train] loss=136.5785 recon_loss=101.8774 kldiv_loss=347.0108 lig_norm=12...,
train_fitting,train_VAEGAN_8_1024_x_1,6736277,ERR,,slurmstepd: error: Detected 1 oom-kill event(s) in step 6736277.batch cgroup. Some of your proce...
train_fitting,train_VAEGAN_8_256_x_2,6736278,R,[iteration=48286 disc_iter=48286 phase=train model=disc batch=0] loss=nan gan_loss=nan lig_norm=...,
train_fitting,train_VAEGAN_8_128_w_0,6736279,ERR,UFF Exception,NameError: name 'traceback' is not defined
train_fitting,train_VAEGAN_8_1024_w_0,6736280,ERR,,slurmstepd: error: Detected 1 oom-kill event(s) in step 6736280.batch cgroup. Some of your proce...
train_fitting,train_VAEGAN_8_128_w_1,6736281,ERR,,slurmstepd: error: Detected 1 oom-kill event(s) in step 6736281.batch cgroup. Some of your proce...
train_fitting,train_VAEGAN_8_1024_w_1,6736282,R,[iteration=50516 disc_iter=50518 phase=train model=gen batch=0] loss=-41551336526031827464071633...,


# Read in training output metrics

In [190]:
# READ IN METRIC FILES
metrics = param_search.metrics(glob.glob('train_*_8_*/train.sh'), metric_pat='(.*)\.metrics')

train_VAEGAN_8_1024_x_1/train.sh No objects to concatenate
train_VAEGAN_8_1024_w_0/train.sh No objects to concatenate
train_VAEGAN_8_128_w_1/train.sh No objects to concatenate
train_VAEGAN_8_256_w_2/train.sh No objects to concatenate
train_VAEGAN_8_1024_w_2/train.sh No objects to concatenate
train_CVAEGAN_8_128_x_1/train.sh No objects to concatenate
train_CVAEGAN_8_256_x_1/train.sh No objects to concatenate
train_CVAEGAN_8_128_x_2/train.sh No objects to concatenate
train_CVAEGAN_8_1024_x_2/train.sh No objects to concatenate
train_CVAEGAN_8_128_w_0/train.sh No objects to concatenate
train_CVAEGAN_8_1024_w_0/train.sh No objects to concatenate
train_CVAEGAN_8_256_w_1/train.sh No objects to concatenate
train_CVAEGAN_8_1024_w_1/train.sh No objects to concatenate
train_CVAEGAN_8_256_w_2/train.sh No objects to concatenate
train_GAN_8_256_x_1/train.sh No objects to concatenate
train_GAN_8_1024_x_1/train.sh No objects to concatenate
train_GAN_8_128_x_2/train.sh No objects to concatenate
train_G

In [194]:
def dtype(s):
    return s.dtype

metric_index_cols = ['iteration', 'disc_iter', 'phase', 'model', 'batch']
all_index_cols = list(rec_cond_param_space.keys()) + metric_index_cols

metrics[all_index_cols].agg([dtype, pd.Series.nunique, pd.Series.unique]).transpose()

Unnamed: 0,dtype,nunique,unique
data_root,object,2,"[/net/pulsar/home/koes/mtr22/molport, /net/pulsar/home/koes/paf46/Research/CrossDocking_script/P..."
rec_molcache,object,2,"[/net/pulsar/home/koes/mtr22/gan/data/molportFULL_rec.molcache2, /net/pulsar/home/koes/paf46/git..."
lig_molcache,object,2,"[/net/pulsar/home/koes/mtr22/gan/data/molportFULL_lig.molcache2, /net/pulsar/home/koes/paf46/git..."
train_file,object,2,"[/net/pulsar/home/koes/mtr22/gan/data/molportFULL_rand_train0.types, /net/pulsar/home/koes/paf46..."
test_file,object,2,"[/net/pulsar/home/koes/mtr22/gan/data/molportFULL_rand_test0.types, /net/pulsar/home/koes/paf46/..."
skip_connect,bool,2,"[False, True]"
random_seed,int64,3,"[2, 0, 1]"
batch_size,int64,1,[10]
rec_map_file,object,1,[/net/pulsar/home/koes/mtr22/gan/data/my_rec_map]
lig_map_file,object,1,[/net/pulsar/home/koes/mtr22/gan/data/my_lig_map]


In [195]:
#metrics['iteration'] = metrics['iteration'].mask(metrics['iteration'].isna(), metrics['gen_iter'], axis=0)
#metrics['recon_loss_fixed'] = metrics['recon_loss'] * (19*48*48*48) / 2

metrics['model'] = metrics['model'].fillna('gen')
metrics['real'] = metrics['real'].fillna(False)

metrics[
    metrics['iteration'] == metrics.groupby('model_type')['iteration'].transform(max)
].groupby(['model_type', 'iteration', 'phase', 'model'])[
    ['loss', 'kldiv_loss', 'recon_loss', 'gan_loss', 'gen_grad_norm', 'disc_grad_norm', 'lig_norm', 'lig_gen_norm']
].mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,loss,kldiv_loss,recon_loss,gan_loss,gen_grad_norm,disc_grad_norm,lig_norm,lig_gen_norm
model_type,iteration,phase,model,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
CVAEGAN,41300,test,disc,,,,,,,36.272277,
CVAEGAN,41300,test,gen,,,,,,,37.963648,
GAN,65800,test,disc,,,,,,,38.342677,
GAN,65800,test,gen,,,,,,,,
VAEGAN,50600,test,disc,1.5726530000000002e+31,1.011927e+22,1.940239e+28,1.571683e+31,,,38.800762,179367900000000.0
VAEGAN,50600,test,gen,-3.1853920000000003e+31,1.00558e+22,1.924106e+28,-3.1873160000000005e+31,,,38.355567,181871800000000.0


In [196]:
import numpy as np
import scipy.stats
import param_search

def filled_lines(data, x, y, hue, ax, **kwargs):
    for h, hue_data in data.groupby(hue):
        try:
            mean = hue_data.groupby(x)[y].agg(np.nanmean)
            sem = hue_data.groupby(x)[y].agg(scipy.stats.sem)
            ax.fill_between(mean.index, mean - 2*sem, mean + 2*sem, alpha=0.5, label=h)
            ax.plot(mean.index, mean, label=h)
        except Exception as e:
            print(e, hue_data[y].dtype, y)
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    
iter_bin_size = 1000
metrics['iter_bin'] = ((metrics['iteration'] // iter_bin_size) * iter_bin_size).astype(int)

model_types0 = ['AE', 'CE', 'VAE', 'CVAE']
model_types1 = ['GAN', 'CGAN', 'VAEGAN', 'CVAEGAN']

fig = param_search.plot(
    metrics[
        (metrics['model_type'].isin('GAN')) &
        (metrics['phase'] == 'test')
    ],
    x='iter_bin',
    y=['loss', 'kldiv_loss', 'recon_loss', 'gan_loss', 'gen_grad_norm', 'disc_grad_norm', 'lig_norm', 'lig_gen_norm'],
    ylim=dict(]),
    hue=('model_type', 'model', 'phase'),
    plot_func=filled_lines,
    n_cols=4, height=3.5
)

SyntaxError: invalid syntax (<ipython-input-196-614e6131cfc0>, line 30)