# ALS sample analysis

In [2]:
from env import neptune_api_token

config = {
    'entity': 'ejmockler',
    'project': 'ALS-NUPS-2000',
    'neptuneApiToken': neptune_api_token,
}

In [3]:
import neptune
import pandas as pd
project = neptune.init_project(project=config['entity'] + '/' + config['project'], api_token=config['neptuneApiToken'])
runs_table_df = project.fetch_runs_table().to_pandas()

https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/


In [49]:
runs_table_df['nTest'].unique()

array([136.8])

In [50]:
import multiprocess as multiprocessing
import neptune
import os
%env NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'

def download_file(run_id, field='sampleResults', extension='csv'):
    path = f'./{field}/{run_id}.{extension}'
    if not os.path.exists(field):
        os.mkdir(field)
    if os.path.isfile(path): 
        return
    run = neptune.init_run(with_id=run_id, project=config['entity'] + '/' + config['project'], api_token=config['neptuneApiToken'])
    try:
        if field == 'globalFeatureImportance' or field == 'testLabels':
            for i in range(11):
                path = f'./{field}/{run_id}_{i}.{extension}'
                run[f"{field}/{i}"].download(destination=path)
        else: run[field].download(destination=path)
    except:
            pass
    run.stop()
    

# Get the number of available CPUs
cpu_count = multiprocessing.cpu_count()

# Create a list of tuples containing the run ID and the corresponding function call
sample_probability_tasks = [(download_file, run['sys/id']) for _, run in runs_table_df.iterrows()]
sample_label_tasks = [(download_file, run['sys/id'], 'testLabels', 'csv') for _, run in runs_table_df.iterrows()]
# shap_tasks = [(download_file, run['sys/id'], 'shapExplanationsPerFold', 'pkl') for _, run in runs_table_df.iterrows()]
embedding_tasks = [(download_file, run['sys/id'], 'embedding', 'csv') for _, run in runs_table_df.iterrows()]
# global_importance_tasks = [(download_file, run['sys/id'], 'globalFeatureImportance', 'csv') for _, run in runs_table_df.iterrows()]

with multiprocessing.Pool(cpu_count) as pool:
    # Use the multiprocessing Pool to map the tasks to different processes
    pool.starmap(lambda func, *args: func(*args), sample_probability_tasks)
    pool.starmap(lambda func, *args: func(*args), sample_label_tasks)
    # pool.starmap(lambda func, *args: func(*args), shap_tasks)
    pool.starmap(lambda func, *args: func(*args), embedding_tasks)
    # pool.starmap(lambda func, *args: func(*args), global_importance_tasks)

env: NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'



To avoid unintended consumption of logging hours during interactive sessions, the following monitoring options are disabled unless set to 'True' when initializing the run: 'capture_stdout', 'capture_stderr', and 'capture_hardware_metrics'.


To avoid unintended consumption of logging hours during interactive sessions, the following monitoring options are disabled unless set to 'True' when initializing the run: 'capture_stdout', 'capture_stderr', and 'capture_hardware_metrics'.


To avoid unintended consumption of logging hours during interactive sessions, the following monitoring options are disabled unless set to 'True' when initializing the run: 'capture_stdout', 'capture_stderr', and 'capture_hardware_metrics'.


To avoid unintended consumption of logging hours during interactive sessions, the following monitoring options are disabled unless set to 'True' when initializing the run: 'capture_stdout', 'capture_stderr', and 'capture_hardware_metrics'.


To avoid unintended consumption

https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-6907
Shutting down background jobs, please wait a moment...
Done!
All 0 operations synced, thanks for waiting!
Explore the metadata in the Neptune app:
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-6907/metadata
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-6906
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-7643
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-7920
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-6999
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-8012
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-7091
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-7828
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-8104
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-7367
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/NUPS2000-7275
https://new-ui.neptune.ai/ejmockler/ALS-NUPS-2000/e/

AssertionError: Cannot have cache with result_hander not alive

In [None]:
globalImportance_file_list

[]

In [5]:
from functools import partial

def process_csv(file_path, keepRunId=False):
    df = pd.read_csv(file_path)
    filename_stem = os.path.splitext(os.path.basename(file_path))[0]
    if keepRunId: df['run'] = filename_stem
    return df

def process_attribute(dataframe, run_id, field):
    run = neptune.init_run(with_id=run_id, project=config['entity'] + '/' + config['project'], api_token=config['neptuneApiToken'])
    try:
        dataframe[field] = [*run[field].fetch()] * len(dataframe)
    except:
        pass
    run.stop()
    return dataframe

result_path = 'sampleResults/'
embedding_path = 'embedding/'
result_file_list = []

def build_file_list(path):
    file_list = []
    for filename in os.listdir(path):
        if filename.endswith(".csv"):
            file_path = os.path.join(path, filename)
            file_list.append(file_path)
    return file_list

#globalImportance_path = 'globalFeatureImportance/'
#globalImportance_file_list = build_file_list(globalImportance_path)
    
result_file_list = build_file_list(result_path)
embedding_path_list = build_file_list(embedding_path)

with multiprocessing.Pool() as pool:
    processSampleResultCSV = partial(process_csv, keepRunId=True)
    dataframes = pool.map(processSampleResultCSV, result_file_list)
    embeddingDataframes = pool.map(process_csv, embedding_path_list)
    # globalFeatureImportanceDataframes = pool.map(process_csv, globalImportance_file_list)
    dataframes = pool.starmap(process_attribute, [(df, run_id, 'sys/tags') for df, run_id in zip(dataframes, runs_table_df['sys/id'])])

sampleResults = pd.concat(dataframes, ignore_index=False)
sampleResults = sampleResults.rename({'sys/tags': 'model'}, axis=1)

embedding = pd.concat(embeddingDataframes, ignore_index=False).drop_duplicates().set_index('id', drop=True)
embedding.index.name = 'id'

# globalImportances = pd.concat(globalFeatureImportanceDataframes, ignore_index=False)

NameError: name 'multiprocessing' is not defined

In [None]:
variants = globalFeatureImportanceDataframes[0]['Unnamed: 0'].to_list()

In [None]:
globalFeatureImportanceDataframes = [runDataframe.set_index('Unnamed: 0', drop=True) for runDataframe in globalFeatureImportanceDataframes if 'Unnamed: 0' in runDataframe.columns]
globalFeatureImportanceDataframes = [runDataframe.set_index(runDataframe.index.rename('variant')) for runDataframe in globalFeatureImportanceDataframes]
for i in range(len(globalFeatureImportanceDataframes)):
    globalFeatureImportanceDataframes[i].index = variants
    

In [None]:
globalImportanceDataframes = pd.concat(globalFeatureImportanceDataframes)
globalImportanceDataframes = globalImportanceDataframes


In [None]:
caseGlobalFeatureImportances = globalImportanceDataframes[['feature_importances_case']].reset_index()

In [145]:
averageCaseGlobalFeatureImportances = caseGlobalFeatureImportances.groupby(['index']).mean()
averageCaseGlobalFeatureImportances['feature_importances_case'] = averageCaseGlobalFeatureImportances['feature_importances_case'].abs()
averageCaseGlobalFeatureImportances = averageCaseGlobalFeatureImportances.sort_values('feature_importances_case', ascending=False)

In [162]:
averageCaseGlobalFeatureImportances.sort_values('feature_importances_case', ascending=False).to_csv('averageModelImportanceCoefficients.csv')

In [None]:
sampleResults

In [11]:
sampleResults = sampleResults.dropna()

In [None]:
embedding

In [1]:
import plotly.express as px

# Group by sample ID and calculate mean probability for each sample
mean_probs = sampleResults.groupby(['id', 'model'])['probability'].mean().reset_index()

# Create histogram
fig = px.histogram(mean_probs, x='probability', color='model', title="Mean Sample ALS Probability")
fig.show()

NameError: name 'sampleResults' is not defined

In [None]:
import pickle
import multiprocess as mp

def load_pickled(args):
    field, runID = args
    return pickle.load(open(f'{field}/{runID}.pkl', 'rb'))

def load_fold_dataframe(args):
    field, runID = args
    try:
        return pd.concat([pd.read_csv(f'{field}/{runID}_{i}.csv') for i in range(1,11)])
    except:
        pass

labels = list()
shapExplanations = list()

# Prepare arguments for the load_pickled function
label_args_list = [('testLabels', runID) for runID in sampleResults['run'].unique()]
shap_args_list = [('shapExplanationsPerFold', runID) for runID in sampleResults['run'].unique()]

with mp.Pool(mp.cpu_count() // 2) as pool:
    loadedLabels = pd.concat(pool.map(load_fold_dataframe, label_args_list)).drop_duplicates().set_index('id', drop=True)

# Update sampleResults DataFrame
sampleResults = sampleResults.join(loadedLabels, on='id')
# sampleResults['shapExplanations'] = shapExplanations

In [45]:
runs_table_df[runs_table_df['bootstrapIteration'] == 1]

Unnamed: 0,sys/creation_time,sys/description,sys/failed,sys/hostname,sys/id,sys/modification_time,sys/monitoring_time,sys/name,sys/owner,sys/ping_time,...,monitoring/fff42861/stdout,monitoring/fff42861/tid,monitoring/fffa1e9e/cpu,monitoring/fffa1e9e/gpu,monitoring/fffa1e9e/gpu_memory,monitoring/fffa1e9e/hostname,monitoring/fffa1e9e/memory,monitoring/fffa1e9e/pid,monitoring/fffa1e9e/stdout,monitoring/fffa1e9e/tid
3135,2023-06-06 10:40:01.325000+00:00,,False,noot,NUPS2000-1399,2023-06-06 10:40:41.141000+00:00,40,Untitled,ejmockler,2023-06-06 10:40:41.141000+00:00,...,,,,,,,,,,
3136,2023-06-06 10:39:58.535000+00:00,,False,noot,NUPS2000-1398,2023-06-06 10:40:48.660000+00:00,50,Untitled,ejmockler,2023-06-06 10:40:48.660000+00:00,...,,,,,,,,,,
3137,2023-06-06 10:39:55.366000+00:00,,False,noot,NUPS2000-1397,2023-06-06 10:40:41.240000+00:00,46,Untitled,ejmockler,2023-06-06 10:40:41.240000+00:00,...,,,,,,,,,,
3138,2023-06-06 10:39:53.841000+00:00,,False,noot,NUPS2000-1396,2023-06-06 10:40:35.028000+00:00,41,Untitled,ejmockler,2023-06-06 10:40:35.028000+00:00,...,,,,,,,,,,
3139,2023-06-06 10:39:52.543000+00:00,,False,noot,NUPS2000-1395,2023-06-06 10:40:35.804000+00:00,43,Untitled,ejmockler,2023-06-06 10:40:35.804000+00:00,...,,,,,,,,,,
3140,2023-06-06 10:39:47.881000+00:00,,False,noot,NUPS2000-1394,2023-06-06 10:40:36.126000+00:00,48,Untitled,ejmockler,2023-06-06 10:40:36.126000+00:00,...,,,,,,,,,,
3141,2023-06-06 10:39:42.931000+00:00,,False,noot,NUPS2000-1393,2023-06-06 10:40:35.481000+00:00,52,Untitled,ejmockler,2023-06-06 10:40:35.481000+00:00,...,,,,,,,,,,


## Determine sample accuracy

In [None]:
import numpy as np
import pandas as pd

# calculate accuracy for each sample, model pair
resolvedSampleResults = []
for id, group in sampleResults.groupby(['id', 'model'], group_keys=False):
    assert len(group['testLabel'].unique()) == 1  # All labels should be the same
    sampleLabel = group['testLabel'].unique()[0]
    sampleClassifications = np.around(group['probability']) # Scale probablility to label values for direct classification
    group['accuracy'] = np.mean(sampleClassifications == sampleLabel)
    resolvedSampleResults.append(group) 
sampleResults = pd.concat(resolvedSampleResults)
sampleResults_cases = sampleResults[sampleResults['testLabel'] == 1]
sampleResults_controls = sampleResults[sampleResults['testLabel'] == 0]


In [None]:
pd.set_option('display.max_rows', 25)
pd.set_option('display.min_rows', 25)

In [None]:
sampleResults[[column for column in sampleResults.columns if column not in ['shapExplanations']]]['accuracy'].describe()

count    4.223016e+06
mean     5.081953e-01
std      2.605973e-01
min      0.000000e+00
25%      2.840909e-01
50%      5.174825e-01
75%      7.328767e-01
max      9.872611e-01
Name: accuracy, dtype: float64

In [None]:
import plotly.express as px

# Group by sample ID and label, calculate mean accuracy for each sample
mean_accuracy = sampleResults.loc[sampleResults['model']=='LinearSVC'].groupby(['id', 'testLabel'])['accuracy'].mean().reset_index()

# Create histogram of mean accuracy
fig = px.histogram(mean_accuracy, x='accuracy', color='testLabel', pattern_shape='testLabel', title="Mean Sample Accuracy")
fig.show()

## Correlate sample accuracy across models

- Use Spearman rank-order method since model accuracies should be monotonic 

In [29]:
accuracy_grouped_models_df = sampleResults.groupby(['id', 'model', 'testLabel'])['accuracy'].mean().reset_index()

In [30]:
accuracy_grouped_models_df

Unnamed: 0,id,model,testLabel,accuracy
0,ALS__CGND-HDA-00001__UP-WGS-185,AdaBoostClassifier,1,0.486486
1,ALS__CGND-HDA-00001__UP-WGS-185,LinearSVC,1,0.533784
2,ALS__CGND-HDA-00001__UP-WGS-185,LogisticRegression,1,0.533333
3,ALS__CGND-HDA-00001__UP-WGS-185,MultinomialNB,1,0.542484
4,ALS__CGND-HDA-00001__UP-WGS-185,RadialBasisSVC,1,0.591549
...,...,...,...,...
19147,aals-CTR__CGND-HDA-03896__NEUNG931UCJ,LogisticRegression,0,0.755605
19148,aals-CTR__CGND-HDA-03896__NEUNG931UCJ,MultinomialNB,0,0.758542
19149,aals-CTR__CGND-HDA-03896__NEUNG931UCJ,RadialBasisSVC,0,0.773836
19150,aals-CTR__CGND-HDA-03896__NEUNG931UCJ,RandomForestClassifier,0,0.723982


In [32]:
model_case_accuracy_pivot_df = accuracy_grouped_models_df.loc[accuracy_grouped_models_df['testLabel']==1].pivot(index='id', columns='model', values='accuracy')
model_control_accuracy_pivot_df = accuracy_grouped_models_df.loc[accuracy_grouped_models_df['testLabel']==0].pivot(index='id', columns='model', values='accuracy')

In [33]:
model_control_accuracy_pivot_df

model,AdaBoostClassifier,LinearSVC,LogisticRegression,MultinomialNB,RadialBasisSVC,RandomForestClassifier,XGBClassifier
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
CTR__CGND-HDA-00196__NEUHC282LVJ,0.918182,0.923963,0.919283,0.917995,0.904656,0.914027,0.917241
CTR__CGND-HDA-00209__NEUZW491LJA,0.127273,0.117512,0.150224,0.118451,0.108647,0.126697,0.110345
CTR__CGND-HDA-00260__NEUCX966RX5,0.881818,0.910138,0.896861,0.899772,0.900222,0.884615,0.862069
CTR__CGND-HDA-00434__NEUXM830AFG,0.438636,0.495392,0.461883,0.464692,0.461197,0.513575,0.473563
CTR__CGND-HDA-00435__NEUKZ685AR4,0.272727,0.278802,0.304933,0.277904,0.259424,0.294118,0.282759
...,...,...,...,...,...,...,...
aals-CTR__CGND-HDA-03811__NEUFU823AF8,0.379545,0.361751,0.325112,0.309795,0.365854,0.289593,0.324138
aals-CTR__CGND-HDA-03813__NEUEK795BLX,0.763636,0.776498,0.751121,0.763098,0.760532,0.769231,0.770115
aals-CTR__CGND-HDA-03874__NEUDD665KML,0.129545,0.168203,0.179372,0.168565,0.137472,0.171946,0.163218
aals-CTR__CGND-HDA-03876__NEUXY894VJR,0.704545,0.672811,0.739910,0.710706,0.676275,0.692308,0.714943


In [34]:
from scipy.stats import spearmanr

model_accuracy_pivot_df = pd.concat([model_case_accuracy_pivot_df, model_control_accuracy_pivot_df], axis=0)
sample_accuracy_pivot_df = model_accuracy_pivot_df.transpose()
modelAccuracyCorrelation, _ = spearmanr(model_accuracy_pivot_df)
sampleAccuracyCorrelation, _ = spearmanr(sample_accuracy_pivot_df)
modelCorrelation_df = pd.DataFrame(modelAccuracyCorrelation, index=model_accuracy_pivot_df.columns, columns=model_accuracy_pivot_df.columns)
sampleCorrelation_df = pd.DataFrame(sampleAccuracyCorrelation, index=sample_accuracy_pivot_df.columns, columns=sample_accuracy_pivot_df.columns)

modelCaseAccuracyCorrelation, _ = spearmanr(model_case_accuracy_pivot_df)
modelCaseCorrelation_df =  pd.DataFrame(modelCaseAccuracyCorrelation, index=model_case_accuracy_pivot_df.columns, columns=model_case_accuracy_pivot_df.columns)
modelControlAccuracyCorrelation, _ = spearmanr(model_control_accuracy_pivot_df)
modelControlCorrelation_df =  pd.DataFrame(modelControlAccuracyCorrelation, index=model_control_accuracy_pivot_df.columns, columns=model_control_accuracy_pivot_df.columns)

In [35]:
import plotly.figure_factory as ff

fig = ff.create_annotated_heatmap(z=modelCorrelation_df.values,
                                  x=list(modelCorrelation_df.columns),
                                  y=list(modelCorrelation_df.columns),
                                  annotation_text=modelCorrelation_df.round(2).values,
                                  colorscale='Viridis')
fig.update_layout(title_text='Spearman Correlation of Per-Sample Accuracy Across Models',
                  xaxis = dict(title='Model'),
                  yaxis = dict(title='Model'),
                  margin={'t':175},)
fig.show()

fig = ff.create_annotated_heatmap(z=modelCaseCorrelation_df.values,
                                  x=list(modelCaseCorrelation_df.columns),
                                  y=list(modelCaseCorrelation_df.columns),
                                  annotation_text=modelCorrelation_df.round(2).values,
                                  colorscale='Viridis')
fig.update_layout(title_text='Spearman Correlation of Per-Case Accuracy Across Models',
                  xaxis = dict(title='Model'),
                  yaxis = dict(title='Model'),
                  margin={'t':175},)
fig.show()

fig = ff.create_annotated_heatmap(z=modelControlCorrelation_df.values,
                                  x=list(modelControlCorrelation_df.columns),
                                  y=list(modelControlCorrelation_df.columns),
                                  annotation_text=modelControlCorrelation_df.round(2).values,
                                  colorscale='Viridis')
fig.update_layout(title_text='Spearman Correlation of Per-Control Accuracy Across Models',
                  xaxis = dict(title='Model'),
                  yaxis = dict(title='Model'),
                  margin={'t':175},)
fig.show()


## Plot heatmap of variants x cases and variants x controls

- Sort samples by accuracy 
- Color by feature value
    - Cluster case variants, show dendogram
    - Order control variants by case clustering too

## Select outlier samples in accuracy distribution

In [45]:
accuracyThreshold = (0.85, 0.15)
accurateSamples = sampleResults[sampleResults['accuracy'] >= accuracyThreshold[0]].groupby(['id', 'label'])['accuracy'].mean().reset_index()
discordantSamples = sampleResults[sampleResults['accuracy'] <= accuracyThreshold[1]].groupby(['id', 'label'])['accuracy'].mean().reset_index()

In [46]:
print(f"total samples: {sampleResults['id'].unique().shape[0]}")
print(f"cases with classification accuracy above {accuracyThreshold[0]:.0%}: {accurateSamples.loc[accurateSamples['label'] == 1,'id'].unique().shape[0]}")
print(f"controls with classification accuracy above {accuracyThreshold[0]:.0%}: {accurateSamples.loc[accurateSamples['label'] == 0,'id'].unique().shape[0]}")

print(f"cases with classification accuracy above {accuracyThreshold[1]:.0%}: {discordantSamples.loc[discordantSamples['label'] == 1,'id'].unique().shape[0]}")
print(f"controls with classification accuracy above {accuracyThreshold[1]:.0%}: {discordantSamples.loc[discordantSamples['label'] == 0,'id'].unique().shape[0]}")

total samples: 2736
cases with classification accuracy above 85%: 695
controls with classification accuracy above 85%: 91
cases with classification accuracy above 15%: 410
controls with classification accuracy above 15%: 146


In [47]:
accurateSamples[[column for column in accurateSamples.columns if column not in ['shapExplanations']]]

Unnamed: 0,id,label,accuracy
0,ALS__CGND-HDA-00004__UP-WGS-187,1,0.906250
1,ALS__CGND-HDA-00012__UP-WGS-195,1,0.906250
2,ALS__CGND-HDA-00013__UP-WGS-196,1,0.866667
3,ALS__CGND-HDA-00028__UP-WGS-211,1,0.915254
4,ALS__CGND-HDA-00051__UP-WGS-234,1,0.888889
5,ALS__CGND-HDA-00057__UP-WGS-241,1,0.920635
6,ALS__CGND-HDA-00064__UP-WGS-248,1,0.956522
7,ALS__CGND-HDA-00076__UP-WGS-260,1,0.882353
8,ALS__CGND-HDA-00101__UP-WGS-285,1,0.895349
9,ALS__CGND-HDA-00102__UP-WGS-286,1,0.868852


## View variants by sample accuracy

In [54]:
accurateCases

Unnamed: 0_level_0,label,accuracy
id,Unnamed: 1_level_1,Unnamed: 2_level_1
ALS__CGND-HDA-01215__NEUUA360BR1,1,1.00
ALS__CGND-HDA-03062__UP-WGS-535,1,1.00
ALS__CGND-HDA-02741__PF-UCL-28,1,1.00
ALS__CGND-HDA-00353__358ALS,1,1.00
ALS__CGND-HDA-01013__NEUFL908GEL,1,1.00
ALS__CGND-HDA-00872__MH-WASHU-250,1,1.00
ALS__CGND-HDA-01813__TD-ALS-136,1,1.00
aals-ALS__CGND-HDA-04067__NEUJA207UUV,1,1.00
ALS__CGND-HDA-00644__MH-WASHU-22,1,1.00
ALS__CGND-HDA-03505__NEUBA645MFF,1,1.00


In [55]:
accurateCaseEmbeddings

Unnamed: 0_level_0,"('1', '186347356', 'TPR')","('1', '225419442', 'LBR')","('1', '229487591', 'NUP133')","('1', '229495987', 'NUP133')","('1', '246842749', 'AHCTF1')","('1', '246861024', 'AHCTF1')","('1', '246877229', 'AHCTF1')","('1', '246885532', 'AHCTF1')","('2', '183131001', 'NUP35')","('2', '183131014', 'NUP35')",...,"('14', '24210671', 'CHMP4A')","('16', '56839701', 'NUP93')","('16', '71922608', 'IST1')","('16', '71924122', 'IST1')","('16', '71924149', 'IST1')","('17', '47671806', 'KPNB1')","('18', '12984145', 'SEH1L')","('19', '7961534', 'ELAVL1')","('19', '49908960', 'NUP62')","('19', '58551790', 'CHMP2A')"
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ALS__CGND-HDA-01215__NEUUA360BR1,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,...,1.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0
ALS__CGND-HDA-03062__UP-WGS-535,0.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,...,0.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0
ALS__CGND-HDA-02741__PF-UCL-28,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,...,0.0,1.0,1.0,1.0,0.0,1.0,1.0,0.0,1.0,0.0
ALS__CGND-HDA-00353__358ALS,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,...,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0
ALS__CGND-HDA-01013__NEUFL908GEL,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,...,0.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0
ALS__CGND-HDA-00872__MH-WASHU-250,0.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,...,0.0,0.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,0.0
ALS__CGND-HDA-01813__TD-ALS-136,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,...,0.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0
aals-ALS__CGND-HDA-04067__NEUJA207UUV,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,...,1.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0
ALS__CGND-HDA-00644__MH-WASHU-22,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1.0,...,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0
ALS__CGND-HDA-03505__NEUBA645MFF,1.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,...,1.0,0.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,1.0


In [148]:
topVariants

["('6', '17675015', 'NUP153')",
 "('7', '849532', 'SUN1')",
 "('7', '135584907', 'NUP205')",
 "('2', '183131014', 'NUP35')",
 "('3', '13319787', 'NUP210')",
 "('3', '13353975', 'NUP210')",
 "('18', '12984145', 'SEH1L')",
 "('7', '135607350', 'NUP205')",
 "('1', '246842749', 'AHCTF1')",
 "('7', '842031', 'SUN1')"]

In [155]:
outlierCaseEmbeddings.index

Index(['ALS__CGND-HDA-01215__NEUUA360BR1', 'ALS__CGND-HDA-00644__MH-WASHU-22',
       'ALS__CGND-HDA-03062__UP-WGS-535', 'ALS__CGND-HDA-02651__UP-WGS-406',
       'aals-ALS__CGND-HDA-03814__NEUCH829YF0',
       'ALS__CGND-HDA-03630__NSTNNXTLA8ZQ',
       'aals-ALS__CGND-HDA-03913__NEUJA497KXF',
       'ALS__CGND-HDA-03505__NEUBA645MFF',
       'aals-ALS__CGND-HDA-04067__NEUJA207UUV',
       'ALS__CGND-HDA-01813__TD-ALS-136',
       ...
       'ALS__CGND-HDA-02288__13-190-33', 'ALS__CGND-HDA-02561__93-094-34',
       'aals-ALS__CGND-HDA-02700__NEUPP607CPW',
       'aals-ALS__CGND-HDA-00234__NEUZN836GME',
       'ALS__CGND-HDA-02422__03-151-16',
       'aals-ALS__CGND-HDA-03560__NEUET719NJD',
       'ALS__CGND-HDA-03442__NEUYU889EY1', 'ALS__CGND-HDA-01795__TD-ALS-87',
       'ALS__CGND-HDA-03696__NEUBZ354DBH',
       'aals-ALS__CGND-HDA-03605__NEUPJ681DUM'],
      dtype='object', name='id', length=1105)

In [159]:
import dash_bio

accurateCases = accurateSamples.loc[accurateSamples['label'] == 1].sort_values(by=['accuracy'], ascending=False).set_index('id', drop=True)
accurateCaseEmbeddings = embedding.loc[embedding.index.isin(accurateCases.index)].reindex(accurateCases.index)
accurateControls = accurateSamples.loc[accurateSamples['label'] == 0].sort_values(by=['accuracy'], ascending=False).set_index('id', drop=True)
accurateControlEmbeddings = embedding.loc[embedding.index.isin(accurateControls.index)].reindex(accurateControls.index)

discordantCases = discordantSamples.loc[discordantSamples['label'] == 1].sort_values(by=['accuracy'], ascending=False).set_index('id', drop=True)
discordantCaseEmbeddings = embedding.loc[embedding.index.isin(discordantCases.index)].reindex(discordantCases.index)
discordantControls = discordantSamples.loc[discordantSamples['label'] == 0].sort_values(by=['accuracy'], ascending=False).set_index('id', drop=True)
discordantControlEmbeddings = embedding.loc[embedding.index.isin(discordantControls.index)].reindex(discordantControls.index)

outlierCases = pd.concat([accurateCases, discordantCases]).sort_values(by=['accuracy'], ascending=False)
outlierCaseEmbeddings = embedding.loc[embedding.index.isin(outlierCases.index)].reindex(outlierCases.index)

outlierControls = pd.concat([accurateControls, discordantControls]).sort_values(by=['accuracy'], ascending=False)
outlierControlEmbeddings = embedding.loc[embedding.index.isin(outlierControls.index)].reindex(outlierControls.index)

outlierSamples = pd.concat([accurateSamples, discordantSamples]).sort_values(by=['accuracy'], ascending=False)
outlierEmbeddings = embedding.loc[embedding.index.isin(outlierSamples.index)].reindex(outlierSamples.index)

topVariants = averageCaseGlobalFeatureImportances.iloc[:10].index
outlierCaseEmbeddings = outlierCaseEmbeddings[topVariants]
outlierControlEmbeddings = outlierControlEmbeddings[topVariants]


plot = dash_bio.Clustergram(
    data=outlierCaseEmbeddings.T.values,
    row_labels=list(outlierCaseEmbeddings.columns.values),
    column_labels=list(outlierCaseEmbeddings.index),
    hidden_labels='column',
    cluster='row',
    height=1200,
    width=1200,
    color_map= [
        [0.0, '#636EFA'],
        [0.25, '#AB63FA'],
        [0.5, '#FFFFFF'],
        [0.75, '#E763FA'],
        [1.0, '#EF553B']
    ])
plot.update_layout(title={'text': 'Outlier Case Variants (Accuracy >= 85% or <= 15%, highest on left)', 'x': 0.5, 'xanchor': 'center'})
plot.write_html('outlier_cases_clustergram.html')

# TODO filter important variants 
plot = dash_bio.Clustergram(
    data=outlierControlEmbeddings.T.values,
    row_labels=list(outlierControlEmbeddings.columns.values),
    column_labels=list(outlierControlEmbeddings.index),
    hidden_labels='column',
    cluster='row',
    height=1200,
    width=1200,
    color_map= [
        [0.0, '#636EFA'],
        [0.25, '#AB63FA'],
        [0.5, '#FFFFFF'],
        [0.75, '#E763FA'],
        [1.0, '#EF553B']
    ])
plot.update_layout(title={'text': 'Outlier Control Variants (Accuracy >= 85% or <= 15%, highest on left)', 'x': 0.5, 'xanchor': 'center'})
plot.write_html('outlier_controls_clustergram.html')

plot = dash_bio.Clustergram(
    data=embedding.T.values,
    row_labels=list(embedding.columns.values),
    column_labels=list(embedding.index),
    hidden_labels='column',
    cluster='row',
    height=1200,
    width=1200,
    color_map= [
        [0.0, '#636EFA'],
        [0.25, '#AB63FA'],
        [0.5, '#FFFFFF'],
        [0.75, '#E763FA'],
        [1.0, '#EF553B']
    ])
plot.update_layout(title={'text': 'Outlier Case & Control Variants (Accuracy >= 85% or <= 15%, highest on left)', 'x': 0.5, 'xanchor': 'center'})
plot.write_html('outliers_clustergram.html')

In [57]:

plot = dash_bio.Clustergram(
    data=discordantCaseEmbeddings.T.values,
    row_labels=list(discordantCaseEmbeddings.columns.values),
    column_labels=list(discordantCaseEmbeddings.index),
    hidden_labels='column',
    cluster='row',
    height=1200,
    width=1200,
    color_map= [
        [0.0, '#636EFA'],
        [0.25, '#AB63FA'],
        [0.5, '#FFFFFF'],
        [0.75, '#E763FA'],
        [1.0, '#EF553B']
    ])
plot.update_layout(title={'text': 'Discordant Case Variants (Accuracy <= 15%, highest on left)', 'x': 0.5, 'xanchor': 'center'})
plot.write_html('discordant_cases_clustergram.html')

plot = dash_bio.Clustergram(
    data=discordantControlEmbeddings.T.values,
    row_labels=list(discordantControlEmbeddings.columns.values),
    column_labels=list(discordantControlEmbeddings.index),
    hidden_labels='column',
    cluster='row',
    height=1200,
    width=1200,
    color_map= [
        [0.0, '#636EFA'],
        [0.25, '#AB63FA'],
        [0.5, '#FFFFFF'],
        [0.75, '#E763FA'],
        [1.0, '#EF553B']
    ])
plot.update_layout(title={'text': 'Discordant Control Variants (Accuracy <= 15%, highest on left)', 'x': 0.5, 'xanchor': 'center'})
plot.write_html('discordant_controls_clustergram.html')


In [51]:
accurateCaseEmbeddings


Unnamed: 0_level_0,"('1', '186347356', 'TPR')","('1', '225419442', 'LBR')","('1', '229487591', 'NUP133')","('1', '229495987', 'NUP133')","('1', '246842749', 'AHCTF1')","('1', '246861024', 'AHCTF1')","('1', '246877229', 'AHCTF1')","('1', '246885532', 'AHCTF1')","('2', '183131001', 'NUP35')","('2', '183131014', 'NUP35')",...,"('14', '24210671', 'CHMP4A')","('16', '56839701', 'NUP93')","('16', '71922608', 'IST1')","('16', '71924122', 'IST1')","('16', '71924149', 'IST1')","('17', '47671806', 'KPNB1')","('18', '12984145', 'SEH1L')","('19', '7961534', 'ELAVL1')","('19', '49908960', 'NUP62')","('19', '58551790', 'CHMP2A')"
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ALS__CGND-HDA-01215__NEUUA360BR1,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,...,1.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0
ALS__CGND-HDA-03062__UP-WGS-535,0.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,...,0.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0
ALS__CGND-HDA-02741__PF-UCL-28,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,...,0.0,1.0,1.0,1.0,0.0,1.0,1.0,0.0,1.0,0.0
ALS__CGND-HDA-00353__358ALS,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,...,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0
ALS__CGND-HDA-01013__NEUFL908GEL,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,...,0.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0
ALS__CGND-HDA-00872__MH-WASHU-250,0.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,...,0.0,0.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,0.0
ALS__CGND-HDA-01813__TD-ALS-136,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,...,0.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0
aals-ALS__CGND-HDA-04067__NEUJA207UUV,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,...,1.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0
ALS__CGND-HDA-00644__MH-WASHU-22,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1.0,...,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0
ALS__CGND-HDA-03505__NEUBA645MFF,1.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,...,1.0,0.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,1.0


## Shapely value distribution for well-classified samples

In [24]:
def serialize_sample_shap(sampleValues):
    # some models only have probabilities for a single class
    return sampleValues.values[:,1] if len(sampleValues.values.shape) > 1 else sampleValues.values

serializedData = []
for sampleValues in accurateSamples['shapExplanations']:
    serializedData.append(serialize_sample_shap(sampleValues))
    
shapelyValueDataframe = pd.DataFrame(serializedData, columns=accurateSamples.iloc[0]['shapExplanations'].feature_names)
shapelyValueDataframe.index = accurateSamples['id']

In [25]:
accurateSamples['shapExplanations'].iloc[0].values[0]

array([ 0.00208532, -0.00208532])

In [26]:
shapelyValueDataframe

Unnamed: 0_level_0,1_186347356_TPR,1_225419442_LBR,1_229487591_NUP133,1_229495987_NUP133,1_246842749_AHCTF1,1_246861024_AHCTF1,1_246877229_AHCTF1,1_246885532_AHCTF1,2_183131001_NUP35,2_183131014_NUP35,...,14_24210671_CHMP4A,16_56839701_NUP93,16_71922608_IST1,16_71924122_IST1,16_71924149_IST1,17_47671806_KPNB1,18_12984145_SEH1L,19_7961534_ELAVL1,19_49908960_NUP62,19_58551790_CHMP2A
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ALS__CGND-HDA-00028__UP-WGS-211,-0.002085,1.105033e-05,-0.000271,-0.000595,0.001739,2.072271e-04,-1.647792e-04,1.491997e-03,0.000524,6.176167e-05,...,3.368424e-03,-3.966042e-03,0.001352,7.185846e-04,0.002947,0.000800,-0.004286,-6.833568e-05,-0.002589,3.036540e-05
ALS__CGND-HDA-00028__UP-WGS-211,0.000000,0.000000e+00,0.000000,0.000000,0.000000,0.000000e+00,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,...,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000,-0.005933,0.000000e+00,0.000000,0.000000e+00
ALS__CGND-HDA-00028__UP-WGS-211,-0.031587,8.240942e-05,-0.000965,-0.001356,0.003614,-5.852192e-04,-2.933458e-04,-1.994608e-03,0.008768,8.917580e-05,...,1.315728e-02,-3.282740e-02,0.015376,1.463580e-03,0.003535,0.001407,-0.001442,-9.142397e-04,-0.009797,2.347364e-03
ALS__CGND-HDA-00028__UP-WGS-211,-0.053750,0.000000e+00,-0.003750,-0.010000,0.040000,-3.750000e-03,0.000000e+00,0.000000e+00,0.026250,0.000000e+00,...,2.000000e-02,-8.000000e-02,0.053750,1.125000e-02,0.060000,0.027500,-0.023750,0.000000e+00,-0.023750,0.000000e+00
ALS__CGND-HDA-00028__UP-WGS-211,-0.003167,1.814708e-05,-0.000262,-0.000506,0.001549,-2.973182e-05,9.041043e-06,1.365186e-05,0.000809,2.332971e-05,...,-3.626024e-04,-4.553258e-03,0.000817,6.931123e-04,0.001802,0.001349,-0.004090,1.926242e-04,-0.001914,6.618462e-05
ALS__CGND-HDA-00028__UP-WGS-211,-0.001699,0.000000e+00,-0.000138,-0.000233,0.000588,6.899551e-06,-3.813430e-05,2.734416e-04,0.000481,2.511202e-05,...,2.728980e-04,-2.866491e-03,0.000360,3.727614e-04,0.002225,0.001197,-0.002400,3.084472e-05,-0.001197,2.007815e-05
ALS__CGND-HDA-00028__UP-WGS-211,0.134497,0.000000e+00,-0.036077,0.013720,-0.031532,0.000000e+00,1.566214e-03,-1.446319e-02,-0.015263,8.245045e-03,...,7.782458e-02,-1.930796e-02,0.014084,6.551375e-03,-0.025357,-0.037268,-0.039882,-1.977051e-04,0.003734,7.545884e-03
ALS__CGND-HDA-00028__UP-WGS-211,-0.020860,2.516267e-05,0.002815,0.002884,0.001877,1.747048e-03,-4.068545e-03,-6.103635e-04,0.008227,2.005737e-04,...,1.318474e-02,-3.783663e-02,0.010932,4.318332e-03,0.005522,0.002097,-0.007436,-1.807261e-03,-0.003859,2.939592e-03
ALS__CGND-HDA-00028__UP-WGS-211,-0.011680,7.367027e-05,0.004437,0.003698,0.009642,1.617706e-03,-1.005164e-03,4.100567e-04,0.008354,4.934337e-05,...,3.612122e-04,-4.435089e-02,0.004922,-3.163747e-04,0.009375,0.003198,-0.007594,-5.277894e-06,-0.004437,6.607813e-04
ALS__CGND-HDA-00028__UP-WGS-211,-0.025460,-6.109835e-06,-0.000622,-0.001382,0.006130,-9.644239e-04,-6.352159e-04,-2.895871e-04,0.006104,-1.403618e-05,...,1.725073e-03,-4.423027e-02,0.007977,3.642917e-03,0.007469,0.002016,-0.004029,8.430942e-04,-0.005967,-6.267683e-05


In [28]:
import numpy as np
import plotly.express as px

df_stats = shapelyValueDataframe.describe().T

# Create a DataFrame for the plotting
df_plot_mean = df_stats[['mean']].reset_index()
df_plot_std = df_stats[['std']].reset_index()

# Sort by the greatest mean Shapley value and least standard deviation
df_plot_mean = df_plot_mean.sort_values(['mean'], ascending=False)
df_plot_std = df_plot_std.sort_values(['std'], ascending=True)

# Create a bar chart for means
fig_mean = px.bar(df_plot_mean, x='index', y='mean', labels={'index':'Feature', 'mean':'Mean Shapely Value'})
fig_mean.show()

# Create a bar chart for standard deviations
fig_std = px.bar(df_plot_std, x='index', y='std', labels={'index':'Feature', 'std':'Standard Deviation'})
fig_std.show()


In [38]:
df_stats[['mean', 'std']].corr(method='pearson')

Unnamed: 0,mean,std
mean,1.0,0.714121
std,0.714121,1.0


In [54]:
accurateCases = [id for id in accurateSamples['id'].unique() if "CTR" not in id]
pd.Series(accurateCases, name='id').to_csv('accurateCases.csv', index=False)

In [55]:
discordantSampleIDs = discordantSamples['id'].unique()
discordantCases = [id for id in discordantSampleIDs if "CTR" not in id]
pd.Series(discordantCases, name='id').to_csv('discordantCases.csv', index=False)

In [50]:
discordantSampleIDs

array(['ALS__CGND-HDA-00277__2140ALS', 'ALS__CGND-HDA-00317__362ALS',
       'ALS__CGND-HDA-00323__820ALS', 'ALS__CGND-HDA-00332__2271ALS',
       'ALS__CGND-HDA-00360__125ALS', 'ALS__CGND-HDA-00782__MH-WASHU-160',
       'ALS__CGND-HDA-00845__MH-WASHU-223',
       'ALS__CGND-HDA-01098__276-11-5',
       'ALS__CGND-HDA-01224__NEUUM419GYB', 'ALS__CGND-HDA-01294__EC11',
       'ALS__CGND-HDA-01489__NEUGY188ZTM',
       'ALS__CGND-HDA-01539__UP-WGS-017',
       'ALS__CGND-HDA-01591__UP-WGS-069',
       'ALS__CGND-HDA-01795__TD-ALS-87',
       'ALS__CGND-HDA-01824__TD-ALS-129',
       'ALS__CGND-HDA-01897__TD-ALS-143',
       'ALS__CGND-HDA-02089__NEUWZ812JXY',
       'ALS__CGND-HDA-02254__NEUMG708VB0',
       'ALS__CGND-HDA-02288__13-190-33', 'ALS__CGND-HDA-02313__05-156-09',
       'ALS__CGND-HDA-02329__96-119-66', 'ALS__CGND-HDA-02340__94-106-48',
       'ALS__CGND-HDA-02365__91-072-76', 'ALS__CGND-HDA-02392__87-017-42',
       'ALS__CGND-HDA-02413__98-130-70', 'ALS__CGND-HDA-02416__97-