# Prepare result table:

In [1]:
%load_ext autoreload
%autoreload 2
import collections
import os
import pandas as pd
import numpy as np
import pickle
import json
import sys
import tensorflow as tf
from vis.utils import utils
from loguru  import logger
from tqdm import tqdm
from screening.validation.crossval import crossval_table, crossval_ref_filter, crossval_max_value_filter
from pprint import pprint

physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices)>0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
tf.config.run_functions_eagerly(False)

2024-07-31 18:25:05.048407: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-31 18:25:05.373036: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-31 18:25:30.323037: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-07-31 18:25:30.628957: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node 

## Best models:

In [2]:

def create_op_dict( op_name, extra_suf="" ):

    d = collections.OrderedDict( {
                'max_sp'           : f'summary{extra_suf}/max_sp',
                'auc'              : f'summary{extra_suf}/auc',
                #'acc'              : f'summary{extra_suf}/acc',
                #'pd'               : f'summary{extra_suf}/pd',
                #'fa'               : f'summary{extra_suf}/fa',
                'sens'             : f'summary{extra_suf}/sensitivity',
                'spec'             : f'summary{extra_suf}/specificity',
                'threshold'        : f'summary{extra_suf}/threshold',
                'roc'              : f'summary{extra_suf}/roc',
                'roc_val'          : f'summary{extra_suf}/roc_val',
                'roc_op'           : f'summary{extra_suf}/roc_op',
                'roc_test'         : f'summary{extra_suf}/roc_test',

                'min_spec_sens_reached' : f'{op_name}{extra_suf}/min_spec_sens_reached',

                'max_sp_val'       : f'summary{extra_suf}/max_sp_val',
                'auc_val'          : f'summary{extra_suf}/auc_val',
                #'acc_val'          : f'summary{extra_suf}/acc_val',
                #'pd_val'           : f'summary{extra_suf}/pd_val',
                #'fa_val'           : f'summary{extra_suf}/fa_val',
                'sens_val'         : f'summary{extra_suf}/sensitivity_val',
                'spec_val'         : f'summary{extra_suf}/specificity_val', 

                'max_sp_test'      : f'summary{extra_suf}/max_sp_test',
                'auc_test'         : f'summary{extra_suf}/auc_test',
                #'acc_test'         : f'summary{extra_suf}/acc_test',
                #'pd_test'          : f'summary{extra_suf}/pd_test',
                #'fa_test'          : f'summary{extra_suf}/fa_test',
                'sens_test'        : f'summary{extra_suf}/sensitivity_test',
                'spec_test'        : f'summary{extra_suf}/specificity_test', 

                'max_sp_op'        : f'summary{extra_suf}/max_sp_op',
                'auc_op'           : f'summary{extra_suf}/auc_op',
                #'acc_op'           : f'summary{extra_suf}/acc_op',
                #'pd_op'            : f'summary{extra_suf}/pd_op',
                #'fa_op'            : f'summary{extra_suf}/fa_op',
                'sens_op'          : f'summary{extra_suf}/sensitivity_op',
                'spec_op'          : f'summary{extra_suf}/specificity_op', 

                'sp_index'         : f'{op_name}{extra_suf}/sp_index',
                'sens_at'          : f'{op_name}{extra_suf}/sensitivity',
                'spec_at'          : f'{op_name}{extra_suf}/specificity',
                'acc_at'           : f'{op_name}{extra_suf}/acc',
                'threshold_at'     : f'{op_name}{extra_suf}/threshold',

                'sp_index_val'     : f'{op_name}{extra_suf}/sp_index_val',
                'sens_at_val'      : f'{op_name}{extra_suf}/sensitivity_val',
                'spec_at_val'      : f'{op_name}{extra_suf}/specificity_val',
                #'acc_at_val'       : f'{op_name}{extra_suf}/acc_val',
                'threshold_at_val' : f'{op_name}{extra_suf}/threshold_val',

                'sp_index_test'    : f'{op_name}{extra_suf}/sp_index_test',
                'sens_at_test'     : f'{op_name}{extra_suf}/sensitivity_test',
                'spec_at_test'     : f'{op_name}{extra_suf}/specificity_test',
                #'acc_at_test'      : f'{op_name}{extra_suf}/acc_test',
                'threshold_at_test': f'{op_name}{extra_suf}/threshold_test',

                'sp_index_op'      : f'{op_name}{extra_suf}/sp_index_op',
                'sens_at_op'       : f'{op_name}{extra_suf}/sensitivity_op',
                'spec_at_op'       : f'{op_name}{extra_suf}/specificity_op',
                #'acc_at_op'        : f'{op_name}{extra_suf}/acc_op',
                'threshold_at_op'  : f'{op_name}{extra_suf}/threshold_op',

                'inference'        : f'{op_name}{extra_suf}/inference',
        })
    return d


#extra_suf='_val'
extra_suf=''
conf_dict = collections.OrderedDict(
    {
        'sens90'   : create_op_dict( 'sens90'  , extra_suf=extra_suf ),
        'max_sp'   : create_op_dict( 'max_sp' , extra_suf=extra_suf),
        'spec70'   : create_op_dict( 'spec70'  , extra_suf=extra_suf),
    }
)
pprint(conf_dict)

OrderedDict([('sens90',
              OrderedDict([('max_sp', 'summary/max_sp'),
                           ('auc', 'summary/auc'),
                           ('sens', 'summary/sensitivity'),
                           ('spec', 'summary/specificity'),
                           ('threshold', 'summary/threshold'),
                           ('roc', 'summary/roc'),
                           ('roc_val', 'summary/roc_val'),
                           ('roc_op', 'summary/roc_op'),
                           ('roc_test', 'summary/roc_test'),
                           ('min_spec_sens_reached',
                            'sens90/min_spec_sens_reached'),
                           ('max_sp_val', 'summary/max_sp_val'),
                           ('auc_val', 'summary/auc_val'),
                           ('sens_val', 'summary/sensitivity_val'),
                           ('spec_val', 'summary/specificity_val'),
                           ('max_sp_test', 'summary/max_sp_test'),
                

In [3]:
basepath='/mnt/brics_data/joao.pinto'


models = [
    ( 'user.philipp.gaspar.convnets_v0.altogether.shenzhen_santacasa.exp_wgan_p2p.67de4190c1.r2'                , 'v0.alto.sh-sc.ewp'         ),
    ( 'user.philipp.gaspar.convnets_v0.altogether.shenzhen_santacasa.exp_wgan_p2p_cycle.a19a3a4f8c.r2'          , 'v0.alto.sh-sc.ewpc'        ),
    ( 'user.philipp.gaspar.convnets_v0.altogether.shenzhen_santacasa_manaus.exp_wgan_p2p.0d13030165.r2'         , 'v0.alto.sh-sc-ma.ewp'      ),
    ( 'user.philipp.gaspar.convnets_v0.altogether.shenzhen_santacasa_manaus.exp_wgan_p2p_cycle.c5143abd1b.r2'   , 'v0.alto.sh-sc-ma.ewpc'     ),
    
    ( 'user.philipp.gaspar.convnets_v0.baseline.shenzhen_santacasa.exp.989f87bed5.r2'                           , 'v0.base.sh-sc.e'           ),
    ( 'user.philipp.gaspar.convnets_v0.baseline.shenzhen_santacasa_manaus.exp.ffe6cbee11.r2'                    , 'v0.base.sh-sc-ma.e'        ),
    
    ( 'user.philipp.gaspar.convnets_v0.interleaved.shenzhen_santacasa.exp_wgan_p2p.e540d24b4b.r2'               , 'v0.inte.sh-sc.ewp'         ),
    ( 'user.philipp.gaspar.convnets_v0.interleaved.shenzhen_santacasa.exp_wgan_p2p_cycle.a19a3a4f8c.r2'         , 'v0.inte.sh-sc.ewpc'        ),
    ( 'user.philipp.gaspar.convnets_v0.interleaved.shenzhen_santacasa_manaus.exp_wgan_p2p.ac79954ba0.r2'        , 'v0.inte.sh-sc-ma.ewp'      ),
    ( 'user.philipp.gaspar.convnets_v0.interleaved.shenzhen_santacasa_manaus.exp_wgan_p2p_cycle.c5143abd1b.r2'  , 'v0.inte.sh-sc-ma.ewpc'     ),

    ( 'user.philipp.gaspar.convnets_v1.baseline.shenzhen_santacasa.exp.20240303.r2'                             , 'v1.base.sh-sc.e'           ),
    ( 'user.philipp.gaspar.convnets_v1.interleaved.shenzhen_santacasa.exp_wgan_p2p.20240303.r2'                 , 'v1.inte.sh-sc.ewp'         ),
    ( 'user.philipp.gaspar.convnets_v1.altogether.shenzhen_santacasa.exp_wgan_p2p.20240303.r2'                  , 'v1.alto.sh-sc.ewp'         ),
    ( 'user.philipp.gaspar.convnets_v1.baseline.shenzhen_santacasa_manaus.exp.20240303.r2'                      , 'v1.base.sh-sc-ma.e'        ),
    ( 'user.philipp.gaspar.convnets_v1.interleaved.shenzhen_santacasa_manaus.exp_wgan_p2p.20240303.r2'          , 'v1.inte.sh-sc-ma.ewp'      ),
    ( 'user.philipp.gaspar.convnets_v1.altogether.shenzhen_santacasa_manaus.exp_wgan_p2p.20240303.r2'           , 'v1.alto.sh-sc-ma.ewp'      ),
]



In [87]:
def extract_inference_table( table ):

    def get_datasets( row ):
        def decorate(d, name):
            d.insert(0,'dataset',name)
            return d
        return pd.concat([decorate(row.inference[name]['probs'],name) for name in row.inference.keys()],axis='rows')
    data_list = []
    for idx, row in tqdm(table.iterrows(), total=len(table),desc='processing...'):    
        data = get_datasets(row)
        data.insert(0, 'file_name' , row.file_name)
        data.insert(0, 'sort', row.sort)
        data.insert(0, 'test', row.test)
        data.insert(0, 'op_name', row.op_name)
        data.insert(0, 'train_tag', row.train_tag)
        data_list.append(data)
    data=pd.concat(data_list,axis='rows')
    table.drop(columns=['inference'],inplace=True)
    return data

In [88]:
cv = crossval_table( conf_dict )
#for path, train_tag in models:
#    cv.fill( basepath+'/'+path , train_tag )
cv.table = pd.read_pickle("table_results.pkl")
table = cv.table
inference = extract_inference_table( table )
#inference.to_pickle("table_inference.pkl")
#table.to_pickle("table_results.pkl")

processing...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4320/4320 [00:08<00:00, 539.87it/s]


In [89]:
table.head()

Unnamed: 0,train_tag,op_name,test,sort,file_name,max_sp,auc,sens,spec,threshold,...,spec_at_val,threshold_at_val,sp_index_test,sens_at_test,spec_at_test,threshold_at_test,sp_index_op,sens_at_op,spec_at_op,threshold_at_op
0,v0.alto.sh-sc.ewp,sens90,3,0,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,0.886746,0.944909,0.887129,0.886364,0.441392,...,0.8,0.34974,0.782621,0.727273,0.84,0.34974,0.872353,0.897959,0.847118,0.34974
1,v0.alto.sh-sc.ewp,max_sp,3,0,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,0.886746,0.944909,0.887129,0.886364,0.441392,...,0.8,0.34974,0.782621,0.727273,0.84,0.34974,0.872353,0.897959,0.847118,0.34974
2,v0.alto.sh-sc.ewp,spec70,3,0,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,0.886746,0.944909,0.887129,0.886364,0.441392,...,0.7,0.156788,0.768306,0.818182,0.72,0.156788,0.822598,0.944341,0.709273,0.156788
3,v0.alto.sh-sc.ewp,sens90,5,2,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,0.841054,0.907026,0.812623,0.869973,0.47497,...,0.82,0.296286,0.748762,0.757576,0.74,0.296286,0.800813,0.888889,0.717337,0.296286
4,v0.alto.sh-sc.ewp,max_sp,5,2,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,0.841054,0.907026,0.812623,0.869973,0.47497,...,0.82,0.296286,0.748762,0.757576,0.74,0.296286,0.800813,0.888889,0.717337,0.296286


In [90]:
inference.head()

Unnamed: 0,train_tag,op_name,test,sort,file_name,dataset,project_id,target,y_prob,y
0,v0.alto.sh-sc.ewp,sens90,3,0,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,russia,russi_AAA_13_01_1999_043AAC,1,0.011683,False
1,v0.alto.sh-sc.ewp,sens90,3,0,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,russia,russi_ASM_13_12_1980_34F0E8,1,0.307573,False
2,v0.alto.sh-sc.ewp,sens90,3,0,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,russia,russi_DAS_16_03_2000_299323,1,0.052381,False
3,v0.alto.sh-sc.ewp,sens90,3,0,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,russia,russi_DTU_23_01_1993_DC263C,1,0.40187,True
4,v0.alto.sh-sc.ewp,sens90,3,0,/mnt/brics_data/joao.pinto/user.philipp.gaspar...,russia,russi_DVV_26_07_1974_2A0A3C,1,0.617209,True


In [91]:


class crossval_ref_filter:
    def __init__(self, ref, ref_col_name, max_col_name, test_key):
        self.ref=ref
        self.ref_col_name=ref_col_name
        self.max_col_name=max_col_name
        self.test_key=test_key


    def __call__( self ,table,  ref, ref_col_name, max_col_name, group_col_names,   col_count : str='test', step : float=0.01):
        count=len(table[col_count].unique())
        def is_in(row, delta):
            return 0<abs(row[ref_col_name]-ref)<delta
        train_tags = table['train_tag'].unique().tolist()

        tables = []
        for train_tag in train_tags:
            table_train_tag = table.loc[table.train_tag==train_tag].copy()
            for delta in np.arange(0,1+step,step):
                table_train_tag['is_in'] = table_train_tag.apply(lambda row : is_in(row,delta), axis='columns')
                # NOTE: force to have always the same number of tests boxes
                if len(table_train_tag.loc[ table_train_tag['is_in']==True][col_count].unique())==count:
                    break
            tables.append(table_train_tag.loc[table_train_tag['is_in']==True])
            
        table = pd.concat(tables, axis='rows')
        table.drop(columns=['is_in'], inplace=True)
        idxmask = table.groupby(group_col_names)[max_col_name].idxmax().values
        return table.loc[idxmask]


    def filter_sorts( self, table ):
        return self.__call__(table, self.ref, self.ref_col_name, self.max_col_name, ['train_tag','op_name','test'], 
                             col_count='test',step=0.01)

    def filter_tests( self, best_sorts):
        idxmask = best_sorts.groupby(['train_tag', 'op_name'])[self.test_key].idxmax().values
        return best_sorts.loc[idxmask]


class crossval_max_value_filter:
    def __init__(self, sort_key, test_key):
        self.sort_key=sort_key
        self.test_key=test_key
    def filter_sorts( self, table ):
        idxmask = table.groupby(['train_tag', 'op_name','test'])[self.sort_key].idxmax().values
        return table.loc[idxmask]
    def filter_tests( self, best_sorts ):
        idxmask = best_sorts.groupby(['train_tag', 'op_name'])[self.test_key].idxmax().values
        return best_sorts.loc[idxmask]


def apply_filters( table):

    best_sorts_list = []
    best_tests_list = []

    filters = {
        'sens90' : crossval_ref_filter(0.9, 'sens_op', 'spec_op', test_key='spec_op'),
        'max_sp' : crossval_max_value_filter(sort_key='max_sp_op', test_key='max_sp_op'),
        'spec70' : crossval_ref_filter(0.7, 'spec_op', 'sens_op', test_key='sens_op'),
    }
    
    for op_name in table.op_name.unique():
        best_sorts = filters[op_name].filter_sorts( table.loc[table.op_name==op_name])
        best_tests = filters[op_name].filter_tests( best_sorts )
        best_sorts_list.append(best_sorts)
        best_tests_list.append(best_tests)
    return pd.concat(best_sorts_list, axis='rows'), pd.concat(best_tests_list, axis='rows')

best_sorts, best_tests = apply_filters(table)


In [97]:
best_tests[['train_tag','op_name','test','sort','sens_val','spec_val','sens_op','spec_op','sens_test','spec_test', 'file_name']]

Unnamed: 0,train_tag,op_name,test,sort,sens_val,spec_val,sens_op,spec_op,sens_test,spec_test,file_name
798,v0.alto.sh-sc-ma.ewp,sens90,0,8,0.868421,0.887097,0.920474,0.965278,0.763158,0.938462,/mnt/brics_data/joao.pinto/user.philipp.gaspar...
828,v0.alto.sh-sc-ma.ewpc,sens90,7,0,0.684211,0.892308,0.935691,0.967767,0.763158,0.888889,/mnt/brics_data/joao.pinto/user.philipp.gaspar...
6,v0.alto.sh-sc.ewp,sens90,4,0,0.852941,0.92,0.919476,0.962312,0.69697,0.84,/mnt/brics_data/joao.pinto/user.philipp.gaspar...
354,v0.alto.sh-sc.ewpc,sens90,9,7,0.823529,0.895833,0.884547,0.957265,0.794118,0.770833,/mnt/brics_data/joao.pinto/user.philipp.gaspar...
1362,v0.base.sh-sc-ma.e,sens90,7,0,0.868421,0.923077,0.908012,0.936097,0.921053,0.84127,/mnt/brics_data/joao.pinto/user.philipp.gaspar...
1230,v0.base.sh-sc.e,sens90,2,7,0.794118,0.9375,0.891089,0.936795,0.636364,0.9,/mnt/brics_data/joao.pinto/user.philipp.gaspar...
2184,v0.inte.sh-sc-ma.ewp,sens90,1,2,0.648649,0.923077,0.880471,0.953373,0.736842,0.830769,/mnt/brics_data/joao.pinto/user.philipp.gaspar...
2580,v0.inte.sh-sc-ma.ewpc,sens90,3,6,0.842105,0.84127,0.730645,0.826291,0.72973,0.892308,/mnt/brics_data/joao.pinto/user.philipp.gaspar...
1647,v0.inte.sh-sc.ewp,sens90,8,8,0.852941,0.916667,0.914339,0.960445,0.823529,0.9375,/mnt/brics_data/joao.pinto/user.philipp.gaspar...
2061,v0.inte.sh-sc.ewpc,sens90,6,0,0.823529,0.84,0.787719,0.862999,0.823529,0.857143,/mnt/brics_data/joao.pinto/user.philipp.gaspar...


In [98]:
best_tests.to_pickle("table_best_models.pkl")

## Prepare datasets:

In [99]:
def read_datasets(basepath : str="/mnt/brics_data/public/datasets"):

    datasets = [
        f"{basepath}/Shenzhen/china/raw/Shenzhen_china_table_from_raw.csv",
        f"{basepath}/SantaCasa/imageamento_anonimizado_valid/raw/SantaCasa_imageamento_anonimizado_valid_table_from_raw.csv",
        f"{basepath}/Manaus/manaus/raw/Manaus_manaus_table_from_raw.csv",
        f"{basepath}/Caxias/caxias/raw/images.csv",
        f"{basepath}/Indonesia/indonesia/raw/images.csv",
        f"{basepath}/Russia/russia/raw/images.csv",
        f"{basepath}/Rio/fiocruz/raw/Rio_fiocruz_table_from_raw.csv",
    ]

    blacklists = [
        f"{basepath}/Caxias/caxias/raw/blacklist.pkl",#
        f"{basepath}/Indonesia/indonesia/raw/blacklist.pkl", #
        f"{basepath}/Russia/russia/raw/blacklist.pkl", #
    ]
    
    data_list = []
    for idx,path in enumerate(datasets):
        raw_path = '/'.join(path.split('/')[:-1])
        data=pd.read_csv(path, index_col=0)
        def append_basepath(row):
            return f"{raw_path}/{row.image_path}"
        data['image_path'] = data.apply(lambda row : append_basepath(row), axis='columns')
        data_list.append( data )
    data = pd.concat(data_list,axis='rows')
    blacklist = []
    for path in blacklists:
        images = pickle.load(open(path,'rb'))['black_list']
        blacklist.extend(images)
    #pprint(blacklist)
    data = data[~data['project_id'].isin(blacklist)]
    data=data.reset_index()
    return data

data = read_datasets()

In [100]:
data.head()

Unnamed: 0,index,dataset_name,project_id,image_path,insertion_date,metadata,target
0,135,china,china_CHNCXR_0001_0_E464A8,/mnt/brics_data/public/datasets/Shenzhen/china...,2021-08-17,"{'gender': 'male', 'age': 45, 'has_tb': False,...",0
1,323,china,china_CHNCXR_0002_0_961172,/mnt/brics_data/public/datasets/Shenzhen/china...,2021-08-17,"{'gender': 'male', 'age': 63, 'has_tb': False,...",0
2,102,china,china_CHNCXR_0003_0_BA565D,/mnt/brics_data/public/datasets/Shenzhen/china...,2021-08-17,"{'gender': 'female', 'age': 48, 'has_tb': Fals...",0
3,229,china,china_CHNCXR_0004_0_96C984,/mnt/brics_data/public/datasets/Shenzhen/china...,2021-08-17,"{'gender': 'male', 'age': 58, 'has_tb': False,...",0
4,37,china,china_CHNCXR_0005_0_B6ECEF,/mnt/brics_data/public/datasets/Shenzhen/china...,2021-08-17,"{'gender': 'male', 'age': 28, 'has_tb': False,...",0


In [101]:
data.to_pickle("table_dataset.pkl")