In [None]:
#%%

from pathlib import Path
import os, sys

# project_path = Path(__file__).absolute().parents[2]
project_path = Path().cwd()
os.sys.path.append(project_path.as_posix())
#%%
from src.MyModule.utils import *

config = load_config()
input_path = get_path("data/processed/2_produce_data/synthetic_decoded/")

ouput_path = get_path("data/processed/3_evaluate_data/")
#%%

syn_path = input_path.joinpath("Synthetic_data_epsilon10000_50.csv")
original_path = get_path("data/processed/preprocess_1/train_ori_50.pkl")

synthetic = pd.read_csv(syn_path)
original_data = pd.read_pickle(original_path)


#%%
import pandas as pd
import numpy as np
from scipy.linalg import norm
from scipy.spatial.distance import euclidean



#%%
class Variable:

    '''
    properties : data type and data with list
    functions : sort values and count values
    '''
    
    def __init__(self,
                 dtype : str,
                 data : list):

        self.dtype = dtype
        self.data = data
        if isinstance(self.data, pd.Series):
            self.data = self.data.tolist()

    def __repr__(self):
        print(self.data)
        return self.dtype

    def count_values(self):
        '''
        count values in a list and returns dictionary
        '''
        NotImplementedError

    def probability_distribution(self):
        '''
        creates probability distribution
        '''
        NotImplementedError

#%%
class ContinuousVariable(Variable):

    def __init__(self, dtype, data):
        super().__init__(dtype, data)
    
    def count_values(self, bins, range = None, density = False):
        """
        returns count information
        """
        counts, _ = np.histogram(self.data, bins, range = range, density = density)
        return counts

    def probability_distribution(self, bins, range = None):
        """
        returns distribution histogram series
        """
        counts = self.count_values(bins) 
        total = counts.sum()
        return pd.Series(data = counts / total)
#%%

class CategoricalVariable(Variable):

    def __init__(self, dtype, data):
        super().__init__(dtype, data)
        self.categories = set(self.data) 

    def count_values(self):
        counts = [self.data.count(element) for element in self.categories]
        return counts
    
    def probability_distribution(self):
        counts = self.count_values()
        data = np.array(counts) / sum(counts)
        return pd.Series(data = data, index=self.categories)

#%%

ori_sex = CategoricalVariable('categorical', original_data.BSPT_SEX_CD)
#%%
ori_sex.probability_distribution()



#%%
def hellinger_distance(var1 : Variable,
                       var2 : Variable,
                       bins = None):

    """
    calculates the hellinger distance of two variables
    dtype : categorical or continuous
    var1, var2 are variables
    """
    assert var1.dtype == var2.dtype, "the two data type are not the same!"
    
    if var1.dtype == 'categorical' :
        var1prob, var2prob = var1.probability_distribution(), var2.probability_distribution()

        var1prob.name = "var1"
        var2prob.name = "var2"

        df = pd.merge(var1prob, var2prob, left_index=True, right_index=True, how = "outer")
        df = df.fillna(0)

        return df.apply(lambda x : x['var1'])


        # (1/np.sqrt(2))* np.sqrt()

    else :
        '''dtype is continuous'''
        pass
    
    
#%%
ori_sex = CategoricalVariable('categorical', original_data.BSPT_SEX_CD)
syn_sex = CategoricalVariable('categorical', synthetic.BSPT_SEX_CD)


#%%
hellinger_distance(ori_sex, syn_sex)

#%%
        


#%% define hellinger distance