In [None]:
from models import *
import sys

import numpy as np
import pandas as pd

import sage


############################################
# Define helper functions
############################################

def save_sages(sage_vals, path):
    res = pd.DataFrame([aml_data.columns.to_list(), sage_vals.values, sage_vals.std]).T
    res.columns = ["metabolite_id", "sage_value", "sage_value_sd"]
    res.to_csv(path, index=False)
    
def save_pw_sages(sage_vals, pw_groupnames, path):
    res = pd.DataFrame([pw_groupnames, sage_vals.values, sage_vals.std]).T
    res.columns = ["pathway_name", "sage_value", "sage_value_sd"]
    res.to_csv(path, index=False)

def get_sage_pws(model, ref_data, test_data, dim, pw_groups, pw_groupnames, pw_type):

    def get_dim_vals(dat):
        return model.encode_mu(dat)[:,dim:(dim+1)]
    
    # calculate base values
    dim_output = get_dim_vals(test_data)
    
    # Setup and calculate
    # NOTE: any callable function that returns a prediction is allowed in PermutationSampler
    imputer_pw = sage.GroupedMarginalImputer(ref_data[:10], pw_groups)
    sampler_pw = sage.PermutationSampler(get_dim_vals, imputer_pw, 'mse')
    sage_values_pw = sampler_pw(test_data, Y=dim_output, batch_size = 512)
    
    # rename dimension 0 to dimension 18
    dim_idx = 18 if dim == 0 else dim
    
    if pw_type == "super": 
        save_pw_sages(sage_values_pw, pw_groupnames, f"results/sage_values/VAE/superpw_dim_{dim_idx}.csv")
    elif pw_type == "sub":
        save_pw_sages(sage_values_pw, pw_groupnames, f"results/sage_values/VAE/subpw_dim_{dim_idx}.csv")
        
    return sage_values_pw


def get_sage_mets(model, ref_data, test_data, dim):

    def get_dim_vals(dat):
        return model.encode_mu(dat)[:,dim:(dim+1)]
    
    dim_output = get_dim_vals(test_data)
    
    # Setup and calculate
    # NOTE: any callable function that returns a prediction is allowed in PermutationSampler
    imputer = sage.MarginalImputer(ref_data[:10])
    sampler = sage.PermutationSampler(get_dim_vals, imputer, 'mse')
    sage_values = sampler(test_data, Y=dim_output, batch_size = 10)
    
    # rename dimension 0 to dimension 18
    dim_idx = 18 if dim == 0 else dim
    
    save_sages(sage_values, f"results/sage_values/VAE/met_dim_{dim_idx}.csv")
    
    return sage_values


def get_groups(pw_name):
    # Feature groups
    feature_groups = met_annos.groupby(pw_name)['COMP_IDstr'].apply(list).to_dict()

    group_names = [group for group in feature_groups]
    for col in feature_names:
        if np.all([col not in group[1] for group in feature_groups.items()]):
            group_names.append(col)

    # Group indices
    groups = []
    for _, group in feature_groups.items():
        ind_list = []
        for feature in group:
            ind_list.append(data_cols.index(feature))
        groups.append(ind_list)
        
    return {'feature_groups':feature_groups, 'group_names':group_names, 'groups':groups}

In [None]:
############################################
# Initialize variables and instantiate objects
############################################

twins_path  = 'data/TwinsUK.xls'
aml_path    = 'data/AML.xls'

aml_data = pd.read_excel(aml_path, sheet_name='Metabolite Data')
aml_anno = pd.read_excel(aml_path, sheet_name='Sample Annotations')

twins_train_df = pd.read_excel(twins_path, sheet_name='Training Set')
twins_test_df  = pd.read_excel(twins_path, sheet_name='Testing Set')

# these arrays are used for score calculations
twins_train = twins_train_df.values
twins_test  = twins_test_df.values


model_path = 'models/'

# Define paths for saved VAE model
path_vae     = model_path + 'VAE.h5'
path_encoder = model_path + 'VAE_encoder.h5'
path_decoder = model_path + 'VAE_decoder.h5'


# Data & model configuration
input_dim = twins_train_df.shape[1]
intermediate_dim = 200
latent_dim = 18

kl_beta = 1e-2
learning_rate = 1e-3

# instantiate model
mtmodel = mtVAE(input_dim,
                intermediate_dim,
                latent_dim,
                kl_beta,
                learning_rate)

# load model
mtmodel.vae.load_weights(path_vae)
mtmodel.encoder.load_weights(path_encoder)
mtmodel.decoder.load_weights(path_decoder)


# Create groups to calculate grouped SAGE values
data_cols = aml_data.columns.to_list()
met_annos = pd.read_excel(aml_path, sheet_name='Metabolite Annotations')
met_annos = met_annos.loc[met_annos['COMP_IDstr'].isin(data_cols)]
feature_names = met_annos['COMP_IDstr'].to_list()
biochemical_names = met_annos['BIOCHEMICAL'].to_list()

In [None]:
############################################
# MAIN PART OF SCRIPT
############################################

for latent_dim in range(18):
    for sage_type in ['superpathway', 'subpathway', 'metabolite']:
        if sage_type == "superpathway":
            superpw        = get_groups('SUPER_PATHWAY')
            sage_pw_values = get_sage_pws(mtmodel, twins_train, twins_test, latent_dim, superpw['groups'], superpw['group_names'], "super")

        elif sage_type == "subpathway":
            subpw          = get_groups('SUB_PATHWAY')
            sage_pw_values = get_sage_pws(mtmodel, twins_train, twins_test, latent_dim, subpw['groups'], subpw['group_names'], "sub")

        elif sage_type == "metabolite":
            sage_met_values = get_sage_mets(mtmodel, twins_train, twins_test, latent_dim)