In [None]:
"""
This file is expected to be ran after persistent homology features have been created in 
a (default) folder structure like:

/complete_data
    /sample_1
        /clique_dictionaries
        /gen
        /edge_vals
    /sample_2
        ...
    ...

This file takes a MatLab array sub_i_sess_j.mat representing the data obtained from subject i during session j
The MatLab array file is expected to have target-class field 'fMRI_labels_selected_thresh' used for classification.

In this analysis, we only use the 1st homology dimension.
"""

import os
from os.path import join
import pickle as pickle
import pickle as pk
from sklearn.neural_network import MLPClassifier
from scipy.io import loadmat
from compute_PH_features import get_birth_and_death
import Holes as ho
import numpy as np
from copy import deepcopy
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import time

def load_feature_data(sample_data_pths):
    """
    Load the features and target data into a dictionary 
    """
    feature_data = {}
    target_counter = 0
    for dp in sample_data_pths:
        sample_idx = int(dp.split('/')[-1].split('_')[1])
        gen_pth = join(dp, 'gen')
        feature_data[sample_idx] = {}.fromkeys(['raw_pth', 'target_cls'])
        feature_data[sample_idx]['raw_pth'] = [join(gen_pth, x) for x in os.listdir(gen_pth)]
        feature_data[sample_idx]['target_cls'] = targets[target_counter]
        target_counter += 1
    return feature_data

def check_keys(feature_data, targets):
    """
    Check to make sure we have the same number of samples as trials
    """
    keys = feature_data.keys()
    for i in range(targets.shape[0]):
        try:
            assert i in keys
        except AssertionError:
            raise RuntimeError("There isn't the same number of generated feature "
                               "sets as there are trials. Please investigate!")

def get_barcode(gen_lst, title=''):
    return ho.barcode_creator(gen_lst, title=title)


def summary(gen_lst, homology_group=1):
    return gen_lst[homology_group].summary()


def get_features_and_target_data(feature_data_pths, hom_dim=1):
    feature_data_array = {}
    target_data_array = {}
    
    for sample in feature_data_pths.keys():
        sample_gens_per_time_window = feature_data_pths[sample]['raw_pth']  
        feature_data_array[sample] = []
        for gen_pth in sample_gens_per_time_window:
            generators = pickle.load(open(gen_pth, 'r'))
            feature_data_array[sample].append(generators[hom_dim]) # 1 for H1 group

        target_data = feature_data_pths[sample]['target_cls']
        target_data_array[sample] = target_data  # H1 group
            
    return feature_data_array, target_data_array

# def get_features_and_target_data(feature_data_pths, hom_dim=1):
#     """
#     Load the training and target data for a particular homology dimension
#     """
#     feature_data_array = {}
#     target_data_array = {}
    
#     def get_persistence_lifetime(cycle):
#         return float(cycle.end) - float(cycle.start)
    
#     def pad(lst, ln):
#         if len(lst) < ln:
#             diff = ln - len(lst)
#             pad = [0]*diff
#             lst += pad
#         return lst

#     for sample in feature_data_pths.keys():
#         sample_gens_per_time_window = feature_data_pths[sample]['raw_pth']
#         sample_feature_data = []  # [births, deaths, num_time_windows]
#         for gen_pth in sample_gens_per_time_window:
#             time_window_birth_deaths = []
            
#             feature_data = pickle.load(open(gen_pth, 'r'))
#             sample_cycles = feature_data[hom_dim]
            
#             for cycle in sample_cycles:
#                 cycle_lifetime = get_persistence_lifetime(cycle)
#                 sample_feature_data.append((cycle, cycle_lifetime))  # H1 group
                
#         sorted_feature_data = [[float(i[0].start), float(i[0].end)] for i in sorted(sample_feature_data, key=lambda x:x[1])] # sort by ascending lifetime
#         feature_data_array[sample] = sorted_feature_data
#         target_data = feature_data_pths[sample]['target_cls']
#         target_data_array[sample] = target_data 


#     for sample in feature_data_pths.keys():
        

#     return feature_data_array, target_data_array


def show_persistence_across_trial(gens):
    """
    Plot the persistence diagrams of a given trial. Since each trial 
    has multiple windows, and thus multiple generators, shows the change 
    within one trial.
    """
    for i in gens:
        get_barcode(i)


def get_summaries_across_trial(gens):
    for i in gens:
        summary(i)

def get_feature_data_with_targets(feature_data_array):
    """
    Get complete training data with targets
    """
    train_x = []
    max_bi = 0

    for sample in feature_data_array.keys():
        sample_data = []  # [[[b,d] x 14] x n] --> n x 14 x max({b}) x max({d})
        for gen_file in feature_data_array[sample]:
            b, d = get_birth_and_death(gen_file)
            if len(b) > max_bi:
                max_bi = len(b)
            sample_data.append([b, d])
        train_x.append(sample_data)

    print "max b_i is: {0}".format(max_bi)
    print "Number of samples is: {0}".format(len(train_x))
    return train_x, max_bi

def get_sample_data(train_x, max_bi):
    """
    Get train_x max dimensions
    Each sample n_i has features --> 14 x 2 x max{b+i} where max{b_i} is max over all samples
    """
    padded_train_x = np.zeros((len(train_x), 14 * 2 * max_bi))
    sample_array_template = np.zeros((14, 2, max_bi))
    i = 0
    for sample in train_x:
        j = 0
        for gen_window in sample:  # 14 of these
            b, d = gen_window
            # Pad b and d
            b_pad = pad(b, max_bi)
            d_pad = pad(d, max_bi)
            sample_features = deepcopy(sample_array_template)
            sample_features[j][0][0:] = b_pad
            sample_features[j][1][0:] = d_pad
            j += 1
        # Reshape
        sample_features = sample_features.reshape(1, -1)
        padded_train_x[i][0:] = sample_features
        i += 1
    return padded_train_x
        
if __name__ == '__main__':

    """
    Path parameters
    """
    data_pth = 'data/completed_data/'
    repo_dir = os.getcwd()
    subject_1_session_1_pth = join(repo_dir, 'data/raw_data/sub1_ses1.mat')
    """
    Load Data
    """
    sample_data_pths = [join(data_pth, x) for x in os.listdir(data_pth)]
    mat = loadmat(subject_1_session_1_pth)  # load mat-file
    targets = mat['fMRI_labels_selected_thresh']
    feature_data_pths = load_feature_data(sample_data_pths)
    check_keys(feature_data_pths, targets)
    feature_data_array, target_data_array = get_features_and_target_data(feature_data_pths)

    train_x, max_bi = get_feature_data_with_targets(feature_data_array)
    sample_features = get_sample_data(train_x, max_bi)
    """
    Get train_y
    """
    train_y = np.array(target_data_array.values()[0:len(train_x)])
    """
    Get padded_train_x
    """
    padded_train_x = get_sample_data(train_x, max_bi)
    print "padded train_x shape is: {0}".format(padded_train_x.shape)
    print "train_y shape is: {0}".format(train_y.shape)
    """
    Perform prediction
    """
    mlp = MLPClassifier(hidden_layer_sizes=(500, 250, 100, 50), max_iter=100000, )
    X_train, X_test, y_train, y_test = \
    train_test_split(padded_train_x, train_y, test_size=0.33, random_state=42)
    t1 = time.time()
    mlp.fit(X_train, y_train)
    t2 = time.time()
    print (t2 - t1)
    preds = mlp.predict(X_test) 
    print "preds shape is:{0}".format(preds.shape)
    print classification_report(y_test, preds)