In [245]:
import os
import pandas as pd
from scipy.stats import pearsonr
import numpy as np
import json
import torch.nn.functional as F
import torch
import random
from tqdm import tqdm_notebook as tqdm

In [339]:
output_dir = '/home/dc925/project/medsts/output/final_first_batch'

clinicalbert_outputs = os.path.join(output_dir, 'clinicalbert_outputs')
ncbibert_outputs = os.path.join(output_dir, 'ncbi_bert_outputs')
scibert_outputs = os.path.join(output_dir, 'scibert_outputs')
mtdnnbase_outputs = os.path.join(output_dir, 'mt_dnn_base_outputs')

In [340]:
clinicalbert_runs = [os.path.join(clinicalbert_outputs, r) for r in os.listdir(clinicalbert_outputs) if r.startswith('run')]
ncbibert_runs = [os.path.join(ncbibert_outputs, r) for r in os.listdir(ncbibert_outputs) if r.startswith('run')]
scibert_runs = [os.path.join(scibert_outputs, r) for r in os.listdir(scibert_outputs) if r.startswith('run')]
mtdnnbase_runs = [os.path.join(mtdnnbase_outputs, r) for r in os.listdir(mtdnnbase_outputs) if r.startswith('run')]


In [248]:
output_dir2 = '/home/dc925/project/medsts/output/final_second_batch'

In [249]:
clinicalbert_outputs_2 = os.path.join(output_dir2, 'clinicalbert_outputs')
ncbibert_outputs_2 = os.path.join(output_dir2, 'ncbi_bert_outputs')
scibert_outputs_2 = os.path.join(output_dir2, 'scibert_outputs')
mtdnnbase_outputs_2 = os.path.join(output_dir2, 'mt_dnn_base_outputs')
clinicalbert_runs_2 = [os.path.join(clinicalbert_outputs_2, r) for r in os.listdir(clinicalbert_outputs_2) if r.startswith('run')]
ncbibert_runs_2 = [os.path.join(ncbibert_outputs_2, r) for r in os.listdir(ncbibert_outputs_2) if r.startswith('run')]
scibert_runs_2 = [os.path.join(scibert_outputs_2, r) for r in os.listdir(scibert_outputs_2) if r.startswith('run')]
mtdnnbase_runs_2 = [os.path.join(mtdnnbase_outputs_2, r) for r in os.listdir(mtdnnbase_outputs_2) if r.startswith('run')]

In [382]:
output_dir3 = '/home/dc925/project/medsts/output/aux'
clinicalbert_outputs_aux = os.path.join(output_dir3, 'clinicalbert_aux_outputs')
ncbibert_outputs_aux = os.path.join(output_dir3, 'ncbi_bert_aux_outputs')
scibert_outputs_aux = os.path.join(output_dir3, 'scibert_aux_outputs')
mtdnnbase_outputs_aux = os.path.join(output_dir3, 'mt_dnn_base_aux_outputs')
clinicalbert_runs_aux = [os.path.join(clinicalbert_outputs_aux, r) for r in os.listdir(clinicalbert_outputs_aux) if r.startswith('run')]
ncbibert_runs_aux = [os.path.join(ncbibert_outputs_aux, r) for r in os.listdir(ncbibert_outputs_aux) if r.startswith('run')]
scibert_runs_aux = [os.path.join(scibert_outputs_aux, r) for r in os.listdir(scibert_outputs_aux) if r.startswith('run')]
mtdnnbase_runs_aux = [os.path.join(mtdnnbase_outputs_aux, r) for r in os.listdir(mtdnnbase_outputs_aux) if r.startswith('run')]

In [383]:
all_runs =  clinicalbert_runs_aux + ncbibert_runs_aux + scibert_runs_aux + mtdnnbase_runs_aux #clinicalbert_runs + mtdnnbase_runs + scibert_runs + ncbibert_runs #+clinicalbert_runs_2 + ncbibert_runs_2 + scibert_runs_2 + mtdnnbase_runs_2 #+ clinicalbert_runs_aux
#     ncbibert_runs_aux + scibert_runs_aux + mtdnnbase_runs_aux 




In [392]:
folds = [0, 1, 2, 3, 4]
kfold_model_outputs = []
for k in folds:
    outputs = [os.path.join(r, 'k_fold_{}'.format(k)) for r in all_runs]
    kfold_model_outputs.append(outputs)

In [393]:
len(kfold_model_outputs[0])

64

In [394]:
tables = []


In [395]:
all_topmodels = []

In [420]:
k = 4

In [421]:
best_val = 0
best_n = 0
for n in range(4, len(all_runs) * 2 // 5, 2):
    errors_paths = [os.path.join(p, 'checkpoint/errors.csv') for p in kfold_model_outputs[k]]
    dfs = [pd.read_csv(ep).sort_values(by='pid').set_index('pid') for ep in errors_paths]
    ensemble_table = pd.concat([dfs[0]['label'], dfs[0]['sent1'], dfs[0]['sent2']], axis=1)
    for i in range(len(dfs)):
        ensemble_table['pred{}'.format(i)] = dfs[i]['pred']
    top_cols = ensemble_table.corr().sort_values(by='label', axis=1, ascending=False).columns[1:n+1]
    
    ensemble_table['ensemble'] = ensemble_table[top_cols].mean(axis=1)
    val = ensemble_table.corr().iloc[0,-1:].item()
    if val > best_val:
        best_val = val
        best_n = n
        top_models = top_cols
        print('{} \t {}'.format(best_n, best_val))
print(best_val)
print(best_n)
top_cols = ensemble_table.corr().sort_values(by='label', axis=1, ascending=False).columns[1:best_n+1]
ensemble_table['ensemble'] = ensemble_table[top_cols].mean(axis=1)
tables.append(ensemble_table)

4 	 0.8626613808154351
8 	 0.86596166942313
10 	 0.8677331443742552
0.8677331443742552
10


In [422]:
top_models_index = [int(s.split('d')[1]) for s in top_models]

In [423]:
topmodel_paths = [kfold_model_outputs[k][i] for i in top_models_index]

In [424]:
all_topmodels += topmodel_paths

In [425]:
all_topmodels

['/home/dc925/project/medsts/output/aux/ncbi_bert_aux_outputs/run_6/k_fold_0',
 '/home/dc925/project/medsts/output/aux/ncbi_bert_aux_outputs/run_7/k_fold_0',
 '/home/dc925/project/medsts/output/aux/scibert_aux_outputs/run_8/k_fold_0',
 '/home/dc925/project/medsts/output/aux/scibert_aux_outputs/run_9/k_fold_0',
 '/home/dc925/project/medsts/output/aux/ncbi_bert_aux_outputs/run_3/k_fold_0',
 '/home/dc925/project/medsts/output/aux/scibert_aux_outputs/run_13/k_fold_0',
 '/home/dc925/project/medsts/output/aux/scibert_aux_outputs/run_5/k_fold_0',
 '/home/dc925/project/medsts/output/aux/scibert_aux_outputs/run_7/k_fold_0',
 '/home/dc925/project/medsts/output/aux/scibert_aux_outputs/run_1/k_fold_0',
 '/home/dc925/project/medsts/output/aux/scibert_aux_outputs/run_12/k_fold_0',
 '/home/dc925/project/medsts/output/aux/ncbi_bert_aux_outputs/run_4/k_fold_0',
 '/home/dc925/project/medsts/output/aux/ncbi_bert_aux_outputs/run_0/k_fold_0',
 '/home/dc925/project/medsts/output/aux/clinicalbert_aux_outputs

In [426]:
final_list = [''.join(p.split('/')[-3:-1]) for p in all_topmodels]

In [432]:
from collections import Counter

In [433]:
c = Counter(final_list)

In [434]:
c.most_common(10)

[('mt_dnn_base_aux_outputsrun_14', 3),
 ('clinicalbert_aux_outputsrun_5', 2),
 ('mt_dnn_base_aux_outputsrun_10', 2),
 ('mt_dnn_base_aux_outputsrun_15', 2),
 ('ncbi_bert_aux_outputsrun_6', 1),
 ('ncbi_bert_aux_outputsrun_7', 1),
 ('scibert_aux_outputsrun_8', 1),
 ('scibert_aux_outputsrun_9', 1),
 ('ncbi_bert_aux_outputsrun_3', 1),
 ('scibert_aux_outputsrun_13', 1)]

In [435]:
final_list_full = [','.join(''.join(p.split('/')[-3:]).split('k_')) for p in all_topmodels]

In [436]:
final_list_full

['ncbi_bert_aux_outputsrun_6,fold_0',
 'ncbi_bert_aux_outputsrun_7,fold_0',
 'scibert_aux_outputsrun_8,fold_0',
 'scibert_aux_outputsrun_9,fold_0',
 'ncbi_bert_aux_outputsrun_3,fold_0',
 'scibert_aux_outputsrun_13,fold_0',
 'scibert_aux_outputsrun_5,fold_0',
 'scibert_aux_outputsrun_7,fold_0',
 'scibert_aux_outputsrun_1,fold_0',
 'scibert_aux_outputsrun_12,fold_0',
 'ncbi_bert_aux_outputsrun_4,fold_0',
 'ncbi_bert_aux_outputsrun_0,fold_0',
 'clinicalbert_aux_outputsrun_8,fold_0',
 'ncbi_bert_aux_outputsrun_5,fold_0',
 'mt_dnn_base_aux_outputsrun_14,fold_0',
 'scibert_aux_outputsrun_3,fold_0',
 'clinicalbert_aux_outputsrun_5,fold_0',
 'clinicalbert_aux_outputsrun_1,fold_0',
 'clinicalbert_aux_outputsrun_15,fold_2',
 'scibert_aux_outputsrun_10,fold_2',
 'scibert_aux_outputsrun_11,fold_2',
 'mt_dnn_base_aux_outputsrun_10,fold_2',
 'clinicalbert_aux_outputsrun_14,fold_2',
 'scibert_aux_outputsrun_6,fold_2',
 'mt_dnn_base_aux_outputsrun_15,fold_1',
 'mt_dnn_base_aux_outputsrun_9,fold_1',
 '