In [None]:
from models import *

import numpy as np
import pandas as pd

import sage

In [None]:
############################################
# Define helper functions
############################################

def save_sages(sage_vals, path):
    res = pd.DataFrame([aml_data.columns.to_list(), sage_vals.values, sage_vals.std]).T
    res.columns = ["metabolite_id", "sage_value", "sage_value_sd"]
    res.to_csv(path, index=False)
    
def save_pw_sages(sage_vals, pw_groupnames, path):
    res = pd.DataFrame([pw_groupnames, sage_vals.values, sage_vals.std]).T
    res.columns = ["pathway_name", "sage_value", "sage_value_sd"]
    res.to_csv(path, index=False)

def get_sage_pws(save_path, model, ref_data, test_data, dim, pw_groups, pw_groupnames, pw_type):

    def get_dim_vals(dat):
        return model.encode(dat)[:,dim:(dim+1)]
    
    # calculate base values
    dim_output = get_dim_vals(test_data)
    
    # Setup and calculate
    # NOTE: any callable function that returns a prediction is allowed in PermutationSampler
    imputer_pw = sage.GroupedMarginalImputer(ref_data[:10], pw_groups)
    sampler_pw = sage.PermutationSampler(get_dim_vals, imputer_pw, 'mse')
    sage_values_pw = sampler_pw(test_data, Y=dim_output, batch_size = 512)
    
    if pw_type == "super": 
        # do dim+1 to add 1 to the dimension index
        save_pw_sages(sage_values_pw, pw_groupnames, save_path + f"/superpw_dim_{dim+1}.csv")
    elif pw_type == "sub":
        save_pw_sages(sage_values_pw, pw_groupnames, save_path + f"/subpw_dim_{dim+1}.csv")
        
    return sage_values_pw


def get_sage_mets(save_path, model, ref_data, test_data, dim):

    def get_dim_vals(dat):
        return model.encode(dat)[:,dim:(dim+1)]
    
    dim_output = get_dim_vals(test_data)
    
    # Setup and calculate
    # NOTE: any callable function that returns a prediction is allowed in PermutationSampler
    imputer = sage.MarginalImputer(ref_data[:10])
    sampler = sage.PermutationSampler(get_dim_vals, imputer, 'mse')
    sage_values = sampler(test_data, Y=dim_output, batch_size = 10)
    
    # do dim+1 to add 1 to the dimension index
    save_sages(sage_values, save_path + f"/met_dim_{dim+1}.csv")
    
    return sage_values


def get_groups(pw_name):
    # Feature groups
    feature_groups = met_annos.groupby(pw_name)['COMP_IDstr'].apply(list).to_dict()

    group_names = [group for group in feature_groups]
    for col in feature_names:
        if np.all([col not in group[1] for group in feature_groups.items()]):
            group_names.append(col)

    # Group indices
    groups = []
    for _, group in feature_groups.items():
        ind_list = []
        for feature in group:
            ind_list.append(data_cols.index(feature))
        groups.append(ind_list)
        
    return {'feature_groups':feature_groups, 'group_names':group_names, 'groups':groups}

In [None]:
############################################
# Initialize variables and instantiate objects
############################################

twins_path  = 'data/TwinsUK.xls'
aml_path    = 'data/AML.xls'

aml_data = pd.read_excel(aml_path, sheet_name='Metabolite Data')
aml_anno = pd.read_excel(aml_path, sheet_name='Sample Annotations')

twins_train_df = pd.read_excel(twins_path, sheet_name='Training Set')
twins_test_df  = pd.read_excel(twins_path, sheet_name='Testing Set')

# these arrays are used for score calculations
twins_train = twins_train_df.values
twins_test  = twins_test_df.values


# Data & model configuration
latent_dim = 18


######################
# Define KPCA models
######################
poly_KPCA_model_ = KPCA_model(twins_train_df.values, latent_dim,"poly", 2, 0.001, 3, 5.0)
cosine_KPCA_model_ = KPCA_model(twins_train_df.values,latent_dim,"cosine", 1, 0, 0, 0)
sigmoid_KPCA_model_ = KPCA_model(twins_train_df.values,latent_dim,"sigmoid", 1, 0.05, 0, 0)
rbf_KPCA_model_ = KPCA_model(twins_train_df.values,latent_dim,"rbf", 1, 0.005, 0, 0)

cosine_path = "results/sage_values/cosine"
sigmoid_path = "results/sage_values/sigmoid"
rbf_path = "results/sage_values/rbf"
poly_path = "results/sage_values/poly"

# Create groups to calculate grouped SAGE values
data_cols = aml_data.columns.to_list()
met_annos = pd.read_excel(aml_path, sheet_name='Metabolite Annotations')
met_annos = met_annos.loc[met_annos['COMP_IDstr'].isin(data_cols)]
feature_names = met_annos['COMP_IDstr'].to_list()
biochemical_names = met_annos['BIOCHEMICAL'].to_list()

In [None]:
############################################
# MAIN PART OF SCRIPT
############################################

for latent_dim in range(18):
    for sage_type in ['superpathway', 'subpathway', 'metabolite']:

        if sage_type == "superpathway":
            superpw        = get_groups('SUPER_PATHWAY')
            sage_pw_values = get_sage_pws(cosine_path, cosine_KPCA_model_, twins_train, twins_test, latent_dim, superpw['groups'], superpw['group_names'], "super")
            sage_pw_values = get_sage_pws(sigmoid_path, sigmoid_KPCA_model_, twins_train, twins_test, latent_dim, superpw['groups'], superpw['group_names'], "super")
            sage_pw_values = get_sage_pws(rbf_path, rbf_KPCA_model_, twins_train, twins_test, latent_dim, superpw['groups'], superpw['group_names'], "super")
            sage_pw_values = get_sage_pws(poly_path, poly_KPCA_model_, twins_train, twins_test, latent_dim, superpw['groups'], superpw['group_names'], "super")

        elif sage_type == "subpathway":
            subpw          = get_groups('SUB_PATHWAY')
            sage_pw_values = get_sage_pws(cosine_path, cosine_KPCA_model_, twins_train, twins_test, latent_dim, subpw['groups'], subpw['group_names'], "sub")
            sage_pw_values = get_sage_pws(sigmoid_path, sigmoid_KPCA_model_, twins_train, twins_test, latent_dim, subpw['groups'], subpw['group_names'], "sub")
            sage_pw_values = get_sage_pws(rbf_path, rbf_KPCA_model_, twins_train, twins_test, latent_dim, subpw['groups'], subpw['group_names'], "sub")
            sage_pw_values = get_sage_pws(poly_path, poly_KPCA_model_, twins_train, twins_test, latent_dim, subpw['groups'], subpw['group_names'], "sub")

        elif sage_type == "metabolite":
            sage_met_values = get_sage_mets(cosine_path, cosine_KPCA_model_, twins_train, twins_test, latent_dim)
            sage_met_values = get_sage_mets(sigmoid_path, sigmoid_KPCA_model_, twins_train, twins_test, latent_dim)
            sage_met_values = get_sage_mets(rbf_path, rbf_KPCA_model_, twins_train, twins_test, latent_dim)
            sage_met_values = get_sage_mets(poly_path, poly_KPCA_model_, twins_train, twins_test, latent_dim)