In [1]:
from mpi4py import MPI # ERROR https://stackoverflow.com/questions/36156822/error-when-starting-open-mpi-in-mpi-init-via-python
import numpy as np
import os, sys, pickle, time
from datetime import datetime, date

from tigramite import data_processing as pp
from tigramite.pcmci import PCMCI
from tigramite.independence_tests.parcorr import ParCorr
from tigramite.independence_tests.gpdc import GPDC
from tigramite.independence_tests.cmiknn import CMIknn 
from tigramite.independence_tests.cmisymb import CMIsymb
from tigramite import plotting as tp
import pandas as pd

import matplotlib
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import pyplot as plt

from utils_pcmci import *
from utils_pca_fun import *
from geo_field_jakob import GeoField
import pickle

# Default communicator
COMM = MPI.COMM_WORLD

In [2]:
import itertools

# Check if the 'dateseries' arrays are equal
def load_varimax_data(variables, seasons_mask, model_name, n_comps, mask):
    """ Load all the results from PCA-VARIMAX and concatanate by columns the n 
    principal components for the different variables and seasons specified, 
    as well as the mask for the time series."""

    data_dict = {}
    for var, months in itertools.product(variables, seasons_mask):
        # print(model_name[var], seasons_mask[months])
        
        # create file name, load data and save it to a dictionary
        file_name = './runs/train/train_varimax_%s_3dm_comps-%d_months-%s.bin' % (model_name[var], n_comps[var], seasons_mask[months]) # (model_name[0], n_comps, months)
        datadict = pickle.load(open(file_name, 'rb'))
        
        if mask == 'unmasked':
            ts_to_use = 'ts_unmasked'

            data_dict[f'{var}_{months}'] = {'results': datadict['results'], 
                                            'time_mask': datadict['results']['time_mask'], 
                                            'dateseries': datadict['results']['time'][:]}
        elif mask == 'masked':
            ts_to_use = 'ts_masked'

            data_dict[f'{var}_{months}'] = {'results': datadict['results'], 
                                            'time_mask': datadict['results']['time_mask'][~datadict['results']['time_mask']][1::3], 
                                            'dateseries': datadict['results']['time'][~datadict['results']['time_mask']][:]}
                
    
    # Check if the 'dateseries' and 'time_mask' arrays are equal to confirm the time series can be concatenated
    check = []
    for season in seasons_mask.keys():
        for serie in ['dateseries', 'time_mask']:
            if serie == 'dateseries':
                all_equal = (data_dict[variables[0]+'_'+season][serie].shape == data_dict[variables[1]+'_'+season][serie].shape) # and \
                #(data_dict[variables[1]+'_'+season][serie].shape == data_dict[variables[2]+'_'+season][serie].shape)    
            else:
                all_equal = (data_dict[variables[0]+'_'+season][serie] == data_dict[variables[1]+'_'+season][serie]).all() #and \
                #(data_dict[variables[1]+'_'+season][serie] == data_dict[variables[2]+'_'+season][serie]).all()
                    

            check.append(all_equal)
            if not all_equal:
                print('%s in season %s not equal' % (serie, season))
                break

    # Create lists of the data and mask to concatenate
    if all(check):
        concat_list = [
            data_dict[f'{var}_{months}']['results'][ts_to_use]
            for months in seasons_mask.keys()
            for var in variables
        ]

        concat_mask_list = []
        for months in seasons_mask.keys():
            for var in variables:
                T, N = data_dict[f'{var}_{months}']['results'][ts_to_use].shape

                temp_time_mask = data_dict[f'{var}_{months}']['time_mask']
                concat_mask_list.append(np.repeat(temp_time_mask.reshape(T, 1), N,  axis=1))

        # Ensure dimensions match before concatenation
        concat_shapes = [arr.shape[0] for arr in concat_list]
        if len(set(concat_shapes)) != 1:
            raise ValueError("Arrays to concatenate must have the same shape. Found shapes: {}".format(concat_shapes))
        concatenated_data = np.ma.concatenate(concat_list, axis=1)
        concatenated_mask_data = np.ma.concatenate(concat_mask_list, axis=1)

    # If the series are not equal for the same season and for different variables, return None
    else:
        concatenated_data = None
        concatanated_mask_data = None

    return data_dict, concatenated_data, concatenated_mask_data

def filter_results_for_variables(arr, indices):
    """ Keep only the results of specified variables 
    to reduce the complexity of graph visualization."""

    # To understand the shape of matrices in resdict['results']: 
    # (i, j, lag) = (from, to, lag) or (parents, variable, lag). 
    # For example, (3, 5, 5) indicates that variable 3 is a causal 
    # parent of variable 5 with a lag of 5. So, to extract the parent 
    # of a variable j at a specified lag: resdict['results']['graph'][:,j,lag]
    
    # resdict['results']['graph'][:,5,5] -> Parents of variable 5 at lag 5
    
    # New array filled with empty strings
    if arr.dtype == 'float64':
        new_array = np.zeros_like(arr)
    else:
        new_array = np.full(arr.shape, '', dtype=object)

    for j in indices: # for variable of interest
        for i in range(arr.shape[0]): # check all variables for parents
            for lag in range(arr.shape[2]): # check all lags
                new_array[i, j, lag] = arr[i, j, lag]

    return new_array

def create_individual_plots(result_dict, season, pdf):
    """ Create individual plots for each component in the specified season. 
    The plots show the causal parents of one variable (component)."""

    for i in range(len(result_dict[season].keys())):
        tp.plot_graph(
        figsize=(12, 12),
        val_matrix=result_dict[season][i]['val_matrix'],
        graph=result_dict[season][i]['graph'],
        var_names=selected_components,
        link_colorbar_label='cross-MCI',
        link_label_fontsize = 6,
        node_colorbar_label='auto-MCI',
        show_autodependency_lags=False,
        arrow_linewidth= 3, 
        edge_ticks= 0.2, 
        node_ticks= 0.1, 
        node_size = 0.07, 
        arrowhead_size= 12,
        node_label_size= 10,

        )
        plt.savefig(pdf, format='pdf')
        plt.close()

def causal_graphs_precipitation_individual(var):
    """ A wrapper function to create individual plots for each component in the specified season. """

    # components related to precipitation rate
    variables_of_interest = {season: [] for season in seasons_mask.keys()}

    for i, name in enumerate(comps_order_file['name']):
        if var in name:
            season = comps_order_file['name'][i].split('_')[-1]

            # print(comps_order_file['comp_number'][i], comps_order_file['name'][i])
            variables_of_interest[season].append(comps_order_file['comps'][i])
            # if 60 not in variables_of_interest[season]:
            #     variables_of_interest[season].append(60)

    precipitation_results_ind = {season: 
                                {comp: 
                                {result_arr: 
                                filter_results_for_variables(resdict['results'][result_arr], variables_of_interest[season][comp:comp+1])
                                for result_arr in resdict['results'].keys() if result_arr != 'conf_matrix'} 
                                for comp in range(len(variables_of_interest[season]))}
                                for season in seasons_mask.keys()}
    
    for season in seasons_mask.keys():
        with PdfPages(f'./plots/Causal_parents_individual_components_{str(list(n_comps.values()))}comps_{season}_{method_arg}_{period_length}_{model}_{mask}.pdf') as pdf:
            create_individual_plots(precipitation_results_ind, season, pdf)


In [3]:
#### Parameters to run ####
train = True 
n_comps = {'sst': 15, 'prate': 5}
verbosity = 0
period_length = 38 * 2
mask = 'unmasked'
ip = 1
# time_bin_length = 3 # used for daily data

# seasons_mask = {'DJF': [12, 1, 2], 'MAM': [3, 4, 5], 'JJA': [6, 7, 8], 'SON': [9, 10, 11]}  
seasons_mask = {'FMA': [2, 3, 4], 'MJJ': [5, 6, 7], 'ASO': [8, 9, 10], 'NDJ': [11, 12, 1]}
              
variables = [
    'sst', 
    # 'mtpr', 
    'prate',
    # 'sst_amo',
    # 'msl_amo',
    # 'd2m',
    # 'u10'
    ]

model_name = {
    'sst': 'FULL_ERA5_SST_1940-2024_converted_detrend.nc',
    # 'mtpr': 'ERA5_mean_precipitation_1940-2024_converted_detrend.nc',
    'prate': 'PRATE_NCEP_NCAR_Reanalysis_1948-2024.nc',
    # 'sst_amo': 'AMO_ERA5_SST_1940-2024_converted_detrend.nc',
    # # 'msl': 'ERA5_mean_SLP_1940-2024_converted_detrend.nc',
    # 'msl_amo': 'AMO_ERA5_mean_SLP_1940-2024_converted_detrend.nc',
    # 'd2m': 'ERA5_10m-Ucomp_1940-2024_converted_detrend.nc',
    # 'u10': 'ERA5_2m-temp_1940-2024_converted_detrend.nc'
    }

total_comps = int((sum(n_comps.values()) / len(n_comps)) * len(seasons_mask) * len(model_name))


selected_components=['c'+str(i) for i in range(1,total_comps+1)] #  + 2 to include the component that is the precipitation rate
comp_names = [f'{var}_{months}' for months in seasons_mask.keys() for var in variables for _ in range(n_comps[var])]
# comp_names = comp_names + ['allmtpr_DJF'] # only adding teh DJR for convinience
print(selected_components)

# comps_order_file = pd.read_csv('./selected_comps_NCEP2_djf.csv')
# In this dictionary, the 'comps' column is the selected component index for the component position in Name
comps_order_file = pd.DataFrame({'comp_number': ['c'+str(i) for i in range(1, total_comps+1)], # + 2 to include the component that is the precipitation rate
                                 'comps': [i for i in range(total_comps)], # + 1 to include the component that is the precipitation rate
                                 'name': [ i + '_' + j for i, j in zip(selected_components, comp_names)]})

# if mask == 'masked':
#     temp = []
#     for i in range(len(fulldata)):
#         # np.mean(reshaped_mtpr[i:i+12, ], axis=0)
#         temp.append(np.mean(reshaped_mtpr[i:i+12, ], axis=0))
#     reshaped_mtpr = np.array(temp)

['c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9', 'c10', 'c11', 'c12', 'c13', 'c14', 'c15', 'c16', 'c17', 'c18', 'c19', 'c20', 'c21', 'c22', 'c23', 'c24', 'c25', 'c26', 'c27', 'c28', 'c29', 'c30', 'c31', 'c32', 'c33', 'c34', 'c35', 'c36', 'c37', 'c38', 'c39', 'c40', 'c41', 'c42', 'c43', 'c44', 'c45', 'c46', 'c47', 'c48', 'c49', 'c50', 'c51', 'c52', 'c53', 'c54', 'c55', 'c56', 'c57', 'c58', 'c59', 'c60', 'c61', 'c62', 'c63', 'c64', 'c65', 'c66', 'c67', 'c68', 'c69', 'c70', 'c71', 'c72', 'c73', 'c74', 'c75', 'c76', 'c77', 'c78', 'c79', 'c80']


In [4]:
datadict, fulldata, fulldata_mask = load_varimax_data(variables, seasons_mask, model_name, n_comps, mask)
# if mask == 'masked':
#     fulldata = np.ma.concatenate([fulldata, reshaped_mtpr], axis=1)
#     fulldata_mask = np.ma.concatenate([fulldata_mask, np.repeat(True, fulldata.shape[0]).reshape(fulldata.shape[0], 1)], axis=1)

print("Fulldata shape = %s" % str(fulldata.shape))
print("Fulldata masked shape = %s" % str(fulldata_mask.shape))

Fulldata shape = (912, 80)
Fulldata masked shape = (912, 80)


In [6]:
model = "_".join(variables)
months = "_".join(seasons_mask.keys())

for method_arg in ['pcmci']:

    # Based on the full model time available, we chunk up the time axis into
    # as many periods of length "length" (in years) we can fit into the full model time

    period_indices = [0, fulldata.shape[0]]

    print("period_indices ", period_indices)
     
    for ip, period_start_index in enumerate(period_indices[:-1]): # for ip, period_start_index in enumerate(period_indices[:-1]):
        datadict, fulldata, fulldata_mask = load_varimax_data(variables, seasons_mask, model_name, n_comps, mask)
        dateseries = datadict[next(iter(datadict.keys()))]['dateseries'][:] # dateseries is the same for all variables and seasons
        # fulldata = np.ma.concatenate([fulldata, reshaped_mtpr], axis=1)
        # fulldata_mask = np.ma.concatenate([fulldata_mask, np.repeat(False, fulldata.shape[0]).reshape(fulldata.shape[0], 1)], axis=1)
        
        
        T, N = fulldata.shape

        period_end_index = period_indices[ip+1]
        print('Period: ', period_start_index, period_end_index, period_end_index-period_start_index) #, time_axis[period_start_index], time_axis[period_end_index]

        # Slice to the period of interest
        fulldata = fulldata[period_start_index:period_end_index, :]
        fulldata_mask = fulldata_mask[period_start_index:period_end_index, :]


        # Print information of the data so far
        print(fulldata.shape)
        print("Fulldata shape = %s" % str(fulldata.shape))
        print("Fulldata masked shape = %s" % str(fulldata_mask.shape))
        print("Unmasked samples %d" % (fulldata_mask[:,0]==False).sum())

        time_bin_length = 3
        print("Aggregating data to time_bin_length=%s" %time_bin_length)

        ## Time bin data
        # fulldata = pp.time_bin_with_mask(fulldata, time_bin_length=time_bin_length)[0]
        # fulldata_mask = pp.time_bin_with_mask(fulldata_mask, time_bin_length=time_bin_length)[0] > 0.5
        # print("Fulldata after binning shape = %s" % str(fulldata.shape))
        # print("Fulldata after binning masked shape = %s" % str(fulldata_mask.shape))
        
        # Initialize minimum and maximum lags
        TAU_MIN = 1
        TAU_MAX = 5

        # Only use selected indices
        selected_comps_indices = [int(comps_order_file[comps_order_file['comp_number']==i]['comps'].values[0]) for i in selected_components]

        # link_assumptions initialization for PCMCI
        link_assumptions = None

        # This is for the deprecated version that uses 'selected_links' parameter 
        # # Selected links may be used to restricted estimation to given links.
        # # Used to tell pcmci.run_pc_stable() and pcmci.run_mci() to only search for links into j variable.
        # selected_links = {n: {m: [(i, -t) for i in selected_comps_indices for \
        #                           t in range(0, TAU_MAX)] if m == n else [] for m in selected_comps_indices} for n in selected_comps_indices}                
                    
        fulldata = fulldata[:, selected_comps_indices]
        fulldata_mask = fulldata_mask[:, selected_comps_indices]
        print(fulldata.shape)
        print(fulldata_mask.shape)


        dataframe = pp.DataFrame(fulldata, mask=fulldata_mask, )

        
        print("Fulldata shape = %s" % str(dataframe.values[0].shape)) 
        print("Unmasked samples %d" % (dataframe.mask[0][:,0]==False).sum()) 

        T, N = dataframe.values[0].shape
        print('T: ', T, 'N: ', N)

        resdict = {
            "CI_params":{
                'significance':'analytic', 
                # 'use_mask':True,
                'mask_type':['y'],
                'recycle_residuals':False,
                }
        }

        # Chosen conditional independence test
        cond_ind_test = ParCorr(verbosity=verbosity, **resdict['CI_params'] )
            # significance='analytic', 
            # use_mask=True,
            # mask_type=['y'],
            # recycle_residuals=True,
            # verbosity=verbosity)

        # Create master PCMCI object
        pcmci_master = PCMCI(
                dataframe=dataframe,
                cond_ind_test=cond_ind_test,
                verbosity=0)

        _int_sel_links = pcmci_master._set_link_assumptions(link_assumptions, TAU_MIN, TAU_MAX)

        # Used to tell pcmci.run_pc_stable() to only search for links into j variable.
        link_assumptions_parallelized = {n: 
                                        {m: _int_sel_links[m] if m == n else [] for m in selected_comps_indices} 
                                        for n in selected_comps_indices}

        resdict = {
            "PC_params":{                    
                # Significance level in condition-selection step. If a list of levels is is
                # provided or pc_alpha=None, the optimal pc_alpha is automatically chosen via model-selection.
                'pc_alpha':None,
                # Minimum time lag (must be >0)
                'tau_min':TAU_MIN,
                # Maximum time lag
                'tau_max':TAU_MAX,
                # Maximum cardinality of conditions in PC condition-selection step. The recommended default choice is None to leave it unrestricted.
                'max_conds_dim':None,
                # Selected links may be used to restricted estimation to given links.
                # 'selected_links': None, # Deprecated
                'selected_variables' : selected_comps_indices, #range(N), #selected_comps_indices, # Deprecated
                'link_assumptions': link_assumptions_parallelized,
                # Optionalonally specify variable names
                # 'var_names':range(N),
                'var_names': selected_comps_indices,
                },

            "MCI_params":{
                # Minimum time lag (can also be 0)
                'tau_min':0,
                # Maximum time lag
                'tau_max':TAU_MAX,
                # Maximum number of parents of X to condition on in MCI step, leave this to None to condition on all estimated parents.
                'max_conds_px':None,
                # Selected links may be used to restricted estimation to given links.
                # 'selected_links': None, # Deprecated
                'link_assumptions': link_assumptions_parallelized,
                # Alpha level for MCI tests (just used for printing since all p-values are stored anyway)
                'alpha_level' : 0.05,
                }
            }
        
        # Store results in file
        # file_name = './runs/pcmci_results'+'/pcmci_results_%s_3dm_comps-%d_months-%s_%s_%s_%s.bin' % (model, n_comps, months, method_arg, period_length, ip) 
        file_name = './runs/pcmci_results/test'+'/pcmci_results_%s_3dm_comps-%s_months-%s_%s_%s_%s_%s.bin' % (model, str(list(n_comps.values())), months, method_arg, period_length, ip, mask)

        print(file_name)
        
        #
        #  Start of the script
        #
        if COMM.rank == 0:
        # if 3 == 0:

            # Only the master node (rank=0) runs this
            if verbosity > -1:
                print("\n##\n## Running Parallelized Tigramite PC algorithm\n##"
                        "\n\nParameters:")
                print("\nindependence test = %s" % cond_ind_test.measure
                        + "\ntau_min = %d" % resdict['PC_params']['tau_min']
                        + "\ntau_max = %d" % resdict['PC_params']['tau_max']
                        + "\npc_alpha = %s" % resdict['PC_params']['pc_alpha']
                        + "\nmax_conds_dim = %s" % resdict['PC_params']['max_conds_dim'])
                print("\n")
        
            # Split selected_variables into however many cores are available.
            splitted_jobs = split(resdict['PC_params']['selected_variables'], COMM.size)
            if verbosity > -1:
                print("Splitted selected_variables = ", splitted_jobs)
        else:
            splitted_jobs = None
        
        
        ##
        ##  PC algo condition-selection step
        ##
        # Scatter jobs across cores.
        scattered_jobs = COMM.scatter(splitted_jobs, root=0)
        
        print("\nCPU %d estimates parents of %s" % (COMM.rank, scattered_jobs))
        
        # Now each rank just does its jobs and collects everything in a results list.
        results = []
        time_start = time.time()
        for j_index, j in enumerate(scattered_jobs):
            # Estimate conditions
            (j, pcmci_of_j, parents_of_j) = run_pc_stable_parallel(j, dataframe, cond_ind_test, verbosity=verbosity, params=resdict['PC_params'])
        
            results.append((j, pcmci_of_j, parents_of_j))
        
            num_here = len(scattered_jobs)
            current_runtime = (time.time() - time_start)/3600.
            current_runtime_hr = int(current_runtime)
            current_runtime_min = 60.*(current_runtime % 1.)
            estimated_runtime = current_runtime * num_here / (j_index+1.)
            estimated_runtime_hr = int(estimated_runtime)
            estimated_runtime_min = 60.*(estimated_runtime % 1.)
            # print ("\t# CPU %s task %d/%d: %dh %.1fmin / %dh %.1fmin: Variable %s" % (COMM.rank, j_index+1, num_here, 
            #                         current_runtime_hr, current_runtime_min, 
            #                         estimated_runtime_hr, estimated_runtime_min,  resdict['PC_params']['var_names'][j]))
        
        
        
        # Gather results on rank 0.
        results = MPI.COMM_WORLD.gather(results, root=0)
        
        
        if COMM.rank == 0:
            # Collect all results in dictionaries and send results to workers
            all_parents = {}
            pcmci_objects = {}
            for res in results:
                for (j, pcmci_of_j, parents_of_j) in res:
                    all_parents[j] = parents_of_j[j]
                    pcmci_objects[j] = pcmci_of_j
            print(pcmci_objects[0].__dict__.keys())
            #if verbosity > -1:
            #    print("\n\n## Resulting condition sets:")
            #    for j in [var for var in all_parents.keys()]:
            #       pcmci_objects[j]._print_parents_single(j, all_parents[j],
            #           pcmci_objects[j].p_max[j],
            #           pcmci_objects[j].p_max[j]) ERROR IN GETTING p_max[] attribute
                    #pcmci_objects[j]._print_parents_single(j, all_parents[j],
                    #                        pcmci_objects[j].test_statistic_values[j], # ERROR
                    #                        pcmci_objects[j].p_max[j])
        
            if verbosity > -1:
                print("\n##\n## Running Parallelized Tigramite MCI algorithm\n##"
                        "\n\nParameters:")
        
                print("\nindependence test = %s" % cond_ind_test.measure
                        + "\ntau_min = %d" % resdict['MCI_params']['tau_min']
                        + "\ntau_max = %d" % resdict['MCI_params']['tau_max']
                        + "\nmax_conds_px = %s" % resdict['MCI_params']['max_conds_px'])
                
                print("Master node: Sending all_parents and pcmci_objects to workers.")
            
            for i in range(1, COMM.size):
                COMM.send((all_parents, pcmci_objects), dest=i)
        
        else:
            if verbosity > -1:
                print("Slave node %d: Receiving all_parents and pcmci_objects..."
                        "" % COMM.rank)
            (all_parents, pcmci_objects) = COMM.recv(source=0)
        
        
        
        ##
        ##   MCI step
        ##
        # Scatter jobs again across cores.
        scattered_jobs = COMM.scatter(splitted_jobs, root=0)
        
        # Now each rank just does its jobs and collects everything in a results list.
        results = []
        for j_index, j in enumerate(scattered_jobs):
            # print("\n\t# Variable %s (%d/%d)" % (var_names[j], j_index+1, len(scattered_jobs)))
            
            (j, results_in_j) = run_mci_parallel(j, pcmci_objects[j], all_parents, params=resdict['MCI_params'])
            results.append((j, results_in_j))
        
            num_here = len(scattered_jobs)
            current_runtime = (time.time() - time_start)/3600.
            current_runtime_hr = int(current_runtime)
            current_runtime_min = 60.*(current_runtime % 1.)
            estimated_runtime = current_runtime * num_here / (j_index+1.)
            estimated_runtime_hr = int(estimated_runtime)
            estimated_runtime_min = 60.*(estimated_runtime % 1.)
            # print ("\t# CPU %s task %d/%d: %dh %.1fmin / %dh %.1fmin: Variable %s" % (COMM.rank, j_index+1, num_here, 
            #                         current_runtime_hr, current_runtime_min, 
            #                         estimated_runtime_hr, estimated_runtime_min,  resdict['PC_params']['var_names'][j]))
        
        
        
        # Gather results on rank 0.
        results = MPI.COMM_WORLD.gather(results, root=0)
        
        
        if COMM.rank == 0:
            # Collect all results in dictionaries
            # 
            if verbosity > -1:
                print("\nCollecting results...")
            all_results = {}
            for res in results:
                # print('res: ', res)
                for (j, results_in_j) in res:
                    # print('j: ', j)
                    # print('results_in_j: ', results_in_j)
                    for key in results_in_j.keys():
                        # print(f'key: {key}')
                        if results_in_j[key] is None:  
                            all_results[key] = None
                        else:
                            if key not in all_results.keys():
                                if key == 'p_matrix':
                                    all_results[key] = np.ones(results_in_j[key].shape)
                                elif key == 'graph':
                                    all_results[key] = np.zeros(results_in_j[key].shape, dtype=object)
                                else:
                                    all_results[key] = np.zeros(results_in_j[key].shape)
                                all_results[key][:,j,:] = results_in_j[key][:,j,:]
                                # print('done')
                            else:
                                # print('all_results[key].shape: ', all_results[key].shape)
                                # print(all_results[key][:,j,:])
                                # print('results_in_j[key].shape: ', results_in_j[key].shape)
                                # print(results_in_j[key][:,j,:])
                                all_results[key][:,j,:] =  results_in_j[key][:,j,:]

            p_matrix=all_results['p_matrix']
            val_matrix=all_results['val_matrix']
            conf_matrix=all_results['conf_matrix']
        
            sig_links = (p_matrix <= resdict['MCI_params']['alpha_level'])
        
            if verbosity > -1:
                print("\n## Significant links at alpha = %s:" % resdict['MCI_params']['alpha_level'])
                for j in resdict['PC_params']['selected_variables']:
        
                    links = dict([((p[0], -p[1] ), np.abs(val_matrix[p[0], 
                                    j, abs(p[1])]))
                                    for p in zip(*np.where(sig_links[:, j, :]))])
        
                    # Sort by value
                    sorted_links = sorted(links, key=links.get, reverse=True)
        
                    n_links = len(links)
        
                    string = ""
                    string = ("\n    Variable %s has %d "
                                "link(s):" % (resdict['PC_params']['var_names'][j], n_links))
                    for p in sorted_links:
                        string += ("\n        (%s %d): pval = %.5f" %
                                    (resdict['PC_params']['var_names'][p[0]], p[1], 
                                    p_matrix[p[0], j, abs(p[1])]))
        
                        string += " | val = %.3f" % (
                            val_matrix[p[0], j, abs(p[1])])
        
                        if conf_matrix is not None:
                            string += " | conf = (%.3f, %.3f)" % (
                                conf_matrix[p[0], j, abs(p[1])][0], 
                                conf_matrix[p[0], j, abs(p[1])][1])
        
                    print(string)
        
        
            if verbosity > -1:
                print("Pickling to ", file_name)

            resdict['results'] = all_results
            file = open(file_name, 'wb')
            pickle.dump(resdict, file, protocol=-1)        
            file.close()

            # Saving the causal parents (input for causaully informed NNs) in an easier to read format to build the networks
            parents_dict = {i: {} for i in range(len(resdict['PC_params']['selected_variables']))}

            for j in parents_dict.keys():
                links, lags = np.where(sig_links[:, j, :])
                parents_dict[j]['selected_features'] = links
                parents_dict[j]['lags'] = lags
                parents_dict[j]['selected_features_names'] = []
                parents_dict[j]['n_features'] = len(links)
                parents_dict[j]['val_matrix'] = []
                parents_dict[j]['p_values'] = []
                

                for p in zip(*np.where(sig_links[:, j, :])):
                    parents_dict[j]['val_matrix'].append(val_matrix[p[0], j, p[1]])
                    parents_dict[j]['p_values'].append(p_matrix[p[0], j, p[1]])

                    if p[1] == 0:
                        parents_dict[j]['selected_features_names'].append(f'c{p[0]+1}')
                    else:
                        parents_dict[j]['selected_features_names'].append(f'c{p[0]+1}_lag_{p[1]}')


            file_name = './runs/pcmci_results/test'+'/variable_selection_pcmci_%s_3dm_comps-%s_months-%s_%s_%s_%s_%s.bin' % (model, str(list(n_comps.values())), months, method_arg, period_length, ip, mask)
            print('Saving variable selection from PCMCI to:', file_name)

            file = open(file_name, 'wb')
            pickle.dump(parents_dict, file, protocol=-1)
            file.close()

            # Create and save individual components causal graphs for the precipitation rate 
            causal_graphs_precipitation_individual(variables[1])
            print("Plots saved to ./plots/Causal_parents_individual_components_{n_comps}comps_{season}_{method_arg}_{period_length}_{ip}.pdf")


period_indices  [0, 912]
Period:  0 912 912
(912, 80)
Fulldata shape = (912, 80)
Fulldata masked shape = (912, 80)
Unmasked samples 228
Aggregating data to time_bin_length=3
(912, 80)
(912, 80)
Fulldata shape = (912, 80)
Unmasked samples 228
T:  912 N:  80
./runs/pcmci_results/test/pcmci_results_sst_prate_3dm_comps-[15, 5]_months-FMA_MJJ_ASO_NDJ_pcmci_76_0_unmasked.bin

##
## Running Parallelized Tigramite PC algorithm
##

Parameters:

independence test = par_corr
tau_min = 1
tau_max = 5
pc_alpha = None
max_conds_dim = None


Splitted selected_variables =  [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]]

CPU 0 estimates parents of [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 2

#### Causal graphs 

In [None]:
if False:
    #### Extracting results for precipitation rate variables ####

    # components related to precipitation rate
    variables_of_interest = {season: [] for season in seasons_mask.keys()}

    for i, name in enumerate(comps_order_file['name']):
        if 'mtpr' in name:
            season = comps_order_file['name'][i].split('_')[-1]

            # print(comps_order_file['comp_number'][i], comps_order_file['name'][i])
            variables_of_interest[season].append(comps_order_file['comps'][i])

    print(variables_of_interest)


    # components related to SST
    SST_vars = {season: [] for season in seasons_mask.keys()}

    for i, name in enumerate(comps_order_file['name']):
        if 'sst' in name:
            season = comps_order_file['name'][i].split('_')[-1]

            SST_vars[season].append(comps_order_file['comps'][i])

    SST_vars

In [None]:
if False:
    precipitation_results = {season: 
                            {result_arr: 
                            filter_results_for_variables(resdict['results'][result_arr], variables_of_interest[season]) 
                            for result_arr in resdict['results'].keys() if result_arr != 'conf_matrix'} 
                            for season in seasons_mask.keys()}

    precipitation_results_ind = {season: 
                                {comp: 
                                {result_arr: 
                                filter_results_for_variables(resdict['results'][result_arr], variables_of_interest[season][comp:comp+1])
                                for result_arr in resdict['results'].keys() if result_arr != 'conf_matrix'} 
                                for comp in range(len(variables_of_interest[season]))}
                                for season in seasons_mask.keys()}

    # precipitation_results_ind['DJF'][0]['graph'][:,5,5] # check that precipitation_results_ind is correct

    sst_results = {season: 
                            {result_arr: 
                            filter_results_for_variables(resdict['results'][result_arr], SST_vars[season]) 
                            for result_arr in resdict['results'].keys() if result_arr != 'conf_matrix'} 
                            for season in seasons_mask.keys()}

    sst_results_ind = {season: 
                    {comp: 
                        {result_arr: 
                        filter_results_for_variables(resdict['results'][result_arr], SST_vars[season][comp:comp+1])
                        for result_arr in resdict['results'].keys() if result_arr != 'conf_matrix'} 
                        for comp in range(len(SST_vars[season]))}
                        for season in seasons_mask.keys()}

    # precipitation_results_ind['DJF'][0]['graph'][:,5,5] # check that precipitation_results_ind is correct

In [None]:
if False:
    from tigramite import plotting as tp
    import matplotlib
    from matplotlib import pyplot as plt

    # parents_dict = pcmci_master.return_parents_dict(resdict['results']['graph'], val_matrix) # same shape but sparser

    tp.plot_graph(
        figsize=(24, 24),
        val_matrix=precipitation_results['DJF']['val_matrix'],
        graph=precipitation_results['DJF']['graph'],
        var_names=selected_components,
        link_colorbar_label='cross-MCI',
        node_colorbar_label='auto-MCI',
        show_autodependency_lags=False,
        arrow_linewidth= 4, 
        edge_ticks= 0.2, 
        node_ticks= 0.2, 
        node_size = 0.1, 
        arrowhead_size= 15, 
        node_label_size= 10, 
        
        ); 

    plt.show()

In [None]:
if False: 
    # Individual plots for each season component
    for season in seasons_mask.keys():
        with PdfPages(f'./plots/Causal_parents_individual_components_{season}_{method_arg}_{period_length}_{ip}.pdf') as pdf:
            create_individual_plots(precipitation_results_ind, season, pdf)

    print("Plots saved to ./plots/Causal_parents_individual_components_{season}_{method_arg}_{period_length}_{ip}.pdf")