In [2]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
from castle.algorithms import PC
from castle.common.priori_knowledge import PrioriKnowledge
import os
import networkx as nx
import gravis as gv

import sys
if '../' not in sys.path:
    sys.path.append('../')
import utils.process as process
import utils.params as params

In [3]:
basefolder_data='../datasets' #downlaoded from zenodo
sites=['tropical','boreal','temperate']

np.random.seed(23143)
chosen=np.random.choice(10000, size=4000, replace=False)

# 1. Monthly aggregation of integrated gradients and PC algorithm

1.1 Monthly aggregation of daily IG values:

In [None]:
features={}
for site in sites:
    integrated_gradients=pickle.load(open(f"{basefolder_data}/{site}/integrated_gradients.p", "rb"))
    monthly_integrated_gradients=process.divide_into_months(integrated_gradients, mode='sum')
    features[site]=monthly_integrated_gradients
    features[site]=features[site][chosen]  

1.2 PC algorithm with gcastle:

In [10]:
restore=False
filename=f'./cm_with_prior.p'

if not os.path.exists(filename) or not restore:
    results={}
    for site in sites:
        results[site]={}
else:
    results=pickle.load(open(filename,'rb'))
    
alphas=[0.01]
for site in sites:
    for alpha in alphas:
        priori_knowledge = PrioriKnowledge(51)
        for i in range(51):
            for j in range(51):
                if i%17 > j%17:
                    priori_knowledge.add_forbidden_edge(i, j)
        print(site,alpha)
        if restore and alpha in results[site].keys():
            print('Restored')
        else:
            pc = PC(alpha=alpha, priori_knowledge=priori_knowledge)
            pc.learn(features[site])
            results[site][alpha]=np.array(pc.causal_matrix)
        pickle.dump(results,open(filename,'wb'))
            

tropical 0.01
boreal 0.01
temperate 0.01


# 2. Correlation graphs layout with gravis

In [14]:
id_to_month_name_with_driver=dict(zip(np.arange(51),params.month_names_with_drivers))
node_sizes={}
for site in sites:
    integrand=pickle.load(open(f"{basefolder_data}/{site}/integrated_gradients.p", "rb" ))
    integrand_divided=process.divide_into_months(integrand, mode='std')
    node_sizes[site]=np.mean(integrand_divided, axis=0)
    node_sizes[site]=dict(zip(params.month_names_with_drivers,node_sizes[site]))

In [12]:
def add_node_color(graph, communities):
    colors = ['blue', 'orange', 'green', 'red', 'pink', 'purple','yellow']
    ncomm=len(communities)
    print('Number of communities for graph coloring: ',ncomm)
    if ncomm>len(colors):
        raise NotImplementedError('Too many communities: increase colors')
    for community, color in zip(communities, colors[:ncomm]):
        for node in community:
            graph.nodes[node]['color'] = color
    return


def remove_prev_year_nodes(graph):
    """Remove nodes whose label ends with '_p'."""
    nodes_to_remove = [n for n in graph.nodes if str(n).endswith('_p')]
    graph.remove_nodes_from(nodes_to_remove)


def symmetrize_matrix(matrix):
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            if matrix[i, j] != 0:  # Only modify if the value is non-zero
               matrix[j, i] = 1 
    for i in range(matrix.shape[0]):
        matrix[i, i] = 0  # Ensure diagonal is set to 0
    return matrix

In [13]:
remove_prev_year=True
color='features'
data=pickle.load(open('./cm_with_prior.p','rb'))
alpha=0.01

for site in sites:

    matrix=data[site][alpha]
    matrix=symmetrize_matrix(matrix)

    graph=nx.from_numpy_array(matrix)
    graph=nx.relabel_nodes(graph,id_to_month_name_with_driver)
    nx.set_node_attributes(graph, node_sizes[site], 'size')

    if remove_prev_year:
        remove_prev_year_nodes(graph)

    # Assignment of node sizes
    if color=='communities':
        communities = nx.algorithms.community.greedy_modularity_communities(graph)
        add_node_color(graph, communities)

    elif color=='seasons':
        if remove_prev_year:
            add_node_color(graph, params.seasons_groups_with_drivers_no_prev_year)
        else:
            add_node_color(graph, params.seasons_groups_with_drivers)

    elif color=='features':
        if remove_prev_year:
            add_node_color(graph, params.features_groups_with_drivers_no_prev_year)
        else:
            add_node_color(graph, params.features_groups_with_drivers)

    else:

        raise NotImplementedError('Coloring not implemented for this option')

    fig=gv.d3(graph, use_node_size_normalization=True, node_size_normalization_max=30,
        use_edge_size_normalization=True, edge_size_data_source='weight', 
        zoom_factor=0.7, graph_height=800, node_label_size_factor=1.3, edge_size_factor=0.2)

    fig.export_html(f'pc_{site}_{color}.html', overwrite=True)

Number of communities for graph coloring:  3
Number of communities for graph coloring:  3
Number of communities for graph coloring:  3
