In [1]:
%matplotlib inline

import copy
import os
import pickle
from collections import defaultdict
import itertools

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import scipy as sp
from sklearn.cluster import AgglomerativeClustering


from context import messi
from messi.data_processing import *
from messi.hme import hme
from messi.gridSearch import gridSearch

# Reproduction of MESSI results on MERFISH

Here we reproduce MESSI's results shown in our manuscript, using the excitatory neurons in the female naive animals of the MERFISH hypothalamus dataset as an example. The corresponding results are shown in Figure 3. 

In order to reproduce the results in the manuscript, the number of experts and other hyperparameters are specified according to our grid search results as shown in Table S3 in our manuscript.

The results included in this notebook were produced in Linux Mint 19 with Intel(R) Xeon(R) W-2123 CPU @ 3.60GHz and 16 GB memory.

## User-definied arguements

Same as the arguements for the command-line usage. See detailed information of each arguement in the docs. 

In [3]:
input_path = '../../input/merfish/'
output_path = '../output/'

data_type = 'merfish'
sex = 'Female'
behavior = "Naive" 
behavior_no_space = behavior.replace(" ", "_")
current_cell_type = 'Excitatory'
current_cell_type_no_space = current_cell_type.replace(" ", "_")

grid_search = False 
n_sets = 5  

n_classes_0 = 1
# n_classes_1 = 8 # we will specify number of experts later
n_epochs = 20  

preprocess = 'neighbor_cat'
top_k_response = None  
top_k_regulator = None
response_type = 'original'  # use raw values to fit the model

mode = "CV"
n_replicates = 1

## Read in and preprocess data

### Read in the ligad & receptor list and meta information for the dataset

The meta information can be obtained by running readyData.py. See the details in the docs. 

In [4]:
read_in_functions = {'merfish': [read_meta_merfish, read_merfish_data, get_idx_per_dataset_merfish],
                    'merfish_cell_line': [read_meta_merfish_cell_line, read_merfish_cell_line_data, get_idx_per_dataset_merfish_cell_line],
                    'starmap': [read_meta_starmap_combinatorial, read_starmap_combinatorial, get_idx_per_dataset_starmap_combinatorial]}

# set data reading functions corresponding to the data type
if data_type in ['merfish', 'merfish_cell_line', 'starmap']:
    read_meta = read_in_functions[data_type][0]
    read_data = read_in_functions[data_type][1]
    get_idx_per_dataset = read_in_functions[data_type][2]
else:
    raise NotImplementedError(f"Now only support processing 'merfish', 'merfish_cell_line' or 'starmap'")

# read in ligand and receptor lists
l_u, r_u = get_lr_pairs(input_path='../messi/input/')  # may need to change to the default value

# read in meta information about the dataset
meta_all, meta_all_columns, cell_types_dict, genes_list, genes_list_u, \
response_list_prior, regulator_list_prior = \
    read_meta(input_path, behavior_no_space, sex, l_u, r_u)  # TO BE MODIFIED: number of responses

# get all available animals/samples
all_animals = list(set(meta_all[:, meta_all_columns['Animal_ID']]))

Removed genes: {'Blank_1', 'Blank_5', 'Blank_3', 'Blank_2', 'Fos', 'Blank_4'}
Total number of cell types for merfish: 16


# Reproduce results by running outer CV

In order to reproduce the results in the manuscript, here we set the number of experts and the other hyperparameters the same values as we present in Table S3 in the manuscript. 

In [5]:
n_experts_types = {'Inhibitory': {1: 10, 2: 10, 3: 10, 4: 10}, 
                   'Excitatory': {1: 8, 2: 8, 3: 10, 4: 10},
                   'Astrocyte': {1: 4, 2: 4, 3: 3, 4: 3},
                   'OD Mature 2' : {1: 3, 2: 3, 3: 4, 4: 3},
                   'Endothelial 1': {1: 1, 2: 1, 3: 2, 4: 2},
                   'OD Immature 1': {1: 1, 2: 1, 3: 2, 4: 2},
                   'OD Mature 1': {1: 1, 2: 1, 3: 1, 4: 1},
                   'Microglia': {1: 1, 2: 1, 3: 1, 4: 1}}

soft_weights_types = {'Inhibitory': {1: True, 2: True, 3: True, 4: True}, 
                    'Excitatory': {1: True, 2: True, 3: True, 4: True}, 
                    'Astrocyte': {1: True, 2: True, 3: False, 4: False}, 
                    'OD Mature 2' : {1: True, 2: True, 3: True, 4: False}, 
                    'Endothelial 1': {1: True, 2: True, 3: False, 4: False}, 
                    'OD Immature 1': {1: True, 2: True, 3: True, 4: True}, 
                    'OD Mature 1': {1: True, 2: True, 3: True, 4: True}, 
                    'Microglia': {1: True, 2: True, 3: True, 4: True}}

In [6]:
import time
start_time = time.time()

In [7]:
for _z in range(len(all_animals)):
    if mode == 'train':
        # only run once
        if _z == 0:
            test_animal = ''
        else:
            break
    else:
        test_animal = all_animals[_z]
    
    # specify number of experts and other hyperparameters according to Table S3
    n_classes_1 = n_experts_types[current_cell_type][test_animal]
    soft_weights = soft_weights_types[current_cell_type][test_animal]
    partial_fit_expert = soft_weights
    
    condition = f"response_{top_k_response}_l1_{n_classes_0}_l2_{n_classes_1}"

    if grid_search:
        condition = f"response_{top_k_response}_l1_{n_classes_0}_l2_grid_search"
    else:
        condition = f"response_{top_k_response}_l1_{n_classes_0}_l2_{n_classes_1}"

    
    samples_test = np.array([test_animal])
    samples_train = np.array(list(set(all_animals) - {test_animal}))
    print(f"Test set is {samples_test}")
    print(f"Training set is {samples_train}")

    bregma = None
    # ------ read data ------
    idx_train, idx_test, idx_train_in_general, \
    idx_test_in_general, idx_train_in_dataset, \
    idx_test_in_dataset, meta_per_dataset_train, \
    meta_per_dataset_test = find_idx_for_train_test(samples_train, samples_test,
                                                    meta_all, meta_all_columns,
                                                    data_type, current_cell_type, get_idx_per_dataset,
                                                    return_in_general=False, bregma=bregma)

    # TBD: the current approach uses a lot memory;
    data_sets = []

    for animal_id, bregma in meta_per_dataset_train:
        hp, hp_cor, hp_genes = read_data(input_path, bregma, animal_id, genes_list, genes_list_u)

        if hp is not None:
            hp_columns = dict(zip(hp.columns, range(0, len(hp.columns))))
            hp_np = hp.to_numpy()
        else:
            hp_columns = None
            hp_np = None
        hp_cor_columns = dict(zip(hp_cor.columns, range(0, len(hp_cor.columns))))
        hp_genes_columns = dict(zip(hp_genes.columns, range(0, len(hp_genes.columns))))
        data_sets.append([hp_np, hp_columns, hp_cor.to_numpy(), hp_cor_columns,
                          hp_genes.to_numpy(), hp_genes_columns])
        del hp, hp_cor, hp_genes

    datasets_train = data_sets

    data_sets = []

    for animal_id, bregma in meta_per_dataset_test:
        hp, hp_cor, hp_genes = read_data(input_path, bregma, animal_id, genes_list, genes_list_u)

        if hp is not None:
            hp_columns = dict(zip(hp.columns, range(0, len(hp.columns))))
            hp_np = hp.to_numpy()
        else:
            hp_columns = None
            hp_np = None

        hp_cor_columns = dict(zip(hp_cor.columns, range(0, len(hp_cor.columns))))
        hp_genes_columns = dict(zip(hp_genes.columns, range(0, len(hp_genes.columns))))
        data_sets.append([hp_np, hp_columns, hp_cor.to_numpy(), hp_cor_columns,
                          hp_genes.to_numpy(), hp_genes_columns])
        del hp, hp_cor, hp_genes

    datasets_test = data_sets

    del data_sets

    # ------ pre-processing -------

    # construct neighborhood graph
    if data_type == 'merfish_RNA_seq':
        neighbors_train = None
        neighbors_test = None
    else:
        if data_type == 'merfish':
            dis_filter = 100
        else:
            dis_filter = 1e9

        neighbors_train = get_neighbors_datasets(datasets_train, "Del", k=10, dis_filter=dis_filter,
                                                 include_self=False)
        neighbors_test = get_neighbors_datasets(datasets_test, "Del", k=10, dis_filter=dis_filter,
                                                include_self=False)
    # set parameters for different feature types
    lig_n = {'name': 'regulators_neighbor', 'helper': preprocess_X_neighbor_per_cell,
             'feature_list_type': 'regulator_neighbor', 'per_cell': True, 'baseline': False,
             'standardize': True, 'log': True, 'poly': False}
    rec_s = {'name': 'regulators_self', 'helper': preprocess_X_self_per_cell,
             'feature_list_type': 'regulator_self', 'per_cell': True, 'baseline': False,
             'standardize': True, 'log': True, 'poly': False}
    lig_s = {'name': 'regulators_neighbor_self', 'helper': preprocess_X_self_per_cell,
             'feature_list_type': 'regulator_neighbor', 'per_cell': True, 'baseline': False,
             'standardize': True, 'log': True, 'poly': False}
    type_n = {'name': 'neighbor_type', 'helper': preprocess_X_neighbor_type_per_dataset,
              'feature_list_type': None, 'per_cell': False, 'baseline': False,
              'standardize': True, 'log': False, 'poly': False}
    base_s = {'name': 'baseline', 'helper': preprocess_X_baseline_per_dataset, 'feature_list_type': None,
              'per_cell': False, 'baseline': True, 'standardize': True, 'log': False, 'poly': False}

    if data_type == 'merfish_cell_line':
        feature_types = [lig_n, rec_s, base_s, lig_s]
    else:
        feature_types = [lig_n, rec_s, type_n, base_s, lig_s]

    # untransformed features
    X_trains, X_tests, regulator_list_neighbor, regulator_list_self = prepare_features(data_type, datasets_train,
                                                                                       datasets_test,
                                                                                       meta_per_dataset_train,
                                                                                       meta_per_dataset_test,
                                                                                       idx_train, idx_test,
                                                                                       idx_train_in_dataset,
                                                                                       idx_test_in_dataset,
                                                                                       neighbors_train,
                                                                                       neighbors_test,
                                                                                       feature_types,
                                                                                       regulator_list_prior,
                                                                                       top_k_regulator,
                                                                                       genes_list_u, l_u, r_u,
                                                                                       cell_types_dict)
    total_regulators = regulator_list_neighbor + regulator_list_self

    log_response = True  # take log transformation of the response genes
    Y_train, Y_train_true, Y_test, Y_test_true, response_list = prepare_responses(data_type, datasets_train,
                                                                                  datasets_test,
                                                                                  idx_train_in_general,
                                                                                  idx_test_in_general,
                                                                                  idx_train_in_dataset,
                                                                                  idx_test_in_dataset,
                                                                                  neighbors_train,
                                                                                  neighbors_test,
                                                                                  response_type, log_response,
                                                                                  response_list_prior,
                                                                                  top_k_response,
                                                                                  genes_list_u, l_u, r_u)
    if grid_search:
        X_trains_gs = copy.deepcopy(X_trains)
        Y_train_gs = copy.copy(Y_train)

    # transform features
    transform_features(X_trains, X_tests, feature_types)
    print(f"Minimum value after transformation can below 0: {np.min(X_trains['regulators_self'])}")

    # combine different type of features
    if data_type == 'merfish':
        num_coordinates = 3
    elif data_type == 'starmap' or data_type == 'merfish_cell_line':
        num_coordinates = 2
    else:
        num_coordinates = None

    if np.ndim(X_trains['baseline']) > 1 and np.ndim(X_tests['baseline']) > 1:
        X_train, X_train_clf_1, X_train_clf_2 = combine_features(X_trains, preprocess, num_coordinates)
        X_test, X_test_clf_1, X_test_clf_2 = combine_features(X_tests, preprocess, num_coordinates)
    elif np.ndim(X_trains['baseline']) > 1:
        X_train, X_train_clf_1, X_train_clf_2 = combine_features(X_trains, preprocess, num_coordinates)

    print(f"Dimension of X train is: {X_train.shape}")
    if mode == 'CV':
        print(f"Dimension of X test is: {X_test.shape}")

    # ------ modeling by MESSI ------
    for _i in range(0, n_replicates):

        # ------ set parameters ------
        model_name_gates = 'logistic'
        model_name_experts = 'mrots'

        soft_weights = True
        partial_fit_expert = True

        # if current_cell_type not in ['OD Mature 2', 'Astrocyte', 'Endothelial 1']:
        #     # soft weights
        #     soft_weights = True
        #     partial_fit_expert = True
        #
        # else:
        #     # hard weights
        #     soft_weights = False
        #     partial_fit_expert = False

        # specify default parameters for MESSI
        model_params = {'n_classes_0': n_classes_0,
                        'n_classes_1': n_classes_1,
                        'model_name_gates': model_name_gates,
                        'model_name_experts': model_name_experts,
                        'num_responses': Y_train.shape[1],
                        'soft_weights': soft_weights,
                        'partial_fit_expert': partial_fit_expert,
                        'n_epochs': n_epochs,
                        'tolerance': 3}

        print(f"Model parameters for training is {model_params}")

        # set up directory for saving the model
        sub_condition = f"{condition}_{model_name_gates}_{model_name_experts}"
        sub_dir = f"{data_type}/{behavior_no_space}/{sex}/{current_cell_type_no_space}/{preprocess}/{sub_condition}"
        current_dir = os.path.join(output_path, sub_dir)

        if not os.path.exists(current_dir):
            os.makedirs(current_dir)

        print(f"Model and validation results (if applicable) saved to: {current_dir}")

        if mode == 'CV':
            suffix = f"_{test_animal}_{_i}"
        else:
            suffix = f"_{_i}"

        if grid_search:
            # prepare input meta data
            if data_type == 'merfish':
                meta_per_part = [tuple(i) for i in meta_per_dataset_train]
                meta_idx = meta2idx(idx_train_in_dataset, meta_per_part)
            else:
                meta_per_part, meta_idx = combineParts(samples_train, datasets_train, idx_train_in_dataset)

            # prepare parameters list to be tuned
            if data_type == 'merfish_cell_line':
                current_cell_type_data = 'U-2_OS'
            elif data_type == 'starmap':
                current_cell_type_data = 'STARmap_excitatory'
            else:
                current_cell_type_data = current_cell_type

            params = {'n_classes_1': list(search_range_dict[current_cell_type_data]), 'soft_weights': [True, False],
                      'partial_fit_expert': [True, False]}

            keys, values = zip(*params.items())
            params_list = [dict(zip(keys, v)) for v in itertools.product(*values)]

            new_params_list = []
            for d in params_list:
                if d['soft_weights'] == d['partial_fit_expert']:
                    new_params_list.append(d)
            ratio = 0.2

            # initialize with default values
            model_params_val = model_params.copy()
            model_params_val['n_epochs'] = 1
            model_params_val['tolerance'] = 0
            print(f"Default model parameters for validation {model_params_val}")
            model = hme(**model_params_val)

            gs = gridSearch(params, model, ratio, n_sets, new_params_list)
            gs.generate_val_sets(samples_train, meta_per_part)
            gs.runCV(X_trains_gs, Y_train_gs, meta_per_part, meta_idx, feature_types, data_type,
                     preprocess)
            gs.get_best_parameter()
            print(f"Best params from grid search: {gs.best_params}")

            # modify the parameter setting
            for key, value in gs.best_params.items():
                model_params[key] = value

            print(f"Model parameters for training after grid search {model_params}")

            filename = f"validation_results{suffix}.pickle"
            pickle.dump(gs, open(os.path.join(current_dir, filename), 'wb'))

        # ------ initialize the sample assignments ------

        if grid_search and 'n_classes_1' in params:
            model = AgglomerativeClustering(n_clusters=gs.best_params['n_classes_1'])
        else:
            model = AgglomerativeClustering(n_classes_1)

        model = model.fit(Y_train)
        hier_labels = [model.labels_]
        model_params['init_labels_1'] = hier_labels

        # ------ construct MESSI  ------
        model = hme(**model_params)

        # train
        model.train(X_train, X_train_clf_1, X_train_clf_2, Y_train)

        # save the model
        filename = f"hme_model{suffix}.pickle"
        pickle.dump(model, open(os.path.join(current_dir, filename), 'wb'))

        # predict the left-out animal
        if mode == 'CV':

            Y_hat_final = model.predict(X_test, X_test_clf_1, X_test_clf_2)

            mae = abs(Y_test - Y_hat_final).mean(axis=1).mean()
            print(f"Mean absolute value for {test_animal} is {mae}")

            filename = f"test_predictions_{test_animal}_{_i}"
            np.save(os.path.join(current_dir, filename), Y_hat_final)
            
    print("\n")

Test set is [1]
Training set is [2 3 4]
Preprocess for Excitatory of merfish
19855
11757
Reading file: merfish_animal2_bregma-029.csv
The dimensions of the sample is: (6091, 170)
Reading file: merfish_animal2_bregma-024.csv
The dimensions of the sample is: (6263, 170)
Reading file: merfish_animal2_bregma-019.csv
The dimensions of the sample is: (6328, 170)
Reading file: merfish_animal2_bregma-014.csv
The dimensions of the sample is: (6135, 170)
Reading file: merfish_animal2_bregma-009.csv
The dimensions of the sample is: (5819, 170)
Reading file: merfish_animal2_bregma-004.csv
The dimensions of the sample is: (5693, 170)
Reading file: merfish_animal2_bregma001.csv
The dimensions of the sample is: (5677, 170)
Reading file: merfish_animal2_bregma006.csv
The dimensions of the sample is: (5381, 170)
Reading file: merfish_animal2_bregma011.csv
The dimensions of the sample is: (5727, 170)
Reading file: merfish_animal2_bregma016.csv
The dimensions of the sample is: (5284, 170)
Reading file: m



------ epoch 1 ------
Best score: 1000000000.0
Current score: 398113.96558739466
level 1 gate error: 0
level 2 gate error: [0.889138701791625]
experts error: [68364.70856602101, 28370.010229524174, 84819.25195742327, 92086.04970744454, 52496.39037529202, 28540.357866236576, 29015.12230263258, 14421.185444118602]




------ epoch 2 ------
Best score: 398113.96558739466
Current score: 387696.0987812748
level 1 gate error: 0
level 2 gate error: [0.7973014146111466]
experts error: [69071.93423886992, 28303.259058824362, 76035.3743100921, 89122.59760322857, 51907.68451985893, 28190.58895373807, 29840.434112959378, 15223.428682288917]




------ epoch 3 ------
Best score: 387696.0987812748
Current score: 385763.60853260686
level 1 gate error: 0
level 2 gate error: [0.7818721560291331]
experts error: [69531.13997132106, 28240.418177270527, 73967.83984658352, 87679.04651461403, 52357.903500187356, 27988.425211733553, 30669.53888462708, 15328.514554113763]




------ epoch 4 ------
Best score: 385763.60853260686
Current score: 385184.49423688964
level 1 gate error: 0
level 2 gate error: [0.7834034995311641]
experts error: [69174.62106973416, 28196.8802859436, 74034.5952174537, 86258.93351397198, 52618.12800709106, 27880.647223149666, 31666.814373003992, 15353.091143041936]




------ epoch 5 ------
Best score: 385184.49423688964
Current score: 384907.9242673249
level 1 gate error: 0
level 2 gate error: [0.7858786065189721]
experts error: [68531.1533905271, 28215.213968092823, 74185.71975333459, 85243.08736382046, 52914.974033596765, 27864.197442874298, 32544.468560869554, 15408.323875602719]




------ epoch 6 ------
Best score: 384907.9242673249
Current score: 384883.2166177118
level 1 gate error: 0
level 2 gate error: [0.7886346507204876]
experts error: [68154.24843853255, 28196.9277540968, 74275.97357912305, 84540.73438353554, 53254.611166592666, 27815.647979988113, 33137.52807986062, 15506.75660133175]




------ epoch 7 ------
Best score: 384883.2166177118
Current score: 384842.9527265786
level 1 gate error: 0
level 2 gate error: [0.7909722527953909]
experts error: [67798.58283748344, 28194.628443353427, 74469.33397413735, 83893.82936994566, 53547.9094524392, 27819.998306401863, 33601.462192195955, 15516.417178368905]




------ epoch 8 ------
Best score: 384842.9527265786
Current score: 384707.0297724219
level 1 gate error: 0
level 2 gate error: [0.7908764058195177]
experts error: [67411.0976686229, 28140.31348567091, 74517.85767947193, 83568.9872844636, 53802.50063105402, 27827.355040105398, 33927.63665985375, 15510.490446773605]




------ epoch 9 ------
Best score: 384707.0297724219
Current score: 384730.5969108038
level 1 gate error: 0
level 2 gate error: [0.791679104451329]
experts error: [67107.39087240495, 28098.997966154773, 74626.882300188, 83431.21769311874, 53863.76270743416, 27840.04474046723, 34231.37745473079, 15530.131497200704]




------ epoch 10 ------
Best score: 384707.0297724219
Current score: 384769.4932614904
level 1 gate error: 0
level 2 gate error: [0.7920378356023607]
experts error: [66773.88505890981, 28065.735059776693, 74677.94879537258, 83369.93296342503, 54011.15847901439, 27806.044196671744, 34499.422950804124, 15564.573719680493]




------ epoch 11 ------
Best score: 384707.0297724219
Current score: 384913.7278781428
level 1 gate error: 0
level 2 gate error: [0.7928185012547598]
experts error: [66390.57484371692, 28032.642177029833, 74915.99424189405, 83315.46788581662, 54137.50134867483, 27780.283669500284, 34772.40530896592, 15568.065584043083]




------ epoch 12 ------
Best score: 384707.0297724219
Current score: 384977.4029096831
level 1 gate error: 0
level 2 gate error: [0.7942539946227496]
experts error: [66119.44021761093, 28005.05619500701, 74959.88186471086, 83382.33620623522, 54190.52880209242, 27757.347194214286, 34991.432373390766, 15570.58580242704]
12 epochs in total
Mean absolute value for 1 is 0.3801548360510801


Test set is [2]
Training set is [1 3 4]
Preprocess for Excitatory of merfish
21185
10427
Reading file: merfish_animal1_bregma-029.csv
The dimensions of the sample is: (6509, 170)
Reading file: merfish_animal1_bregma-024.csv
The dimensions of the sample is: (6412, 170)
Reading file: merfish_animal1_bregma-019.csv
The dimensions of the sample is: (6507, 170)
Reading file: merfish_animal1_bregma-014.csv
The dimensions of the sample is: (6605, 170)
Reading file: merfish_animal1_bregma-009.csv
The dimensions of the sample is: (6185, 170)
Reading file: merfish_animal1_bregma-004.csv
The dimensions of the sample



------ epoch 1 ------
Best score: 1000000000.0
Current score: 433268.63783688034
level 1 gate error: 0
level 2 gate error: [0.8444451442341709]
experts error: [153076.16389737924, 35224.26849074252, 64192.8950462014, 55838.81470278713, 29184.803618567028, 21820.192606418863, 27218.149410870825, 46712.50561876917]




------ epoch 2 ------
Best score: 433268.63783688034
Current score: 421307.2629487564
level 1 gate error: 0
level 2 gate error: [0.7690817837425469]
experts error: [132623.50805560863, 35649.957588542195, 63225.377129314154, 58761.80712515598, 31264.421628481978, 23100.86076613644, 28208.020305915812, 48472.541267817505]




------ epoch 3 ------
Best score: 421307.2629487564
Current score: 419326.5061366189
level 1 gate error: 0
level 2 gate error: [0.7576210903309818]
experts error: [123840.57187455604, 35868.237500444884, 63610.958542409535, 61125.9687072003, 32464.292133865336, 23754.144289450847, 28608.331231282882, 50053.24423631869]




------ epoch 4 ------
Best score: 419326.5061366189
Current score: 418825.64093973354
level 1 gate error: 0
level 2 gate error: [0.7603582980754624]
experts error: [119286.11368085514, 35744.23342706559, 64026.51343247252, 62423.04548996606, 33366.181438314496, 24465.02274075038, 28824.70496826498, 50689.06540374627]




------ epoch 5 ------
Best score: 418825.64093973354
Current score: 418590.6800627944
level 1 gate error: 0
level 2 gate error: [0.7633463517985869]
experts error: [116092.70382884306, 35793.28326307423, 64363.12345085609, 63238.06732329718, 33902.28537604808, 25100.087543461228, 28937.655575483463, 51162.71035537927]




------ epoch 6 ------
Best score: 418590.6800627944
Current score: 418432.63941748824
level 1 gate error: 0
level 2 gate error: [0.7702357195114894]
experts error: [113738.1552514383, 35889.17869656372, 64425.12348898893, 63541.25894294855, 34452.18215442081, 25734.50972837296, 29012.789737222876, 51638.67118181251]




------ epoch 7 ------
Best score: 418432.63941748824
Current score: 418532.61320045503
level 1 gate error: 0
level 2 gate error: [0.7740398129322456]
experts error: [111925.47499275848, 35954.81011202108, 64613.27823573772, 63736.285032822234, 34825.49617740709, 26420.222796109792, 29066.461285921152, 51989.81052786457]




------ epoch 9 ------
Best score: 418432.63941748824
Current score: 418577.29669997224
level 1 gate error: 0
level 2 gate error: [0.7801936364962676]
experts error: [108917.3493887152, 35873.137035293796, 65395.907027440604, 64502.84742902916, 35254.469835917094, 27154.688698042206, 29131.038555099738, 52347.07853679797]




------ epoch 10 ------
Best score: 418432.63941748824
Current score: 418584.99880938727
level 1 gate error: 0
level 2 gate error: [0.7824807712167612]
experts error: [107537.03415462117, 35858.625951736816, 65960.60523651661, 64878.002822843795, 35430.76147405717, 27392.194092691116, 29162.99234309801, 52364.00025305141]
10 epochs in total
Mean absolute value for 2 is 0.3762879954560655


Test set is [3]
Training set is [1 2 4]
Preprocess for Excitatory of merfish
27440
4172
Reading file: merfish_animal1_bregma-029.csv
The dimensions of the sample is: (6509, 170)
Reading file: merfish_animal1_bregma-024.csv
The dimensions of the sample is: (6412, 170)
Reading file: merfish_animal1_bregma-019.csv
The dimensions of the sample is: (6507, 170)
Reading file: merfish_animal1_bregma-014.csv
The dimensions of the sample is: (6605, 170)
Reading file: merfish_animal1_bregma-009.csv
The dimensions of the sample is: (6185, 170)
Reading file: merfish_animal1_bregma-004.csv
The dimensions of the sam



------ epoch 1 ------
Best score: 1000000000.0
Current score: 556914.7869875293
level 1 gate error: 0
level 2 gate error: [0.9138363265671118]
experts error: [173333.72168397263, 87619.52947735989, 36606.40629741332, 56862.24500580864, 86958.39923525132, 26844.333999098108, 34144.42056192167, 16870.98970379069, 24028.34555618476, 13645.481630401642]




------ epoch 2 ------
Best score: 556914.7869875293
Current score: 541094.7653278662
level 1 gate error: 0
level 2 gate error: [0.8185227235229762]
experts error: [147812.35189475704, 87988.98920075395, 37962.95609807239, 59502.92016074653, 90118.37913234682, 28320.387008121743, 32849.21561757954, 17981.17224956197, 24978.845641003645, 13578.729802198937]




------ epoch 3 ------
Best score: 541094.7653278662
Current score: 539221.2094135807
level 1 gate error: 0
level 2 gate error: [0.8104469755680833]
experts error: [137753.25407581657, 90324.21555039886, 38713.388100240445, 61036.37036813282, 92044.64141839521, 28903.524352101846, 32459.49281245776, 18567.294262222462, 25789.360355557757, 13628.857671281392]




------ epoch 4 ------
Best score: 539221.2094135807
Current score: 538633.6595007979
level 1 gate error: 0
level 2 gate error: [0.8163740192346035]
experts error: [132233.6299687469, 92353.72413274758, 38987.66815259826, 61905.51531100739, 92335.24763479896, 29461.744790347562, 32276.140141652526, 18925.56826996658, 26523.703976297737, 13629.90074861505]




------ epoch 5 ------
Best score: 538633.6595007979
Current score: 538451.2535342845
level 1 gate error: 0
level 2 gate error: [0.8270927741220225]
experts error: [128764.36463204624, 94085.60202327227, 39200.279016762004, 62509.2811968731, 92269.84138990691, 29817.855422045584, 32183.849235751564, 19110.33298196736, 26847.985661807605, 13661.034881077801]




------ epoch 6 ------
Best score: 538451.2535342845
Current score: 538563.9511449456
level 1 gate error: 0
level 2 gate error: [0.831713546367055]
experts error: [126141.98330862995, 95499.78858089981, 39350.5659015004, 63138.37594791852, 92093.2864545318, 30098.931129187255, 32162.230571418302, 19247.56242968085, 27169.6319764019, 13660.763131230353]




------ epoch 7 ------
Best score: 538451.2535342845
Current score: 538758.8411739601
level 1 gate error: 0
level 2 gate error: [0.8354101705925271]
experts error: [124034.15540404955, 96195.05165021378, 39433.64489894446, 63802.38461726555, 92282.51977933239, 30257.515199972804, 32182.72957848789, 19384.91747857635, 27534.87067975701, 13650.216477189766]




------ epoch 8 ------
Best score: 538451.2535342845
Current score: 538958.5058065988
level 1 gate error: 0
level 2 gate error: [0.8407315452127053]
experts error: [122888.30685442413, 96272.64929555471, 39453.351098831925, 64404.78285610316, 92392.48075750849, 30426.866056470488, 32170.954999254252, 19481.246726976373, 27873.434741192785, 13593.591688737228]




------ epoch 9 ------
Best score: 538451.2535342845
Current score: 539161.6314755442
level 1 gate error: 0
level 2 gate error: [0.8464077911321729]
experts error: [121713.0528949112, 96267.18874854235, 39470.201762892764, 65077.58007652011, 92710.22085690103, 30507.578584692834, 32154.638572424738, 19532.573095434673, 28142.061337037943, 13585.68913839541]
9 epochs in total
Mean absolute value for 3 is 0.3728791641632251


Test set is [4]
Training set is [1 2 3]
Preprocess for Excitatory of merfish
26356
5256
Reading file: merfish_animal1_bregma-029.csv
The dimensions of the sample is: (6509, 170)
Reading file: merfish_animal1_bregma-024.csv
The dimensions of the sample is: (6412, 170)
Reading file: merfish_animal1_bregma-019.csv
The dimensions of the sample is: (6507, 170)
Reading file: merfish_animal1_bregma-014.csv
The dimensions of the sample is: (6605, 170)
Reading file: merfish_animal1_bregma-009.csv
The dimensions of the sample is: (6185, 170)
Reading file: merfish_animal1_bregm



------ epoch 1 ------
Best score: 1000000000.0
Current score: 527131.5704880377
level 1 gate error: 0
level 2 gate error: [1.0127799212707973]
experts error: [88620.20505848448, 68319.97879332537, 93866.3060698133, 32993.316969545005, 43697.91630738474, 20940.935003805906, 32920.46009468532, 28848.783564258832, 96272.31181374431, 20650.344033069177]




------ epoch 2 ------
Best score: 527131.5704880377
Current score: 513534.7259049733
level 1 gate error: 0
level 2 gate error: [0.8864554418841262]
experts error: [82746.58925714699, 66962.1760021498, 88011.54597392083, 34093.735557329266, 44769.462294723235, 22247.572597901497, 32348.904743238578, 28890.590775307195, 92321.14717170558, 21142.1150761084]




------ epoch 3 ------
Best score: 513534.7259049733
Current score: 511917.2680749441
level 1 gate error: 0
level 2 gate error: [0.8669622384972885]
experts error: [79861.82473826053, 66494.3838077263, 88994.62481805256, 34579.91575279971, 45661.86970551746, 23527.277824194156, 32024.01001241451, 28829.11292305217, 90643.99261324486, 21299.38891744331]




------ epoch 4 ------
Best score: 511917.2680749441
Current score: 511679.7734559846
level 1 gate error: 0
level 2 gate error: [0.8659591843502601]
experts error: [77336.31912003698, 66388.66763037503, 90891.69570547313, 34975.977381425546, 46045.81849889626, 24819.91318947829, 31830.990305453077, 28779.096121452425, 89365.30418661253, 21245.125357597004]




------ epoch 5 ------
Best score: 511679.7734559846
Current score: 511868.4554201901
level 1 gate error: 0
level 2 gate error: [0.8697065167115485]
experts error: [74975.29340378723, 66449.28792448208, 92741.37063844483, 35339.96320961204, 46261.20980697403, 25923.51662013384, 31697.94973515497, 28740.865898363478, 88551.38491356857, 21186.743563152257]




------ epoch 6 ------
Best score: 511679.7734559846
Current score: 511908.7536563823
level 1 gate error: 0
level 2 gate error: [0.8746814315009017]
experts error: [73051.25061292296, 66542.12943672585, 94093.87643831733, 35590.054968872544, 46490.620840499316, 26931.537501472023, 31584.2443533004, 28896.269877191356, 87656.79105258304, 21071.103893066007]




------ epoch 7 ------
Best score: 511679.7734559846
Current score: 511938.2739003165
level 1 gate error: 0
level 2 gate error: [0.8780618015262682]
experts error: [71503.30865704858, 66831.08125330489, 94877.79702582685, 35697.351658003565, 46733.54783151238, 27754.272986213942, 31506.715569207656, 29056.18084796001, 86957.82385455356, 21019.316154883487]




------ epoch 8 ------
Best score: 511679.7734559846
Current score: 512075.5941493139
level 1 gate error: 0
level 2 gate error: [0.8803953398476155]
experts error: [69983.79194259127, 67271.92002338613, 95642.53633135122, 35751.81016138761, 46770.155348418775, 28438.158252652203, 31441.78815789247, 29276.797027156998, 86546.45688604312, 20951.299623094244]
8 epochs in total
Mean absolute value for 4 is 0.37518774373610836




In [8]:
print("--- %s seconds ---" % (time.time() - start_time))

--- 5313.278757810593 seconds ---
