In [None]:
import numpy as np
import pandas as pd
from ModuleRefinement import utils, ModuleLBG, center_algorithms
import plotly.express as px

from sklearn.model_selection import StratifiedKFold

from orthrus.core.dataset import DataSet as DS

from IPython.utils import io

In [None]:

species = 'mouse'
dataset_name = 'salmonella_Liver'
fold = 1

out_file = f'modules/5fold/fold{fold}_{dataset_name}.pickle'

module_file = f'./{out_file}'
refined_modules_dir = f'./refined_modules/'

  
center_methods = [['flag_mean',1,1]]

center_method = center_methods[0][0]
data_dimension = center_methods[0][1]
center_dimension = center_methods[0][2]

module_paths = [module_file,
                './refined_modules/flag_mean_1_1/5fold/fold1_salmonella_Liver.pickle']
algorithms = ['WGCNA', 'LBG']
organism = 'mmusculus'

In [None]:
data_path = f'../data/{dataset_name}.csv'

class_data, unique_labels, data_all, labels_all = utils.load_data(data_path)

In [None]:
 #keep random state at 0 for reproducible results
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
skf.get_n_splits(data_all, labels_all)
fold_number = 0 
for train_index, test_index in skf.split(data_all, labels_all):
    if fold_number == fold:
        split_data = data_all.iloc[train_index]
        split_labels = labels_all.iloc[train_index]

        with io.capture_output() as captured:
            r_value = utils.wgcna_modules(split_data, species, out_file)

        print(f'MODULES COMPUTED! R value = {r_value}')
    fold_number +=1

In [None]:
    split_path = module_path.split('/')

    the_modules, _ = load_modules(module_path)
    
    for center_method in center_methods:
        center_method_str = f'{center_method[0]}_{center_method[1]}_{center_method[2]}'
        save_path0 = f'{save_path_prefix}/{center_method_str}'
        if not os.path.isdir(save_path0):
            os.mkdir(save_path0)
        
        save_path1 = f'{save_path0}/{split_path[2]}'
        if not os.path.isdir(save_path1):
            os.mkdir(save_path1)

        save_path =  f'{save_path1}/{split_path[3][:-7]}.pickle'

        split_data = split_data.loc[:, (split_data != 0).any(axis=0)] #remove columns with all 0s

        #make index here! this is who's in what module
        feature_names = list(split_data.columns)
        feature_labels = pd.DataFrame(columns = feature_names, data = np.zeros((1,len(feature_names))))
        ii=0
        for _, m in the_modules.iterrows():
            feature_labels[m] = ii
            ii+=1
        index = feature_labels.iloc[0]

        restricted_data = split_data[feature_names]

        if center_method[2] > 1:
            my_mlbg = ModuleLBG(center_method = center_method[0], center_dimension = center_method[1],
                                 data_dimension = center_method[2], distance = 'max correlation', centrality = 'degree')
        else:
            my_mlbg = ModuleLBG(center_method = center_method[0], center_dimension = center_method[1],
                                 data_dimension = center_method[2], distance = 'correlation', centrality = 'degree')

        normalized_split_data = my_mlbg.process_data(np.array(restricted_data))

        my_mlbg.calc_centers(normalized_split_data, index)


        my_mlbg.fit_transform(normalized_split_data)

        labels = my_mlbg.get_labels(normalized_split_data)