In [2]:
# Import libraries
import os
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import geopandas as gpd
import gudhi
from tqdm import tqdm
from persim import PersistenceImager
import invr
import matplotlib as mpl

from pysal.lib import weights
from pysal.lib import weights


from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

from scipy.linalg import solve
from scipy.sparse.linalg import spsolve
import numpy as np

import scipy as sp

# Ignore FutureWarnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Matplotlib default settings
mpl.rcParams.update(mpl.rcParamsDefault)

In [3]:
def generate_adjacent_counties(dataframe, variable_name):
    """Generate adjacent counties based on given dataframe and variable."""
    filtered_df = dataframe
    adjacent_counties = gpd.sjoin(filtered_df, filtered_df, predicate='intersects', how='left')
    adjacent_counties = adjacent_counties.query('sortedID_left != sortedID_right')
    adjacent_counties = adjacent_counties.groupby('sortedID_left')['sortedID_right'].apply(list).reset_index()
    adjacent_counties.rename(columns={'sortedID_left': 'county', 'sortedID_right': 'adjacent'}, inplace=True)
    adjacencies_list = adjacent_counties['adjacent'].tolist()
    county_list = adjacent_counties['county'].tolist()
    merged_df = pd.merge(adjacent_counties, dataframe, left_on='county', right_on='sortedID', how='left')
    merged_df = gpd.GeoDataFrame(merged_df, geometry='geometry')
    return adjacencies_list, merged_df, county_list

In [4]:
def form_simplicial_complex(adjacent_county_list, county_list):
    """Form a simplicial complex based on adjacent counties."""
    max_dimension = 3
    V = invr.incremental_vr([], adjacent_county_list, max_dimension, county_list)
    return V

In [5]:
def create_variable_folders(base_path, variables):
    """Create folders for each variable."""
    for variable in variables:
        os.makedirs(os.path.join(base_path, variable), exist_ok=True)
    print('Done creating folders for each variable')


In [79]:
def generate_generalized_variance(simplices,data_frame, variable_name):

    selected_census = []

    for set in simplices:
        if len(set) == 2 or len(set) == 3:
            for vertice in set:
                if vertice not in selected_census:
                    selected_census.append(vertice)
    
    # print(f'selected census: {selected_census}')

    # print(data_frame.head(3))
    # print(data_frame.columns)

    filtered_census_df = data_frame.loc[data_frame["sortedID"].isin(selected_census)]

    # lattice stored in a geo-table
    wq = weights.contiguity.Queen.from_dataframe(filtered_census_df)
    neighbors_q = wq.neighbors

    QTemp = pd.DataFrame(*wq.full()).astype(int)
    QTemp = QTemp.multiply(-1)

    QTemp.index = filtered_census_df["sortedID"].values
    QTemp.columns = filtered_census_df["sortedID"].values

    # for each row in the fullMatrix dataframe sum the values in the row and take the absolute value and store in the diagonal
    for i in QTemp.index:
        QTemp.loc[i,i] = abs(QTemp.loc[i].sum())

    # print(neighbors_q)
    # print(filtered_census_df.head(3))
    # print(QTemp)


    # Marginal variance code -Multiple clusters

    # transform df to numpy array
    Q = QTemp.to_numpy()

    graph = csr_matrix(QTemp)
    n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True)

    print(f"Number of connected components: {n_components}")

    # get the simplices for each component(network)
    component_census = {i: [] for i in range(n_components)}  # Initialize a dictionary for simplices per component
    component_simplices = {i: [] for i in range(n_components)}  # Initialize a dictionary for simplices per component

    # if there are multiple components in the graph. Assign the simplices to the corresponding component
    if n_components>1:

        for idx, label in enumerate(labels):
            # print(idx, label)
            component_census[label].append(idx)
        
        for simplex in simplices:
            if len(simplex) == 2 or len(simplex) == 3:
                # take the first vertice in the simplex and check component census it belongs to
                vertice = simplex[0]
                for component in component_census:
                    
                    if vertice in component_census[component]:
                        # print(f'vertice {vertice} belongs to component {component}')
                        component_simplices[component].append(simplex)


    data_frame[variable_name+'_marginal_variance'] = None #delete this line

    # assign generalized variance for each n_component
    generalized_variance_dic = {i: [] for i in range(n_components)}  # Initialize a dictionary for each n_component

    for k in range(n_components):
        # print(k)

        # get the length of the labels array where the value is equal to i
        # print(len(labels[labels == k]))

        if len(labels[labels==k])==1:

            # get the index of the label
            index = np.where(labels==k)[0][0]
            # print(index)

            #this part is not written becase: does not exists

            # # get the index from Q_df
            # print(Q_df.index[index])

            # print(f"Region {k} is an isolated region")
            # print(f"Marginal Variances with FIPS: {list(zip(Qmatrix[0].index, marginal_variances))}")
            generalized_variance_dic[k] = 1  #CHECK THIS VALUE
        else:
            # print(f"Region {k} is a connected region")

            # get the location index to an array 
            index = np.where(labels == k)
            # print(index)

            # Extract the submatrix
            QQ = Q[np.ix_(index[0], index[0])]

            # print(QQ)

            n = QQ.shape[0]

            
            Q_jitter = QQ + sp.sparse.diags(np.ones(n)) * max(QQ.diagonal()) * np.sqrt(

                np.finfo(np.float64).eps

            )


            # inverse of precision (Q) is cov

            Q_perturbed = sp.sparse.csc_array(Q_jitter)

            b = sp.sparse.identity(n, format='csc')

            sigma = spsolve(Q_perturbed, b)


            # V \in Null(Q)

            V = np.ones(n)  # from pg. 6

            W = sigma @ V.T  # \Sigma * B in 3.17

            Q_inv = sigma - np.outer(W * solve(V @ W, np.ones(1)), W.T)

            # grabbing diag of cov gives var and

            # arithmetic mean in log-space becomes geometric mean after exp

            generalized_variance = np.exp(np.mean(np.log(np.diag(Q_inv))))  # equation in the paper use daba as 1
            # generalized_variance = np.exp(np.sum(np.log(np.diag(Q_inv))) / n) #same as above

            generalized_variance_dic[k] = generalized_variance

            # print(f"Generalized Variance: {generalized_variance}")

    return generalized_variance_dic, component_census, component_simplices

In [7]:
def process_state(state, selected_variables, selected_variables_with_censusinfo, base_path, PERSISTENCE_IMAGE_PARAMS, INFINITY):
    """Process data for a given state."""
    svi_od_path = os.path.join(data_path, state, state + '.shp')
    svi_od = gpd.read_file(svi_od_path)
    # # for variable in selected_variables:
    #     # svi_od = svi_od[svi_od[variable] != -999]

        
    svi_od_filtered_state = svi_od[selected_variables_with_censusinfo].reset_index(drop=True)

    # Get the unique counties
    unique_county_stcnty = svi_od_filtered_state['STCNTY'].unique()

    for county_stcnty in unique_county_stcnty:
        # Filter the dataframe to include only the current county
        county_svi_df = svi_od_filtered_state[svi_od_filtered_state['STCNTY'] == county_stcnty]

        # print("County")
        # print(county_svi_df)
    
        for variable_name in selected_variables:
            df_one_variable = county_svi_df[['STCNTY','FIPS', variable_name, 'geometry']]
            df_one_variable = df_one_variable.sort_values(by=variable_name)
            df_one_variable['sortedID'] = range(len(df_one_variable))
            df_one_variable = gpd.GeoDataFrame(df_one_variable, geometry='geometry')
            df_one_variable.crs = "EPSG:3395"

            adjacencies_list, adjacent_counties_df, county_list = generate_adjacent_counties(df_one_variable, variable_name)
            adjacent_counties_dict = dict(zip(adjacent_counties_df['county'], adjacent_counties_df['adjacent']))
            county_list = adjacent_counties_df['county'].tolist()
            simplices = form_simplicial_complex(adjacent_counties_dict, county_list)

            print(f'length of simplices: {len(simplices)}')

            if len(simplices)==0:
                print(f'No simplices for {variable_name} in {county_stcnty}')
                print(df_one_variable)
            else:
                print(f'State: {state}')
                print(f'County: {county_stcnty}')
                print(f'County: {variable_name}')

                # print("Simplices",simplices)

                generalized_variance = generate_generalized_variance(simplices=simplices,data_frame=df_one_variable, variable_name=variable_name)

                print(f'Generalized Variance: {generalized_variance}\n')

                # print(f'Generalized Variance: {generalized_variance}')

                # Generate persistence images based on the generalized variance
                # generate_persistence_images(simplices, df_one_variable, variable_name, county_stcnty, base_path, PERSISTENCE_IMAGE_PARAMS, generalized_variance)

            # break

        # break

In [8]:
# length of simplices: 79
# State: FL
# County: 12087
# County: EP_POV
# Number of connected components: 3

In [9]:
data_path = '/home/h6x/git_projects/ornl-svi-data-processing/processed_data/SVI/SVI2018_MIN_MAX_SCALED_MISSING_REMOVED'
selected_variables = [
         'EP_POV','EP_UNEMP', 'EP_NOHSDP', 'EP_UNINSUR', 'EP_AGE65', 'EP_AGE17', 'EP_DISABL', 
        'EP_SNGPNT', 'EP_LIMENG', 'EP_MINRTY', 'EP_MUNIT', 'EP_MOBILE', 'EP_CROWD', 'EP_NOVEH', 'EP_GROUPQ'
    ]

In [10]:
selected_variables_with_censusinfo = ['FIPS', 'STCNTY'] + selected_variables + ['geometry']

In [11]:
state = 'FL'

In [12]:
svi_od_path = os.path.join(data_path, state, state + '.shp')
svi_od = gpd.read_file(svi_od_path)

In [13]:
county_stcnty = '12087'

In [14]:
svi_od_filtered_state = svi_od[selected_variables_with_censusinfo].reset_index(drop=True)

In [15]:
# Filter the dataframe to include only the current county
county_svi_df = svi_od_filtered_state[svi_od_filtered_state['STCNTY'] == county_stcnty]

In [16]:
county_svi_df

Unnamed: 0,FIPS,STCNTY,EP_POV,EP_UNEMP,EP_NOHSDP,EP_UNINSUR,EP_AGE65,EP_AGE17,EP_DISABL,EP_SNGPNT,EP_LIMENG,EP_MINRTY,EP_MUNIT,EP_MOBILE,EP_CROWD,EP_NOVEH,EP_GROUPQ,geometry
1328,12087970500,12087,0.17,0.0,0.094194,0.189,0.266298,0.288732,0.119423,0.056,0.041348,0.172,0.032,0.01,0.023,0.048,0.006006,"POLYGON ((-80.45661 25.08716, -80.45599 25.086..."
1329,12087970600,12087,0.13,0.054,0.068387,0.125,0.409945,0.097183,0.185039,0.018,0.009188,0.183,0.226,0.182,0.0,0.019,0.02002,"MULTIPOLYGON (((-80.50134 25.04190, -80.49775 ..."
1330,12087970800,12087,0.073,0.017,0.08129,0.067,0.344751,0.173239,0.185039,0.027,0.007657,0.147,0.196,0.006,0.017,0.034,0.035035,"POLYGON ((-80.59884 24.96615, -80.59444 24.967..."
1331,12087970900,12087,0.075,0.025,0.047742,0.139,0.316022,0.126761,0.111549,0.012,0.012251,0.058,0.144,0.05,0.02,0.024,0.013013,"MULTIPOLYGON (((-80.66103 24.90022, -80.66065 ..."
1332,12087971001,12087,0.031,0.009,0.046452,0.07,0.41105,0.128169,0.06168,0.033,0.0,0.304,0.264,0.002,0.018,0.018,0.005005,"POLYGON ((-81.03029 24.72284, -81.02916 24.723..."
1333,12087971100,12087,0.168,0.024,0.15871,0.187,0.220994,0.305634,0.161417,0.123,0.128637,0.517,0.082,0.105,0.064,0.042,0.0,"POLYGON ((-81.06663 24.72410, -81.06532 24.725..."
1334,12087971200,12087,0.083,0.053,0.091613,0.254,0.292818,0.123944,0.111549,0.05,0.212864,0.386,0.29,0.086,0.008,0.099,0.004004,"POLYGON ((-81.09250 24.70338, -81.08595 24.708..."
1335,12087971401,12087,0.168,0.029,0.110968,0.182,0.282873,0.204225,0.190289,0.041,0.02144,0.213,0.007,0.132,0.062,0.043,0.026026,"POLYGON ((-81.40365 24.63094, -81.40055 24.634..."
1336,12087971402,12087,0.073,0.024,0.087742,0.103,0.21768,0.235211,0.203412,0.083,0.053599,0.246,0.016,0.013,0.0,0.028,0.0,"MULTIPOLYGON (((-81.45474 24.80844, -81.44351 ..."
1337,12087971502,12087,0.078,0.0,0.098065,0.177,0.309392,0.167606,0.136483,0.061,0.027565,0.067,0.017,0.153,0.02,0.017,0.0,"MULTIPOLYGON (((-81.52348 24.76011, -81.51207 ..."


In [17]:
variable_name = 'EP_POV'

In [18]:
df_one_variable = county_svi_df[['STCNTY','FIPS', variable_name, 'geometry']]
df_one_variable = df_one_variable.sort_values(by=variable_name)
df_one_variable['sortedID'] = range(len(df_one_variable))
df_one_variable = gpd.GeoDataFrame(df_one_variable, geometry='geometry')
df_one_variable.crs = "EPSG:3395"

adjacencies_list, adjacent_counties_df, county_list = generate_adjacent_counties(df_one_variable, variable_name)
adjacent_counties_dict = dict(zip(adjacent_counties_df['county'], adjacent_counties_df['adjacent']))
county_list = adjacent_counties_df['county'].tolist()
simplices = form_simplicial_complex(adjacent_counties_dict, county_list)

print(f'length of simplices: {len(simplices)}')

length of simplices: 79


In [19]:
simplices

[[0],
 [1],
 [2],
 [3],
 [4],
 [1, 4],
 [5],
 [6],
 [5, 6],
 [7],
 [8],
 [6, 8],
 [9],
 [1, 9],
 [4, 9],
 [1, 4, 9],
 [10],
 [7, 10],
 [11],
 [3, 11],
 [12],
 [13],
 [3, 13],
 [10, 13],
 [14],
 [4, 14],
 [15],
 [4, 15],
 [9, 15],
 [4, 9, 15],
 [14, 15],
 [4, 14, 15],
 [16],
 [0, 16],
 [17],
 [5, 17],
 [18],
 [11, 18],
 [14, 18],
 [19],
 [9, 19],
 [14, 19],
 [15, 19],
 [9, 15, 19],
 [14, 15, 19],
 [20],
 [0, 20],
 [16, 20],
 [0, 16, 20],
 [21],
 [7, 21],
 [10, 21],
 [7, 10, 21],
 [22],
 [2, 22],
 [12, 22],
 [23],
 [17, 23],
 [24],
 [2, 24],
 [25],
 [1, 25],
 [9, 25],
 [1, 9, 25],
 [19, 25],
 [9, 19, 25],
 [26],
 [4, 26],
 [14, 26],
 [4, 14, 26],
 [18, 26],
 [14, 18, 26],
 [27],
 [0, 27],
 [20, 27],
 [0, 20, 27],
 [23, 27],
 [28],
 [12, 28]]

In [20]:
# generalized variance code

In [21]:
selected_census = []

for set in simplices:
    if len(set) == 2 or len(set) == 3:
        for vertice in set:
            if vertice not in selected_census:
                selected_census.append(vertice)

In [22]:
filtered_census_df = df_one_variable.loc[df_one_variable["sortedID"].isin(selected_census)]

In [23]:
# lattice stored in a geo-table
wq = weights.contiguity.Queen.from_dataframe(filtered_census_df)
neighbors_q = wq.neighbors

 There are 3 disconnected components.
  W.__init__(self, neighbors, ids=ids, **kw)


In [24]:
QTemp = pd.DataFrame(*wq.full()).astype(int)
QTemp = QTemp.multiply(-1)

QTemp.index = filtered_census_df["sortedID"].values
QTemp.columns = filtered_census_df["sortedID"].values

# for each row in the fullMatrix dataframe sum the values in the row and take the absolute value and store in the diagonal
for i in QTemp.index:
    QTemp.loc[i,i] = abs(QTemp.loc[i].sum())

# print(neighbors_q)
# print(filtered_census_df.head(3))
# print(QTemp)


# Marginal variance code -Multiple clusters

# transform df to numpy array
Q = QTemp.to_numpy()

In [35]:
QTemp

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,19,20,21,22,23,24,25,26,27,28
0,3,0,0,0,0,0,0,0,0,0,...,0,-1,0,0,0,0,0,0,-1,0
1,0,3,0,0,-1,0,0,0,0,-1,...,0,0,0,0,0,0,-1,0,0,0
2,0,0,2,0,0,0,0,0,0,0,...,0,0,0,-1,0,-1,0,0,0,0
3,0,0,0,2,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,-1,0,0,5,0,0,0,0,-1,...,0,0,0,0,0,0,0,-1,0,0
5,0,0,0,0,0,2,-1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
6,0,0,0,0,0,-1,2,0,-1,0,...,0,0,0,0,0,0,0,0,0,0
7,0,0,0,0,0,0,0,2,0,0,...,0,0,-1,0,0,0,0,0,0,0
8,0,0,0,0,0,0,-1,0,1,0,...,0,0,0,0,0,0,0,0,0,0
9,0,-1,0,0,-1,0,0,0,0,5,...,-1,0,0,0,0,0,-1,0,0,0


In [25]:
graph = csr_matrix(QTemp)
n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True)

print(f"Number of connected components: {n_components}")

Number of connected components: 3


In [26]:
labels

array([0, 1, 2, 1, 1, 0, 0, 1, 0, 1, 1, 1, 2, 1, 1, 1, 0, 0, 1, 1, 0, 1,
       2, 0, 2, 1, 1, 0, 2], dtype=int32)

In [36]:
len(labels)

29

In [57]:
# Step 6: Group vertices by their component labels
component_simplices = {i: [] for i in range(n_components)}  # Initialize a dictionary for simplices per component


In [49]:
component_census = {i: [] for i in range(n_components)}  # Initialize a dictionary for simplices per component


In [58]:
component_census

{0: [0, 5, 6, 8, 16, 17, 20, 23, 27],
 1: [1, 3, 4, 7, 9, 10, 11, 13, 14, 15, 18, 19, 21, 25, 26],
 2: [2, 12, 22, 24, 28]}

In [67]:
component_simplices

{0: [], 1: [], 2: []}

In [52]:
for idx, label in enumerate(labels):

    # print(idx, label)
    component_census[label].append(idx)

In [53]:
component_census

{0: [0, 5, 6, 8, 16, 17, 20, 23, 27],
 1: [1, 3, 4, 7, 9, 10, 11, 13, 14, 15, 18, 19, 21, 25, 26],
 2: [2, 12, 22, 24, 28]}

In [68]:
for simplex in simplices:

    if len(simplex) == 2 or len(simplex) == 3:

        # take the first vertice in the simplex and check component census it belongs to
        vertice = simplex[0]
        # print("vertice",vertice)

        for component in component_census:
            # print("component",component)

            if vertice in component_census[component]:
                # print(f'vertice {vertice} belongs to component {component}')
                component_simplices[component].append(simplex)

            #     component_simplices[component].append(simplex)
        # break

    

In [69]:
component_simplices

{0: [[5, 6],
  [6, 8],
  [0, 16],
  [5, 17],
  [0, 20],
  [16, 20],
  [0, 16, 20],
  [17, 23],
  [0, 27],
  [20, 27],
  [0, 20, 27],
  [23, 27]],
 1: [[1, 4],
  [1, 9],
  [4, 9],
  [1, 4, 9],
  [7, 10],
  [3, 11],
  [3, 13],
  [10, 13],
  [4, 14],
  [4, 15],
  [9, 15],
  [4, 9, 15],
  [14, 15],
  [4, 14, 15],
  [11, 18],
  [14, 18],
  [9, 19],
  [14, 19],
  [15, 19],
  [9, 15, 19],
  [14, 15, 19],
  [7, 21],
  [10, 21],
  [7, 10, 21],
  [1, 25],
  [9, 25],
  [1, 9, 25],
  [19, 25],
  [9, 19, 25],
  [4, 26],
  [14, 26],
  [4, 14, 26],
  [18, 26],
  [14, 18, 26]],
 2: [[2, 22], [12, 22], [2, 24], [12, 28]]}

In [70]:
component_census[0]

[0, 5, 6, 8, 16, 17, 20, 23, 27]

In [None]:
for simplex in simplices:
    if len(simplex) == 1:
        # st.insert([simplex[0]], filtration=0.0)

In [71]:
component_simplices[0]

[[5, 6],
 [6, 8],
 [0, 16],
 [5, 17],
 [0, 20],
 [16, 20],
 [0, 16, 20],
 [17, 23],
 [0, 27],
 [20, 27],
 [0, 20, 27],
 [23, 27]]

In [74]:
for simplex in component_simplices[0]:
    if len(simplex) == 2:
        print(simplex)
        last_simplex = simplex[-1]
        print(last_simplex)
        filtration_value = df_one_variable.loc[df_one_variable['sortedID'] == last_simplex, variable_name].values[0]
        print(filtration_value)


[5, 6]
6
0.073
[6, 8]
8
0.075
[0, 16]
16
0.123
[5, 17]
17
0.13
[0, 20]
20
0.137
[16, 20]
20
0.137
[17, 23]
23
0.17
[0, 27]
27
0.197
[20, 27]
27
0.197
[23, 27]
27
0.197


In [None]:
for simplex in simplices:
    if len(simplex) == 2:
        last_simplex = simplex[-1]
        filtration_value = df_one_variable.loc[df_one_variable['sortedID'] == last_simplex, variable_name].values[0]
        # st.insert(simplex, filtration=filtration_value)



In [None]:
for simplex in simplices:
    if len(simplex) == 3:
        last_simplex = simplex[-1]
        filtration_value = df_one_variable.loc[df_one_variable['sortedID'] == last_simplex, variable_name].values[0]
        # st.insert(simplex, filtration=filtration_value)

In [75]:
component_census

{0: [0, 5, 6, 8, 16, 17, 20, 23, 27],
 1: [1, 3, 4, 7, 9, 10, 11, 13, 14, 15, 18, 19, 21, 25, 26],
 2: [2, 12, 22, 24, 28]}

In [76]:
len(component_census)

3

In [80]:
generalized_variance_dic, component_census, component_simplices = generate_generalized_variance(simplices=simplices,data_frame=df_one_variable, variable_name=variable_name)

Number of connected components: 3


 There are 3 disconnected components.
  W.__init__(self, neighbors, ids=ids, **kw)


In [81]:
generalized_variance_dic

{0: 1.0807413978231466, 1: 1.0463128075284294, 2: 0.7300371729906561}

In [82]:
component_census

{0: [0, 5, 6, 8, 16, 17, 20, 23, 27],
 1: [1, 3, 4, 7, 9, 10, 11, 13, 14, 15, 18, 19, 21, 25, 26],
 2: [2, 12, 22, 24, 28]}

In [83]:
component_simplices

{0: [[5, 6],
  [6, 8],
  [0, 16],
  [5, 17],
  [0, 20],
  [16, 20],
  [0, 16, 20],
  [17, 23],
  [0, 27],
  [20, 27],
  [0, 20, 27],
  [23, 27]],
 1: [[1, 4],
  [1, 9],
  [4, 9],
  [1, 4, 9],
  [7, 10],
  [3, 11],
  [3, 13],
  [10, 13],
  [4, 14],
  [4, 15],
  [9, 15],
  [4, 9, 15],
  [14, 15],
  [4, 14, 15],
  [11, 18],
  [14, 18],
  [9, 19],
  [14, 19],
  [15, 19],
  [9, 15, 19],
  [14, 15, 19],
  [7, 21],
  [10, 21],
  [7, 10, 21],
  [1, 25],
  [9, 25],
  [1, 9, 25],
  [19, 25],
  [9, 19, 25],
  [4, 26],
  [14, 26],
  [4, 14, 26],
  [18, 26],
  [14, 18, 26]],
 2: [[2, 22], [12, 22], [2, 24], [12, 28]]}

In [86]:
# get the keys of the dictionary component_census

In [92]:
for key in component_census.keys():
    print(key)

0
1
2


In [94]:
PERSISTENCE_IMAGE_PARAMS = {
        'pixel_size': 0.001,
        'birth_range': (0.0, 1.00),
        'pers_range': (0.0, 0.40),
        'kernel_params': {'sigma': 0.0003}
    }

In [100]:
per_images_per_subcomponent = []

for key in component_census.keys():
    print(key)
    
    generalized_variance = generalized_variance_dic[key]
    simplices_sub = component_simplices[key]
    census_sub = component_census[key]
    print(f'Generalized Variance: {generalized_variance}')
    print(f'Simplices: {simplices_sub}')
    print(f'Census: {census_sub}')


    # Generate persistence images based on the generalized variance

    st = gudhi.SimplexTree()
    st.set_dimension(2)

    for simplex in census_sub:
        print(simplex)
    #     # if len(simplex) == 1:
        st.insert([simplex], filtration=0.0)

    for simplex in simplices_sub:
        if len(simplex) == 2:
            last_simplex = simplex[-1]
            filtration_value = df_one_variable.loc[df_one_variable['sortedID'] == last_simplex, variable_name].values[0]
            st.insert(simplex, filtration=filtration_value)

    for simplex in simplices_sub:
        if len(simplex) == 3:
            last_simplex = simplex[-1]
            filtration_value = df_one_variable.loc[df_one_variable['sortedID'] == last_simplex, variable_name].values[0]
            st.insert(simplex, filtration=filtration_value)

    st.compute_persistence()
    persistence = st.persistence()

    intervals_dim0 = st.persistence_intervals_in_dimension(0)
    intervals_dim1 = st.persistence_intervals_in_dimension(1)
    pdgms = [[birth, death] for birth, death in intervals_dim1 if death < np.inf]

    # add interval dim 0  to the pdgms
    for birth, death in intervals_dim0:
        if death < np.inf:
            pdgms.append([birth, death])
        # elif death == np.inf:
            # pdgms.append([birth, INFINITY])
        

    # save_path = os.path.join(base_path, variable_name, county_stcnty)

    if len(pdgms) > 0:
        
        # print(f'Processing {variable_name} for {county_stcnty}')
        # print(f'Number of persistence diagrams: {len(pdgms)}')
        # print(intervals_dim1)
        # for i in range(len(intervals_dim1)):
        #     if np.isinf(pdgms[i][1]):
        #         pdgms[i][1] = 1
        #     if np.isinf(pdgms[i][0]):
        #         pdgms[i][0] = 1

        pimgr = PersistenceImager(pixel_size=0.01)
        pimgr.fit(pdgms)

        pimgr.pixel_size = PERSISTENCE_IMAGE_PARAMS['pixel_size']
        pimgr.birth_range = PERSISTENCE_IMAGE_PARAMS['birth_range']
        pimgr.pers_range = PERSISTENCE_IMAGE_PARAMS['pers_range']
        pimgr.kernel_params = PERSISTENCE_IMAGE_PARAMS['kernel_params']

        pimgs = pimgr.transform(pdgms)
        pimgs = np.rot90(pimgs, k=1) 
        per_images_per_subcomponent.append(pimgs)




0
Generalized Variance: 1.0807413978231466
Simplices: [[5, 6], [6, 8], [0, 16], [5, 17], [0, 20], [16, 20], [0, 16, 20], [17, 23], [0, 27], [20, 27], [0, 20, 27], [23, 27]]
Census: [0, 5, 6, 8, 16, 17, 20, 23, 27]
0
5
6
8
16
17
20
23
27
1
Generalized Variance: 1.0463128075284294
Simplices: [[1, 4], [1, 9], [4, 9], [1, 4, 9], [7, 10], [3, 11], [3, 13], [10, 13], [4, 14], [4, 15], [9, 15], [4, 9, 15], [14, 15], [4, 14, 15], [11, 18], [14, 18], [9, 19], [14, 19], [15, 19], [9, 15, 19], [14, 15, 19], [7, 21], [10, 21], [7, 10, 21], [1, 25], [9, 25], [1, 9, 25], [19, 25], [9, 19, 25], [4, 26], [14, 26], [4, 14, 26], [18, 26], [14, 18, 26]]
Census: [1, 3, 4, 7, 9, 10, 11, 13, 14, 15, 18, 19, 21, 25, 26]
1
3
4
7
9
10
11
13
14
15
18
19
21
25
26
2
Generalized Variance: 0.7300371729906561
Simplices: [[2, 22], [12, 22], [2, 24], [12, 28]]
Census: [2, 12, 22, 24, 28]
2
12
22
24
28


In [102]:
len(per_images_per_subcomponent)

3

In [103]:
per_images_per_subcomponent[0].shape

(400, 1000)

In [104]:
per_images_per_subcomponent[1].shape

(400, 1000)

In [105]:
per_images_per_subcomponent[2].shape

(400, 1000)

In [106]:
type(per_images_per_subcomponent[2])

numpy.ndarray

In [107]:
type(per_images_per_subcomponent)

list

In [108]:
import numpy as np

combined_array = np.concatenate(per_images_per_subcomponent, axis=0)

In [109]:
combined_array.shape

(1200, 1000)

In [110]:
A = per_images_per_subcomponent[0] + per_images_per_subcomponent[1] + per_images_per_subcomponent[2]

In [112]:
A.shape

(400, 1000)

In [117]:
print(per_images_per_subcomponent[0][12][40])
print(per_images_per_subcomponent[1][12][40])
print(per_images_per_subcomponent[2][12][40])

0.0
0.0
0.0


In [118]:
B = np.sum(per_images_per_subcomponent, axis=0)

In [119]:
B.shape

(400, 1000)