In [1]:
import copy

import numpy as np
import pickle
import os
from tigramite.toymodels import structural_causal_processes as toys
from utils import links_to_cyclic_graph
from generate_data import unionize_graphs, intersect_graphs
from matplotlib import pyplot as plt

import argparse
import networkx as nx

from config_generator import generate_configurations, generate_string_from_params, generate_name_from_params

import tigramite.plotting as tp

In [2]:
from metrics import get_results

In [3]:
def get_cycle_metrics(para_setup, metrics_result_path):
    
    N, density, max_lag, pc_alpha, sample_size, regime_children_known, nb_changed_links, nb_regimes, nb_repeats, cycles_only, remove_only, use_cmiknnmixed, use_regressionci, save_folder = para_setup

    results = []

    if regime_children_known == True:
        regime_known = 'True'
    elif regime_children_known == False:
        regime_known = None
    elif regime_children_known == 'and_parents':
        regime_known = 'and_parents'
        
    failed_res = 0
    for repeat in range(nb_repeats):
        try:
            res = get_results(para_setup, repeat)
        except:
            res = None
        if res is not None:
            results.append(res)
        else:
            failed_res += 1
        
    if len(results) > 0:
        num_cycles = calculate_cycles_union(results, N=N, regime_known=regime_known, max_lag=max_lag)
    else:
        num_cycles = None
        
    return num_cycles

In [19]:
def calculate_cycles_union(results, boot_samples=200, N=2, regime_known=None, max_lag=0):
    true_regime_links = [res['true_regime_links_with_regime_ind'] for res in results]

    regime_indicators = [res['regime_indicator'] for res in results]
    
    nb_regimes = len(true_regime_links[0])
    true_regime_graphs = {}
    for regime in range(nb_regimes):
        true_g = [toys.links_to_graph(links[regime], tau_max=max_lag) for links in true_regime_links]
        true_regime_graphs[regime] = np.stack(true_g)
    
    true_union_graphs = np.stack([unionize_graphs([true_regime_graphs[regime][i] for regime in range(nb_regimes)], nb_regimes) for i in range(len(true_regime_graphs[0]))])
    cycle_lengths = []
    
    # true_union_graphs.shape[0]
    for i in range(true_union_graphs.shape[0]):
        # tp.plot_graph(true_union_graphs[i, :, :, :])
        # tp.plot_graph(true_regime_graphs[0][i, :, :, :])
        # tp.plot_graph(true_regime_graphs[1][i, :, :, :])
        
        plt.show()
        num_nodes = true_union_graphs[i].shape[0]
        G = nx.DiGraph()
        for j in range(num_nodes):
            G.add_node(j)
        for j in range(num_nodes):
            for k in range(num_nodes):
                if true_union_graphs[i, j, k, :] == '-->':
                    G.add_edge(j, k)
                elif true_union_graphs[i, j, k, :] == '<--':
                    G.add_edge(k, j)
        try:
            cycle_length = len(nx.find_cycle(G, orientation="original"))
            cycle_lengths.append(cycle_length)
                # tp.plot_graph(true_union_graphs[i, :, :, :])
                # tp.plot_graph(true_regime_graphs[0][i, :, :, :])
                # tp.plot_graph(true_regime_graphs[1][i, :, :, :])
            # print(len(nx.find_cycle(G, orientation="original")))
        except:
            print('no cycle')

    return cycle_lengths
    

In [24]:
config_path = './../update_configs/no_cmi_cycles_no_remove.yaml'  # Path to your YAML configuration file
# config_path = './../update_configs/no_cmi_no_cycles_no_remove_larger_smaller.yaml'  # Path to your YAML configuration file

results_folder, all_configurations = generate_configurations(config_path)

for configuration in all_configurations:
# for i in range(3):
    
    configuration = all_configurations[i]
    metrics_result_path = configuration[0][-1] + '/metrics_v3/'
    
    file_name_union = metrics_result_path + generate_name_from_params(configuration[0]) + '_union.dat'
    file_name_regimes = metrics_result_path + generate_name_from_params(configuration[0]) + '_regimes.dat'
    file_name_avg = metrics_result_path + generate_name_from_params(configuration[0]) + '_avg_regimes.dat'
    cycle_lengths = get_cycle_metrics(configuration[0], metrics_result_path)
    print(configuration, np.unique(cycle_lengths, return_counts=True))
    

((8, 0.4, 0, 0.05, 200, True, 1, 2, 100, True, False, False, 1.0, '/home/pope_oa/repos/extreme_events/update_cycles_no_remove'), '8-0.4-0-0.05-200-True-1-2-100-True-False-False-1.0') (array([2, 3]), array([69,  2]))
((8, 0.4, 0, 0.05, 200, True, 1, 2, 100, True, False, False, 1.0, '/home/pope_oa/repos/extreme_events/update_cycles_no_remove'), '8-0.4-0-0.05-200-True-1-2-100-True-False-False-1.0') (array([2, 3]), array([69,  2]))
((8, 0.4, 0, 0.05, 200, True, 1, 2, 100, True, False, False, 1.0, '/home/pope_oa/repos/extreme_events/update_cycles_no_remove'), '8-0.4-0-0.05-200-True-1-2-100-True-False-False-1.0') (array([2, 3]), array([69,  2]))
((8, 0.4, 0, 0.05, 200, True, 1, 2, 100, True, False, False, 1.0, '/home/pope_oa/repos/extreme_events/update_cycles_no_remove'), '8-0.4-0-0.05-200-True-1-2-100-True-False-False-1.0') (array([2, 3]), array([69,  2]))
((8, 0.4, 0, 0.05, 200, True, 1, 2, 100, True, False, False, 1.0, '/home/pope_oa/repos/extreme_events/update_cycles_no_remove'), '8-0.4-