In [None]:
"""Functions to be called when needed since there are 3 patients """

import numpy as np
import pandas as pd
import sklearn as sk
import matplotlib.pyplot as plt
import seaborn as sns
import mne
import warnings  # Hide all warnings here
import random

mne.set_log_level('WARNING')
warnings.simplefilter("ignore")
warnings.filterwarnings("ignore",category=DeprecationWarning)


master_seizure_patient_dict = {}
xx
added_columns = ['seizure = 1','outer_left_hemi_sum','inner_left_hemi_sum','outter_right_hemi_sum',
                 'inner_right_hemi_sum','center_line','left_temple_to_left_rear','temple_to_temple',
                 'right_temple_to_right_rear']

program_functions = ['compare_describe_methods(dfx,dfy)',' sns_line(data,x,y',
                     'equally_sized_data_sets_leading_to_seizure(X: str, y: str, patient: str)',
                     'compare_describes(**kwargs)','distribution_plot(data=None, kind="hist")']


def load_pandas_data(filename):
	""" converts raw data from an .edf file to a dataframe"""

	mne.set_log_level('WARNING')
	warnings.simplefilter("ignore")

	patient = mne.io.read_raw(filename,preload=True)
	return patient.to_data_frame()


def seizure_set_only(seizure_set,start,end):
	"""

	Prepares a master dataframe for the analysis.
	:param seizure_set: an EEG with a seizure and then truncated after the seizure
	to reduce file size
	:param start: start time of the seizure as listed in the summary file from Children's hospital data set
	:param end: end time of the seizure as listed in the summary file from Children's hospital data set
	:return: a dataframe

	"""

	seizure_set.set_index('time',inplace=True)

	seizure_set['seizure = 1'] = 0

	seizure_set.loc[start: end,'seizure = 1'] = 1

	# seizure_set = pd.concat([seizure_set],axis=0,ignore_index=True)

	# left hemisphere location totals
	seizure_set['outer_left_hemi_sum'] = seizure_set[['FP1-F7','F7-T7','T7-P7','P7-O1']].sum(axis=1)
	seizure_set['inner_left_hemi_sum'] = seizure_set[['FP1-F3','F3-C3','C3-P3','P3-O1']].sum(axis=1)

	# right hemisphere location totals
	seizure_set['outter_right_hemi_sum'] = seizure_set[['FP2-F8','F8-T8','T8-P8-0','P8-O2']].sum(axis=1)
	seizure_set['inner_right_hemi_sum'] = seizure_set[['FP2-F4','F4-C4','C4-P4','P4-O2']].sum(axis=1)

	# center totals
	seizure_set['center_line'] = seizure_set[['FZ-CZ','CZ-PZ']].sum(axis=1)

	# left temple to rear totals
	seizure_set['left_temple_to_left_rear'] = seizure_set[['FZ-CZ','CZ-PZ']].sum(axis=1)

	# temple to temple totals
	seizure_set['temple_to_temple'] = seizure_set[['FT9-FT10']].sum(axis=1)

	# left temple to rear totals
	seizure_set['right_temple_to_right_rear'] = seizure_set[['FT10-T8','T8-P8-1']].sum(axis=1)

	# reorder columns for better visual early analysis
	new_cols = ['seizure = 1',
	            'FP1-F7','F7-T7','T7-P7','P7-O1','outer_left_hemi_sum',
	            'FP1-F3','F3-C3','C3-P3','P3-O1','inner_left_hemi_sum',
	            'FP2-F8','F8-T8','T8-P8-0','P8-O2','outter_right_hemi_sum',
	            'FP2-F4','F4-C4','C4-P4','P4-O2','inner_right_hemi_sum',
	            'FZ-CZ','CZ-PZ','center_line',
	            'P7-T7','T7-FT9','left_temple_to_left_rear',
	            'FT9-FT10','temple_to_temple',
	            'FT10-T8','T8-P8-1','right_temple_to_right_rear']

	seizure_set = seizure_set.reindex(columns=new_cols)
	seizure_set.set_index(create_int_index(seizure_set),inplace=True)

	locations = seizure_location(seizure_set)
	seizure_set = seizure_set.iloc[: locations[1] + 1,:]

	return seizure_set


def is_df_clean(df):
	""" Check if a dataframe is clean. """
	for i in range(df.count().all().tolist()):
		if not i:
			return f"There are null values in {df.columns[i]}: "f"Note: this is expected if column is 'seizures = 1' as most of the values are zero, but 1 where the seizure exist."
	for j in np.isinf(df).sum().any().to_list():
		if not j:
			return f'There are infinite values in {df.columns[j]}'


def create_int_index(df):
	"""Create and return an index that is `int` based"""
	return np.linspace(0,len(df),len(df),dtype=int)


def seizure_location(df):
	"""Return the index locations of the seizure."""

	seizure_location = df.index[df['seizure = 1'] != 0].tolist()
	# seizure_start, seizure_end = seizure_location[0], seizure_location[-1]

	return seizure_location[0],seizure_location[-1]


def make_needed_dataframe(master_df,columns: list,patient: str,master_dict) -> dict:
	# sourcery skip: dict-literal, merge-dict-assign
	"""
    Make a variety of  X, y training and test sets for a variety of EDA and modeling needs.
    """

	local_patient_dict = dict()

	# have a master with seizure column but NOT seizure columns
	local_patient_dict['all data minus seizure column'] = master_df.drop(columns[0],axis=1,inplace=False)

	# have a master without any added columns
	local_patient_dict['all data minus added columns'] = master_df.drop(columns,axis=1,inplace=False)

	# have a master with only summary columns
	local_patient_dict['X only summary columns'] = master_df[columns[1:]]

	# create the Xs
	local_patient_dict['X all columns'] = master_df[master_df['seizure = 1'] != 1]
	local_patient_dict['X no added columns'] = local_patient_dict['X all columns'].drop(columns,axis=1,
	                                                                                    inplace=False)
	local_patient_dict['X only summary columns'] = local_patient_dict['X all columns'][columns[1:]]

	# create the ys
	local_patient_dict['y all columns'] = master_df[master_df['seizure = 1'] != 0]
	local_patient_dict['y no added column'] = local_patient_dict['y all columns'].drop(columns,axis=1,
	                                                                                   inplace=False)
	local_patient_dict['y only summary columns'] = local_patient_dict['y all columns'][columns[1:]]

	master_dict[patient] = local_patient_dict

	return master_dict


def compare_describe_methods(dfx,dfy):
	global master_data_patient_dict
	X_train_describe = master_data_patient_dict['Patient 1 initial set'][dfx].describe().transpose()
	y_train_describe = master_data_patient_dict['Patient 1 initial set'][dfy].describe().transpose()
	trains_compared = pd.concat([X_train_describe,y_train_describe],axis=1)
	trains_compared.columns = ['count_X','mean_X','std_X','min_X','25%_X','50%_X','75%_X','max_X','count_y','mean_y',
	                           'std_y',
	                           'min_y','25%_y','50%_y','75%_y','max_y']
	return trains_compared.sort_index(axis=1)


def sns_line(data,x,y,fit_reg=True,n_boot=2000,seed=911,logx=False,truncate=True):
	return sns.regplot(data=data,x=x,y=y,fit_reg=fit_reg,n_boot=n_boot,seed=seed,logx=logx,truncate=truncate)


def equally_sized_data_sets_leading_to_seizure(X: str,y: str,patient: str):
	""" Create four separate data sets for analysis, each of are equal in length to the patient's seizure set and are chronologically ordered ending in the final set: the seizure set. The four sets are then:
	 X_set_normal --> X_set_pre_aura --> X_set_aura --> y_set_seizure
	"""

	y_length = len(y)
	X_original_length = len(X)

	X_set_aura_beginning_index = X_original_length - y_length
	X_set_pre_aura_beginning_index = X_set_aura_beginning_index - y_length
	X_set_normal_beginning = X_set_pre_aura_beginning_index - y_length

	patient = {'X_set_normal': X[X_set_normal_beginning:X_set_pre_aura_beginning_index], 'X_set_pre_aura': X[X_set_pre_aura_beginning_index:X_set_aura_beginning_index], 'X_set_aura': X[X_set_aura_beginning_index:], 'y_set_seizure': y}

	test = (
			(patient['X_set_normal'].index[0],patient['X_set_normal'].index[-1]),
			(patient['X_set_pre_aura'].index[0],patient['X_set_pre_aura'].index[-1]),
			(patient['X_set_aura'].index[0],patient['X_set_aura'].index[-1]),
			(patient['y_set_seizure'].index[0],patient['y_set_seizure'].index[-1]))

	print(test)
	return patient


def compare_describes(**kwargs):
	global patient

	# dfs = kwargs
	X_normal = kwargs['X_norm'].describe().transpose()
	X_pre = kwargs['X_pre_aura'].describe().transpose()
	X_aura = kwargs['X_aura'].describe().transpose()
	y_seizure = kwargs['y_set_seizure'].describe().transpose()

	sets_compared = pd.concat([X_normal,X_pre,X_aura,y_seizure],axis=1)

	sets_compared.columns = ['cnt_a_norm','mean_a_norm','std_a_norm','min_a_norm','25%_a_norm','50%_a_norm',
	                         '75%_a_norm','max_a_norm','count_b_pre_aura','mean_b_pre_aura','std_b_pre_aura',
	                         'min_b_pre_aura','25%_b_pre_aura','50%_b_pre_aura','75%_b_pre_aura','max_b_pre_aura',
	                         'count_c_aura','mean_c_aura','std_c_aura','min_c_aura','25%_c_aura','50%_c_aura',
	                         '75%_c_aura','max_c_aura','count_seizure','mean_seizure','std_seizure','min_seizure',
	                         '25%_seizure','50%_seizure','75%_seizure','max_seizure']

	return sets_compared.sort_index(axis=1)


def distribution_plot(data=None,kind='hist',legend=True):
	return sns.displot(data=data,kind=kind,legend=legend)
