In [44]:
import os
import json
import sympy
import requests

import pyciemss
import pyciemss.visuals.plots as plots
import pyciemss.visuals.vega as vega
import pyciemss.visuals.trajectories as trajectories

from mira.metamodel import *
from mira.modeling.amr.petrinet import AMRPetriNetModel, template_model_to_petrinet_json
from mira.sources.amr.petrinet import template_model_from_amr_json

In [45]:
%load_ext autoreload
%autoreload 1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [46]:
MODEL_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/"

model1 = os.path.join(MODEL_PATH, "SEIRD_base_model01_petrinet.json")
model2 = os.path.join(MODEL_PATH, "SEIRHD_base_model01_petrinet.json")
model3 = os.path.join(MODEL_PATH, "LV_sheep_foxes.json")

In [47]:
start_time = 0.0
end_time = 10.0
logging_step_size = 1
num_samples = 100
n = num_samples

In [48]:
result1 = pyciemss.sample(model1, end_time, logging_step_size, num_samples, start_time=start_time)
display(result1['data'].head())
result1['data'][['timepoint_id', 'sample_id', 'S_state', 'I_state', "R_state"]].to_csv('sir.csv', index= False) # 


Unnamed: 0,timepoint_id,sample_id,timepoint_unknown,persistent_beta_param,persistent_death_param,persistent_gamma_param,persistent_I0_param,S_state,I_state,E_state,R_state,D_state,infected_observable_state,dead_observable_state
0,0,0,0.0,0.092653,0.00496,0.150774,4.47884,19339996.0,4.47884,40.0,0.0,0.0,4.47884,0.0
1,1,0,1.0,0.092653,0.00496,0.150774,4.47884,19339996.0,12.114385,31.866884,1.284434,0.006403,12.114385,0.006403
2,2,0,2.0,0.092653,0.00496,0.150774,4.47884,19339996.0,17.08383,26.038223,3.502205,0.017459,17.08383,0.017459
3,3,0,3.0,0.092653,0.00496,0.150774,4.47884,19339996.0,20.208096,21.822861,6.318605,0.031499,20.208096,0.031499
4,4,0,4.0,0.092653,0.00496,0.150774,4.47884,19339996.0,22.059082,18.738466,9.502293,0.04737,22.059082,0.04737


### Plot histogram per state combination


In [49]:
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
from itertools import combinations
import networkx as nx
from pyciemss.visuals import plots, vega

# Configuration flags
bin_outliers = False
log = False
remove_duplicates = False
sir_dataset = pd.read_csv("sir.csv")

def remove_consecutive_duplicates(lst):
    """
    Removes consecutive duplicate elements from a list.

    Parameters:
    lst (list): The input list from which consecutive duplicates are to be removed.

    Returns:
    list: A new list with consecutive duplicates removed.
    """
    if not lst:  # if the list is empty, return it
        return lst
    new_lst = [lst[0]]  # add the first item of lst to new_lst
    for item in lst[1:]:  # iterate over the rest of lst
        if item != new_lst[-1]:  # if the current item is not the same as the last item in new_lst
            new_lst.append(item)  # add it to new_lst
    return new_lst

def process_dataset(sir_dataset, bin_outliers, log, remove_duplicates=False):
    """
    Processes the dataset by binning '_state' columns, potentially removing consecutive duplicates, 
    and plots histograms for each column.

    Parameters:
    sir_dataset (pd.DataFrame): The dataset containing state information, timepoint_id, and sample_id.
    bin_outliers (bool): Indicator for whether to bin outliers in a separate bin.
    log (bool): Indicator for whether to use a log scale for the bins.
    remove_duplicates (bool): Indicator for whether to remove consecutive duplicate bins.

    Returns:
    pd.DataFrame: The processed dataset with additional bin columns.
    dict: A dictionary storing the min and max values of each bin for each column.
    """
    # Determine the number of bins using Sturges' rule with a minimum number of bins
    num_bins = math.ceil(math.log2(sir_dataset.shape[0]) + 1)
    num_bins = max(5, num_bins)  # Ensure at least 5 bins
    print(f"Number of bins: {num_bins}")

    # Filter columns that contain '_state'
    sir_dataset_state = sir_dataset.loc[:, sir_dataset.columns.str.contains('_state')]
    print(f"Columns to be binned: {sir_dataset_state.columns.tolist()}")

    # Create a dictionary to store the bin edges for each '_state' column
    bin_dict = {}

    # Make a copy of the dataset to add bin columns
    sir_dataset_output = sir_dataset.copy()

    # Group by 'sample_id'
    for sample_id, group in sir_dataset_output.groupby('sample_id'):
        print(f"Processing sample_id: {sample_id}")

        for col in sir_dataset_state.columns:
            max_val = group[col].max()
            min_val = group[col].min()
            print(f"Binning column: {col}, Sample_id: {sample_id}, Min value: {min_val}, Max value: {max_val}")

            # Use a broader range for outlier binning or avoid outliers binning based solely on percentiles
            if log:
                if bin_outliers:
                    p1_log = np.log10(group[col].quantile(0.01))
                    p99_log = np.log10(group[col].quantile(0.99))
                    bin_edges = np.logspace(max(0.0, p1_log), p99_log, num=num_bins)
                    bin_edges = np.concatenate(([min_val], bin_edges, [max_val]))
                else:
                    bin_edges = np.logspace(np.log10(min_val), np.log10(max_val), num=num_bins)
            else:
                if bin_outliers:
                    bin_edges = np.linspace(group[col].quantile(0.01), group[col].quantile(0.99), num_bins)
                    bin_edges = np.concatenate(([min_val], bin_edges, [max_val]))
                else:
                    bin_edges = np.linspace(min_val, max_val, num_bins)

            print(f"Bin edges for column {col}, Sample_id {sample_id}: {bin_edges}")

            # Save the bin edges to the dictionary
            if col not in bin_dict:
                bin_dict[col] = {}
            bin_dict[col][sample_id] = [{i: (bin_edges[i], bin_edges[i+1])} for i in range(len(bin_edges)-1)]
            print(f"Bin dictionary for column {col}, Sample_id {sample_id}: {bin_dict[col][sample_id]}")

            # Add a new column to the dataset for the bin number of each value
            sir_dataset_output.loc[group.index, col + '_bin'] = group[col].apply(lambda x: np.digitize(x, bin_edges[:-1]))
            print(f"Binned data for column {col}, Sample_id {sample_id}: {sir_dataset_output.loc[group.index, col + '_bin'].unique()}")

            # Print the count per bin
            counts = sir_dataset_output.loc[group.index, col + '_bin'].value_counts(sort=False)
            print(f"Count per bin for column {col}, Sample_id {sample_id}:\n{counts}\n")

    # Combine bin information for all state columns into a single column
    sir_dataset_output['combined_bin'] = sir_dataset_output.apply(
        lambda row: '_'.join([f"{col.replace('_state_bin', '')}_{int(row[col])}" for col in sir_dataset_output.columns if '_bin' in col]), axis=1
    )

    # Optionally remove consecutive duplicate bins
    if remove_duplicates:
        sir_dataset_output['combined_bin'] = sir_dataset_output['combined_bin'].apply(lambda x: ' '.join(remove_consecutive_duplicates(x.split())))

    # Create a mapping from combined bin labels to shorter labels
    unique_bins = sir_dataset_output['combined_bin'].unique()
    label_mapping = {bin: f'bin_{i}' for i, bin in enumerate(unique_bins)}
    sir_dataset_output['short_bin'] = sir_dataset_output['combined_bin'].map(label_mapping)
    print(f"Unique bins: {unique_bins}")
    print(f"Label mapping: {label_mapping}")

    # Save the label mapping to a CSV file for reference
    label_mapping_df = pd.DataFrame(list(label_mapping.items()), columns=['original_label', 'short_label'])
    label_mapping_df.to_csv('label_mapping.csv', index=False)

    return sir_dataset_output, bin_dict

sir_dataset_output, bin_dict = process_dataset(sir_dataset, bin_outliers, log, remove_duplicates)


Number of bins: 12
Columns to be binned: ['S_state', 'I_state', 'R_state']
Processing sample_id: 0
Binning column: S_state, Sample_id: 0, Min value: 19339966.0, Max value: 19339996.0
Bin edges for column S_state, Sample_id 0: [19339966.         19339968.72727273 19339971.45454545 19339974.18181818
 19339976.90909091 19339979.63636364 19339982.36363636 19339985.09090909
 19339987.81818182 19339990.54545455 19339993.27272727 19339996.        ]
Bin dictionary for column S_state, Sample_id 0: [{0: (19339966.0, 19339968.727272727)}, {1: (19339968.727272727, 19339971.454545453)}, {2: (19339971.454545453, 19339974.181818184)}, {3: (19339974.181818184, 19339976.90909091)}, {4: (19339976.90909091, 19339979.636363637)}, {5: (19339979.636363637, 19339982.363636363)}, {6: (19339982.363636363, 19339985.09090909)}, {7: (19339985.09090909, 19339987.818181816)}, {8: (19339987.818181816, 19339990.545454547)}, {9: (19339990.545454547, 19339993.272727273)}, {10: (19339993.272727273, 19339996.0)}]
Binned 

### Plot histogram per state combination


In [50]:
def remove_consecutive_duplicates(lst):
    """
    Removes consecutive duplicate elements from a list.

    Parameters:
    lst (list): The input list from which consecutive duplicates are to be removed.

    Returns:
    list: A new list with consecutive duplicates removed.
    """
    if not lst:  # if the list is empty, return it
        return lst
    new_lst = [lst[0]]  # add the first item of lst to new_lst
    for item in lst[1:]:  # iterate over the rest of lst
        if item != new_lst[-1]:  # if the current item is not the same as the last item in new_lst
            new_lst.append(item)  # add it to new_lst
    return new_lst

def get_bin_lists(binned_data, *, n=None, fig_width=10):
    """
    Generates a list of state transition sequences for each sample.

    Parameters:
    binned_data (pd.DataFrame): The dataset with binned state information.
    n (int, optional): Number of samples to include. Defaults to None, meaning all samples.
    fig_width (int, optional): Width of the plot figure. Defaults to 10.

    Returns:
    list: A list of strings representing the sequence of state transitions for each sample.
    """
    bins_list = []
    for sample_id in binned_data['sample_id'].unique()[:n]:
        sample_df = binned_data[binned_data['sample_id'] == sample_id]
        bins = sample_df['combined_bin'].values
        bins_list.append(' '.join(bins.tolist()))

    return bins_list

In [51]:

bins_list =  get_bin_lists(sir_dataset_output)
bins_list

['S_11_I_1_R_1 S_11_I_5_R_1 S_11_I_8_R_2 S_11_I_10_R_3 S_11_I_11_R_4 S_11_I_11_R_5 S_9_I_11_R_6 S_9_I_11_R_8 S_7_I_11_R_9 S_1_I_11_R_10 S_5_I_11_R_11',
 'S_11_I_1_R_1 S_11_I_1_R_1 S_11_I_2_R_1 S_10_I_3_R_2 S_10_I_4_R_3 S_9_I_5_R_3 S_7_I_6_R_5 S_7_I_7_R_6 S_5_I_8_R_7 S_4_I_10_R_9 S_1_I_11_R_11',
 'S_11_I_1_R_1 S_11_I_1_R_1 S_10_I_1_R_1 S_10_I_2_R_2 S_9_I_3_R_2 S_9_I_3_R_3 S_9_I_4_R_4 S_6_I_6_R_5 S_5_I_7_R_7 S_3_I_9_R_9 S_1_I_11_R_11',
 'S_11_I_1_R_1 S_11_I_2_R_1 S_11_I_3_R_1 S_10_I_4_R_2 S_11_I_4_R_3 S_11_I_5_R_4 S_7_I_6_R_5 S_6_I_7_R_6 S_4_I_9_R_8 S_3_I_10_R_9 S_1_I_11_R_11',
 'S_11_I_1_R_1 S_11_I_2_R_1 S_11_I_2_R_1 S_10_I_3_R_2 S_10_I_4_R_2 S_9_I_5_R_3 S_6_I_6_R_4 S_6_I_7_R_6 S_5_I_8_R_7 S_2_I_10_R_9 S_1_I_11_R_11',
 'S_11_I_1_R_1 S_11_I_7_R_1 S_11_I_10_R_2 S_11_I_11_R_3 S_10_I_11_R_5 S_7_I_11_R_6 S_6_I_11_R_7 S_4_I_10_R_8 S_3_I_10_R_9 S_3_I_9_R_10 S_1_I_8_R_11',
 'S_11_I_1_R_1 S_11_I_2_R_1 S_11_I_2_R_1 S_10_I_3_R_2 S_9_I_4_R_3 S_8_I_5_R_4 S_7_I_6_R_5 S_6_I_7_R_6 S_4_I_8_R_8 S_4_I_10_

In [52]:
### Plot get sententree from bin list

In [53]:
%load_ext autoreload
%autoreload 1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [54]:
import sententree
import sententree_vega
import networkx as nx
import vega
import vl_convert

%aimport sententree
%aimport sententree_vega
%aimport vega

In [55]:
logging_keywords = ["TRACE", "DEBUG", "INFO", "LOG"]
def contains_any(line, keywords):
    return any([w in line for w in keywords])

def clean_line(l):
    no_date = " ".join(l.split(" ")[1:]).strip()
    return no_date.replace("::", "-").replace(":", "-")


In [56]:
# Install Graphviz (run this in a terminal if 'brew' is not accessible from Jupyter)
# !brew install graphviz

# Update the PATH environment variable
# import os
# os.environ['PATH'] += os.pathsep + '/usr/local/bin'


In [57]:

G = sententree.build_sententree(bins_list,
                                min_support = 1,
                                num_exemplars=3,
                                tag_with=sententree.tag_words_with_index)



In [58]:
### Turn sententree network into split chart

In [59]:
edges = G.edges(data=True)
edges

OutEdgeDataView([('S_11_I_1_R_1_0', 'S_11_I_1_R_1_1', {'weight': 13}), ('S_11_I_1_R_1_0', 'S_10_I_1_R_1_1', {'weight': 2}), ('S_11_I_1_R_1_0', 'S_11_I_2_R_1_1', {'weight': 13}), ('S_11_I_1_R_1_0', 'S_10_I_2_R_1_1', {'weight': 3}), ('S_11_I_1_R_1_0', 'S_11_I_3_R_1_1', {'weight': 11}), ('S_11_I_1_R_1_0', 'S_11_I_4_R_1_1', {'weight': 5}), ('S_11_I_1_R_1_0', 'S_11_I_5_R_1_1', {'weight': 8}), ('S_11_I_1_R_1_0', 'S_9_I_3_R_1_1', {'weight': 3}), ('S_11_I_1_R_1_0', 'S_9_I_2_R_1_1', {'weight': 2}), ('S_11_I_1_R_1_0', 'S_10_I_3_R_1_1', {'weight': 1}), ('S_11_I_1_R_1_0', 'S_7_I_6_R_1_1', {'weight': 1}), ('S_11_I_1_R_1_0', 'S_11_I_7_R_1_1', {'weight': 5}), ('S_11_I_1_R_1_0', 'S_11_I_8_R_1_1', {'weight': 3}), ('S_11_I_1_R_1_0', 'S_11_I_6_R_1_1', {'weight': 4}), ('S_11_I_1_R_1_0', 'S_10_I_8_R_1_1', {'weight': 1}), ('S_11_I_1_R_1_1', 'S_11_I_2_R_1_2', {'weight': 11}), ('S_11_I_1_R_1_1', 'S_10_I_2_R_1_2', {'weight': 1}), ('S_11_I_1_R_1_1', 'S_10_I_1_R_1_2', {'weight': 1}), ('S_11_I_2_R_1_2', 'S_10_I_3

In [60]:
nodes = G.nodes(data =True)
nodes

NodeDataView({'S_11_I_1_R_1_0': {'count': 1}, 'S_11_I_1_R_1_1': {'count': 1}, 'S_11_I_2_R_1_2': {'count': 1}, 'S_10_I_3_R_2_3': {'count': 1}, 'S_9_I_4_R_3_4': {'count': 1}, 'S_8_I_5_R_4_5': {'count': 1}, 'S_7_I_6_R_5_6': {'count': 1}, 'S_6_I_7_R_6_7': {'count': 1}, 'S_4_I_8_R_8_8': {'count': 1}, 'S_3_I_10_R_9_9': {'count': 1}, 'S_1_I_11_R_11_10': {'count': 1}, 'S_2_I_10_R_9_9': {'count': 1}, 'S_10_I_1_R_1_1': {'count': 1}, 'S_11_I_2_R_1_1': {'count': 1}, 'S_4_I_10_R_9_9': {'count': 1}, 'S_11_I_4_R_3_4': {'count': 1}, 'S_5_I_8_R_8_8': {'count': 1}, 'S_11_I_5_R_4_5': {'count': 1}, 'S_5_I_8_R_7_8': {'count': 1}, 'S_9_I_4_R_4_5': {'count': 1}, 'S_7_I_5_R_5_6': {'count': 1}, 'S_10_I_4_R_3_4': {'count': 1}, 'S_10_I_4_R_3_5': {'count': 1}, 'S_10_I_4_R_2_4': {'count': 1}, 'S_9_I_4_R_3_5': {'count': 1}, 'S_8_I_5_R_4_6': {'count': 1}, 'S_9_I_5_R_3_5': {'count': 1}, 'S_6_I_6_R_4_6': {'count': 1}, 'S_6_I_6_R_5_6': {'count': 1}, 'S_10_I_2_R_1_1': {'count': 1}, 'S_10_I_5_R_4_5': {'count': 1}, 'S_3_I

In [70]:
import networkx as nx
from typing import List, Dict, Tuple, Any
import json

def prepare_vega_data(
    graph: nx.Graph,
    y_axis_attributes: List[str]
) -> Dict[str, List[Dict[str, Any]]]:
    """
    Prepare the graph data in a format suitable for Vega plotting.

    Parameters:
    graph (nx.Graph): The input NetworkX graph.
    y_axis_attributes (List[str]): List of attributes to be used for y-axis in the plot.

    Returns:
    Dict[str, List[Dict[str, Any]]]: A dictionary containing nodes and edges formatted for Vega plotting.
    """
    def extract_position(node_id: str, coord: str) -> Tuple[int, int]:
        """
        Extract position coordinates from node ID based on a given coordinate.

        Parameters:
        node_id (str): The node ID in the format 'S_x_I_y_R_z_a'
        coord (str): The coordinate ('S', 'I', 'R') to extract the position.

        Returns:
        Tuple[int, int]: Extracted x and y position values.
        """
        components = node_id.split('_')
        x_base = int(components[-1])
        y_base = int(components[components.index(coord) + 1])
        return x_base, y_base

    graph = nx.convert_node_labels_to_integers(graph, label_attribute="original_label")
    gjson = nx.json_graph.node_link_data(graph)

    nodes = []
    for item in gjson["nodes"]:
        node_id = item["original_label"]
        for attr_id, y_axis_attr in enumerate(y_axis_attributes):
            x, y = extract_position(node_id, y_axis_attr)
            nodes.append({
                "id": node_id,
                "x": x,
                "y": y,
                "count": item.get("count", 1),
                "exemplar": item.get("exemplar", False),
                "graph": attr_id + 1,
                "graph_name": y_axis_attr,
                "label_id": f"{y_axis_attr}_{y}"
            })

    edges = [
        {
            "source": gjson["nodes"][item["source"]]["original_label"],
            "target": gjson["nodes"][item["target"]]["original_label"],
            "weight": item.get("weight", 1)
        }
        for item in gjson["links"]
    ]

    return {"nodes": nodes, "edges": edges}


y_axis_attributes = ["I", "S", "R"]
vega_data = prepare_vega_data(G, y_axis_attributes)
print("Data prepared:")
print(json.dumps(vega_data, indent=2))


Data prepared:
{
  "nodes": [
    {
      "id": "S_11_I_1_R_1_0",
      "x": 0,
      "y": 1,
      "count": 1,
      "exemplar": false,
      "graph": 1,
      "graph_name": "I",
      "label_id": "I_1"
    },
    {
      "id": "S_11_I_1_R_1_0",
      "x": 0,
      "y": 11,
      "count": 1,
      "exemplar": false,
      "graph": 2,
      "graph_name": "S",
      "label_id": "S_11"
    },
    {
      "id": "S_11_I_1_R_1_0",
      "x": 0,
      "y": 1,
      "count": 1,
      "exemplar": false,
      "graph": 3,
      "graph_name": "R",
      "label_id": "R_1"
    },
    {
      "id": "S_11_I_1_R_1_1",
      "x": 1,
      "y": 1,
      "count": 1,
      "exemplar": false,
      "graph": 1,
      "graph_name": "I",
      "label_id": "I_1"
    },
    {
      "id": "S_11_I_1_R_1_1",
      "x": 1,
      "y": 11,
      "count": 1,
      "exemplar": false,
      "graph": 2,
      "graph_name": "S",
      "label_id": "S_11"
    },
    {
      "id": "S_11_I_1_R_1_1",
      "x": 1,
      "y": 

In [72]:
def update_vega_schema(graph: nx.Graph) -> dict:
    """
    Updates a Vega schema by replacing the node and edge data.

    Parameters:
    vega_json_path (str): Path to the Vega JSON schema file.
    graph (nx.Graph): The input NetworkX graph.
    y_axis_attributes (List[str]): List of attributes to be used for y-axis in the plot.

    Returns:
    dict: The updated Vega schema.
    """
    # Load the existing Vega schema
    with open("./bins_sententree.vg.json", "r") as f:
        schema = json.load(f)

    # Update the schema with nodes and edges data
    schema["data"] = vega.replace_named_with(schema["data"], "nodes", ["values"], vega_data['nodes'])
    schema["data"] = vega.replace_named_with(schema["data"], "edges", ["values"], vega_data['edges'])

    return schema

schema = update_vega_schema(vega_data)

with open("bin_example.png", "wb") as f:
    png = vl_convert.vega_to_png(schema, scale=4)
    f.write(png)

vega.display(schema, format="interactive")




