In [None]:
import math
import time
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
from scipy import stats
from sklearn.linear_model import LogisticRegression, LinearRegression, ElasticNet, Ridge
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.metrics import brier_score_loss, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

from ml4h.explorations import latent_space_dataframe

# IPython imports
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import colors

In [None]:
all_scores = {}
label_file = '/home/sam/trained_models/explore_phenotypes_3/tensors_all_union.csv'
labels = pd.read_csv(label_file)
phenotypes = [
  'Sex_Male_0_0',
  'sex',
  'diabetes_type_2',
  'hypercholesterolemia',
  'hypertension',
  'Atrial_fibrillation',
    'age', 'bmi', 'bmi_0', 'RRInterval', 
    'LVM', 'LVEDV', 'LVESV', 
    'PC1', 'PC2', 'PC5',

]

col_rename = {f'22009_Genetic-principal-components_0_{i}': f'PC{i}' for i in range(1,41)}
col_rename['Genetic-sex_Male_0_0'] = 'sex'
col_rename['21003_Age-when-attended-assessment-centre_2_0'] = 'age'
col_rename['21001_Body-mass-index-BMI_2_0'] = 'bmi'
col_rename['21001_Body-mass-index-BMI_0_0'] = 'bmi_0'
col_rename['2887_Number-of-cigarettes-previously-smoked-daily_0_0'] = 'smoking'
col_rename['30690_Cholesterol_0_0'] = 'cholesterol'

In [None]:
latent_files = {
'Brain T1': '/home/sam/trained_models/brain_t1_slice_80_autoencoder_256d_v2022_06_07/hidden_axial_80_brain_t1_slice_80_autoencoder_256d_v2022_06_07.tsv',
'Brain MNI': '/home/sam/trained_models/t1_mni_slices_48_80_autoencoder_256d/hidden_axial_68_100_t1_mni_slices_48_80_autoencoder_256d.tsv',
}
latent_size = {
'Brain T1':256,
'Brain MNI':256,
}

pairs = [
    ('Brain MNI', 'Brain T1'),
]

In [None]:
latent_files = {
'DXA 12': '/home/sam/trained_models/dxa_12_autoencoder_256d/hidden_dxa_1_12_dxa_12_autoencoder_256d.tsv',
'DXA 12 Homography': '/home/sam/trained_models/dxa_12_homography_autoencoder_512d/hidden_dxa_1_12_dxa_12_homography_autoencoder_512d.tsv',
}
latent_size = {
'DXA 12':256,
'DXA 12 Homography':512,
}

pairs = [
    ('DXA 12 Homography', 'DXA 12'),
]

In [None]:
latent_files = {
'DXA 2 AE': '/home/sam/trained_models/dxa_2_autoencoder_256d/hidden_dxa_1_2_dxa_2_autoencoder_256d.tsv',
'DXA 2 DF': '/home/sam/trained_models/dxa_2_5_dropfuse_256d/hidden_dxa_1_2_dxa_2_5_dropfuse_256d.tsv',
'DXA 5 AE': '/home/sam/trained_models/dxa_5_autoencoder_256d/hidden_dxa_1_5_dxa_5_autoencoder_256d.tsv',
'DXA 5 DF': '/home/sam/trained_models/dxa_2_5_dropfuse_256d/hidden_dxa_1_5_dxa_2_5_dropfuse_256d.tsv',
}
latent_size = {
'DXA 2 AE':256,
'DXA 2 DF':256,
'DXA 5 AE':256,
'DXA 5 DF':256,  
}

pairs = [
    ('DXA 2 DF','DXA 2 AE'),
    ('DXA 5 DF','DXA 5 AE'),
]

# latent_files = {
# 'cMRI Dense AE': '/home/sam/csvs/dense_autoencoder_1_sample_cmri_inferences.tsv',
# 'cMRI Circular AE': '/home/sam/csvs/circular_autoencoder_lax_4ch_v2023_04_28.tsv',
# 'ECG 10s': '/home/sam/trained_models/ecg_rest_autoencoder_256d_v2023_05_09/hidden_strip_ecg_rest_autoencoder_256d_v2023_05_09.tsv',
# 'ECG Median': '/home/sam/trained_models/hypertuned_48m_16e_ecg_median_raw_10_autoencoder_256d/hidden_embed_hypertuned_48m_16e_ecg_median_raw_10_autoencoder_256d.tsv',
# 'cMRI AE':'/home/sam/trained_models/hypertuned_32m_8e_lax_4ch_heart_center_autoencoder_256d/hidden_lax_4ch_heart_center_hypertuned_32m_8e_lax_4ch_heart_center_autoencoder_256d.tsv',    
# 'cMRI ECG DropFuse': '/home/sam/trained_models/dropout_pair_contrastive_lax_4ch_cycle_ecg_median_10_pretrained_256d_v2020_06_07/hidden_lax_4ch_heart_center_dropout_pair_contrastive_lax_4ch_cycle_ecg_median_10_pretrained_256d_v2020_06_07.tsv',
# 'ECG cMRI DropFuse': '/home/sam/trained_models/dropout_pair_contrastive_lax_4ch_cycle_ecg_median_10_pretrained_256d_v2020_06_07/hidden_ecg_rest_median_raw_10_dropout_pair_contrastive_lax_4ch_cycle_ecg_median_10_pretrained_256d_v2020_06_07.tsv',
  
# }
# latent_size = {
# 'cMRI Dense AE': 50,
# 'cMRI Circular AE': 50,

#    'ECG 10s': 256,
#    'ECG Median': 256,
#     'cMRI AE':256,
#     'cMRI ECG DropFuse':256,
#     'ECG cMRI DropFuse':256,
    
# }
# pairs = [
#     ('cMRI Circular AE','cMRI Dense AE'),
#     ('ECG Median','ECG 10s'),
#         ('cMRI ECG DropFuse','cMRI AE'),
#         ('ECG cMRI DropFuse','cMRI AE'),
    
# ]

In [None]:
def fit_logistic(label_header, train, test, indexes, verbose=False):
    if verbose:
        print(f'{label_header} len train {len(train)} len test {len(test)}')
        print(f'\nTrain:\n{train[label_header].value_counts()} \n\nTest:\n{test[label_header].value_counts()}')
    clf = LogisticRegression(penalty='elasticnet', solver='saga', class_weight='balanced', l1_ratio=0.5)
    clf.fit(train[indexes], train[label_header])
    
    sparsity = np.mean(clf.coef_ == 0) * 100
    score = clf.score(test[indexes], test[label_header])
    train_score = clf.score(train[indexes], train[label_header])
    auc_score = roc_auc_score(clf.predict(test[indexes]), test[label_header])
    train_auc_score = roc_auc_score(clf.predict(train[indexes]), train[label_header])
    if verbose:
        print(f'{label_header} AUC:{auc_score:.3f}  Train AUC:{train_auc_score:.3f}, Sparsity: {sparsity:.2f}\n')
    return auc_score

def fit_linear(label_header, train, test, indexes, verbose=False):
    if verbose:
        print(f'{label_header} len train {len(train)} len test {len(test)}')
        print(f'\nTrain:\n{len(train[label_header].value_counts())} \n\nTest:\n{len(test[label_header].value_counts())}')
    
    clf = make_pipeline(StandardScaler(with_mean=True), Ridge(solver='lsqr', max_iter=250000))
    clf.fit(train[indexes], train[label_header])

    score = clf.score(test[indexes], test[label_header])
    train_score = clf.score(train[indexes], train[label_header])
    if verbose:
        print(f'{label_header} R^2:{score:.3f}  Train R^2:{train_score:.3f}\n')
    return score

def latent_space_regression(label_file, latent_file, num_features = 256, start_features=0, train_ratio = 0.6, folds=4, verbose=False):
    labels = pd.read_csv(label_file)
    if latent_file.split('.')[-1].lower() == 'csv':
        indexes = [f'{i}' for i in range(start_features, num_features)]
        latent = pd.read_csv(latent_file)
    else:
        indexes = [f'latent_{i}' for i in range(start_features, num_features)]
        latent = pd.read_csv(latent_file, sep='\t')
    
    df = pd.merge(labels, latent, left_on='fpath', right_on='sample_id', how='inner')
    df = df.rename(columns=col_rename)
    scores = {}
    errors = {}
    for label in phenotypes if len(phenotypes) else labels.columns:
        try:
            full = df[df[label].notna()]
            if len(full[label].value_counts()) > 2:
                s = []
                for _ in range(folds):
                    train = full.sample(frac=train_ratio)
                    test = full.drop(train.index)
                    s.append(fit_linear(label, train, test, indexes, verbose))
                scores[f'{label} R^2'] = np.mean(s)
                errors[f'{label} R^2'] = 2*np.std(s)
            else:
                s = []
                for _ in range(folds):
                    train = full.sample(frac=train_ratio)
                    test = full.drop(train.index)
                    s.append(fit_logistic(label, train, test, indexes, verbose))
                scores[f'{label} AUC'] = np.mean(s)
                errors[f'{label} AUC'] = 2*np.std(s) 
        except Exception as e:
            print(f'Could not fit LR for {label} {e}')
    
    for k,v in sorted(scores.items(), key=lambda x: x[0].lower()):
        print(f'{k} {v:.3f}')

    return scores, errors


def plot_nested_dictionary(all_scores):
    n = 4
    chack = ['tab:orange', 'tab:blue', 'tab:green']
    for model in all_scores:
        n = max(n, len(all_scores[model][0]))
    cols = max(2, int(math.ceil(math.sqrt(n))))
    rows = max(2, int(math.ceil(n / cols)))
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 2.5), dpi=300)
    renest = defaultdict(dict)
    errors = defaultdict(dict)
    for model in all_scores:
        for metric in all_scores[model][0]:
            renest[metric][model] = all_scores[model][0][metric]
            errors[metric][model] = all_scores[model][1][metric]
    for metric, ax in zip(renest, axes.ravel()):
        models = [k for k,v in sorted(renest[metric].items(), key=lambda x: x[0].lower())]
        values = [v for k,v in sorted(renest[metric].items(), key=lambda x: x[0].lower())]
        err = [v for k,v in sorted(errors[metric].items(), key=lambda x: x[0].lower())]
        y_pos = np.arange(len(models))
        #print(f' {len(renest[metric])} len(models) : {len(models)} metric  {renest[metric]}')
        ax.barh(y_pos, values, xerr=err, align='center')
        ax.set_yticks(y_pos)
        ax.set_yticklabels(models)
        ax.invert_yaxis()  # labels read top-to-bottom
        if 'AUC' in metric:
            ax.set_xlabel('AUROC')
        else:
            ax.set_xlabel('$R^2$')
            
        ax.barh(y_pos, values, xerr=err, align='center', color=colors.TABLEAU_COLORS)
#         if len(metric.split('_')) > 1:
#             metric = metric.split('_')[1] + metric[-4:]
#         ax.set_title(metric.replace('R^2', '$R^2$'))
        if '21001_Body-mass-index-BMI_0_0' in metric:
            ax.set_title('BMI')
        elif '21003_Age-when-attended-assessment-centre_2_0' in metric:
            ax.set_title('Age')
        elif 'Sex_Male_0_0' in metric:
            ax.set_title('Sex')
        else:
            ax.set_title(metric.split(' ')[0]) 
        
    plt.tight_layout()

In [None]:
for name in latent_files:
    all_scores[name] = latent_space_regression(label_file, latent_files[name], 
                                               num_features=latent_size[name],
                                               folds=3, train_ratio=0.8,
                                              )

In [None]:
def compare_pairs(scores, pairs):
    for (p1,p2) in pairs:
        s1 = scores[p1]
        s2 = scores[p2]
        stats = Counter()
        print(f'comparing {p1} to {p2}' )
        for k in s2[0]:
            print(f'\t{k} {s1[0][k]:0.3f}, {s2[0][k]:0.3f},  Diff: {(s1[0][k] - s2[0][k]):0.3f} ')
            if 'R^2' in k:
                stats[f'{p1} R^2 sum'] += s1[0][k]
                stats[f'{p2} R^2 sum'] += s2[0][k]
                stats[f'{p1} R^2 std'] += s1[1][k]
                stats[f'{p2} R^2 std'] += s2[1][k]
                stats['R^2 n'] += 1
            elif 'AUC' in k:
                stats[f'{p1} AUC sum'] += s1[0][k]
                stats[f'{p2} AUC sum'] += s2[0][k]
                stats[f'{p1} AUC std'] += s1[1][k]
                stats[f'{p2} AUC std'] += s2[1][k]
                stats['AUC n'] += 1
        auc1 = stats[f'{p1} AUC sum']/stats['AUC n'] 
        auc_std1 = stats[f'{p1} AUC std']/(stats['AUC n'] )
        auc2 = stats[f'{p2} AUC sum']/stats['AUC n']
        auc_std2 = stats[f'{p2} AUC std']/(stats['AUC n'] ) 
        r21 = stats[f'{p1} R^2 sum']/stats['R^2 n']
        r2_std1 = stats[f'{p1} R^2 std']/(stats['R^2 n'] )
        r22 = stats[f'{p2} R^2 sum']/stats['R^2 n']
        r2_std2 = stats[f'{p2} R^2 std']/(stats['R^2 n'] )
        print(f"\n {p1} vs {p2} ")
        print(f"\t\t Mean AUCs {auc1:0.3f} ({auc1-auc_std1:0.3f}, {auc1+auc_std1:0.3f}), {auc2:0.3f} ({auc2-auc_std2:0.3f}, {auc2+auc_std2:0.3f}) ")
        print(f" \t\t Mean R^2 {r21:0.3f} ({r21-r2_std1:0.3f}, {r21+r2_std1:0.3f}),  {r22:0.3f} ({r22-r2_std2:0.3f}, {r22+r2_std2:0.3f}) \n\n\n")

In [None]:
compare_pairs(all_scores, pairs)

In [None]:
lf='/home/sam/trained_models/dxa_5_autoencoder_256d/hidden_dxa_1_5_dxa_5_autoencoder_256d.tsv'
all_scores['DXA 5 AE'] = latent_space_regression(label_file, lf)

In [None]:
lf='/home/sam/trained_models/dxa_2_5_dropfuse_256d/hidden_dxa_1_5_dxa_2_5_dropfuse_256d.tsv'
all_scores['DXA 5 DF'] = latent_space_regression(label_file, lf)

In [None]:
lf='/home/sam/trained_models/dxa_11_autoencoder_256d/hidden_dxa_1_11_dxa_11_autoencoder_256d.tsv'
all_scores['DXA 11 AE'] = latent_space_regression(label_file, lf)
lf='/home/sam/trained_models/dxa_11_12_dropfuse_256d_v2023_04_17/hidden_dxa_1_11_dxa_11_12_dropfuse_256d_v2023_04_17.tsv'
all_scores['DXA 11 DF'] = latent_space_regression(label_file, lf)

In [None]:
lf=f'/home/sam/trained_models/hypertuned_64m_18e_lax_4ch_heart_center_autoencoder_256d/hidden_embed_hypertuned_64m_18e_lax_4ch_heart_center_autoencoder_256d.tsv'
all_scores['MRI Autoencoder 256D'] = latent_space_regression(label_file, lf)

In [None]:
lf = f'/home/sam/csvs/03-11-2021_simclr_320-320_ukb_embeddings.csv'
all_scores['ECG PCLR 320D'] = latent_space_regression(label_file, lf, num_features=320)

In [None]:
plot_nested_dictionary(all_scores)