# Compute Mode Connectivity Graph & Merge Tree

Notes

- The full pipeline includes:
    - train PINN models using 100 different random seeds
    - setup MC runs using `setup_MC_runs.ipynb` (this creates a file used by `train_eval_pinn_multi.py`)
    - submit MC runs using `train_eval_pinn_multi.py`
    - process MC runs and construct graphs using this notebook

<br>

- Results are strored in the following locations
    - individual checkpoints: 
    ```
    /global/cfs/cdirs/m636/geniesse/projects/characterizing-pinns-failure-modes/pbc_examples/checkpoints
    ```
    - curve checkpoints:
    ```
    /global/cfs/cdirs/m636/geniesse/projects/dnn-mode-connectivity/checkpoints_global
    ```

<br>

- Requires installing the following packages:

    - https://github.com/mrzv/nesoi
    - https://ripser.scikit-tda.org/en/latest/

In [2]:
# !pip install git+https://github.com/mrzv/nesoi.git 

In [3]:
# !pip install Ripser

# Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os 
import numpy as np
import pandas as pd
import networkx as nx
import scipy 

import matplotlib.pyplot as plt
import seaborn as sns

# Compute pairwise-connected graph

1. Process model pair information
2. Construct graph and add mode connected point between each edge

## Process model pair information


Here is an example checkpoint folder for a single curve. Note, the folder name includes the name of the two models connected by the curve.

In [None]:
!ls checkpoints_global/PINN_convection_beta_1.0_lr_1.0_seed_001_PINN_convection_beta_1.0_lr_1.0_seed_002

Here is a quick overview of the options that can be changed below:
- `beta = {1.0, 50.0}    # PINN wave speed coefficient`
- `eval_epoch = {0, 50}  # how long the curve was trained for (0: linear MC, 50: nonlinear MC)`
- `keep_nodes = [2…100]  # how many models to include in the graph` 


In [None]:
# configure options here
beta = 50.0      # PINN wave speed coefficient
eval_epoch = 0   # how long the curve was trained for (0: linear MC, 50: nonlinear MC)
keep_nodes = 5  # how many models to include in the graph (max 100)

In [None]:
# load pairs 
pairs_file = f"PINN_convection_beta_{beta}_lr_1.0_n_seeds_100_pairs.csv"
df_pairs = pd.read_csv(pairs_file)


# assign curve name 
df_pairs = df_pairs.assign(
    curve_name=[f"{_.init_start}_{_.init_end}".replace(".pt","") for __,_ in df_pairs.iterrows()]
)

# assign result folder 
df_pairs = df_pairs.assign(
    result_file=[f"checkpoints_global/{_}/checkpoint-{eval_epoch}_curve.npz" for _ in df_pairs.curve_name]
)

# assign result folder 
df_pairs = df_pairs.assign(
    checkpoint_file=[f"checkpoints_global/{_}/checkpoint-{eval_epoch}.pt" for _ in df_pairs.curve_name]
)

# assign source,target values
df_pairs = df_pairs.assign(
    source=[int(_.split("_seed_")[-1].split(".pt")[0]) for _ in df_pairs.init_start],
    target=[int(_.split("_seed_")[-1].split(".pt")[0]) for _ in df_pairs.init_end]
)

# assign edge tuples
df_pairs = df_pairs.assign(
    edge=[_ for _ in zip(df_pairs.source.values, df_pairs.target.values)]
)

# limit to the first n nodes
df_pairs = df_pairs[df_pairs.source.le(keep_nodes) & df_pairs.target.le(keep_nodes)]
df_pairs = df_pairs.reset_index(drop=True)


# show df
df_pairs

## Construct graph and add mode connected point between each edge


In [None]:
# initialize graph based on edges
edges = df_pairs.edge.to_list()
G1 = nx.Graph(edges)

# construct new graph (with extra nodes)
G2 = nx.Graph()

# loop over edges 
for index, df_pair in df_pairs.iterrows():

    # extract curve metrics for the edge pair
    with np.load(df_pair.result_file) as result:
        curve_losses = result['tr_error_u_rel']
        # curve_losses = result['tr_loss']
        # mc = float(result['mc_metric'])
    
    # assign node id to the new node
    curve_node = len(G1) + index + 1
    source_node = df_pair.source
    target_node = df_pair.target
    
    # update G1
    G1.nodes[source_node]['loss'] = curve_losses[0]
    G1.nodes[target_node]['loss'] = curve_losses[-1]

    # assign node data
    G2.add_node(source_node, loss=curve_losses[0])
    G2.add_node(curve_node, loss=curve_losses[np.argmax(np.abs(curve_losses))])
    G2.add_node(target_node, loss=curve_losses[-1])

    # add new path going through the curve node
    nx.add_path(G2, [source_node, curve_node, target_node])
    print(f"[+] ({source_node}, {curve_node:2d}, {target_node}) => ({G2.nodes[source_node]['loss']:.6f}, {G2.nodes[curve_node]['loss']:.6f}, {G2.nodes[target_node]['loss']:.6f})")

# show some things
print(f"{G1.number_of_nodes()=}")
print(f"{G1.number_of_edges()=}")
print(f"{G2.number_of_nodes()=}")
print(f"{G2.number_of_edges()=}")

## Convert graphs to format for `nesoi`

In [None]:
# save graph edges as np array
edges = [_[:2] for _ in nx.to_edgelist(G2, nodelist=None)]
edges = np.array(edges) 
edges = edges - 1 # re-index for nesoi

# save loss values
loss = np.array([G2.nodes[_]['loss'] for _ in sorted(G2.nodes)])

# save things (optional)
# save_as = f"{pairs_file.replace('.csv','')}_eval_epoch_{eval_epoch}_keep_nodes_{keep_nodes}.npz"
# np.savez(save_as, edges=edges, loss=loss)
# print(save_as)

# Compute Merge Tree (using `nesoi`)

In [None]:
import sys
import numpy as np
import nesoi

In [None]:
# construct a new graph for nesoi (not sure this is necessary... can we just use G2 from above?)
G = nx.Graph()
G.add_nodes_from(list(range(len(loss))))
G.add_edges_from(edges)

In [None]:
tree = nesoi.TMT_float(len(loss), False)

for i,v in enumerate(loss):
    tree.add(i,v)

for e in edges:
    # tree.merge(e[0]-1, e[1]-1)
    tree.merge(e[0], e[1])

tree.repair()

# Output persistence diagram
for (u,s,v) in tree.traverse_persistence():
    print(u,s,v, loss[u], loss[s])


## Compute full tree (includes degree-2 nodes)

In [None]:
# Generate conventional tree
from collections import defaultdict
paths = defaultdict(set)

for u in range(len(loss)):
    (s,v) = tree[u]
    if u == v: continue
    paths[v].add(s)
    if u != s:
        paths[u].add(s)

for u in paths.keys():
    paths[u] = list(paths[u])
    paths[u].sort(key = lambda x: loss[x])

path_edges = []
for u in paths.keys():
    p = paths[u]
    print(u,p[0])
    path_edges.append((u, p[0]))

    for i in range(len(paths[u]) - 1):
        print(p[i], p[i+1])
        path_edges.append((p[i], p[i+1]))

        

T_full = nx.Graph()
T_full.add_nodes_from(list(range(len(loss))))
T_full.add_edges_from(path_edges)

T_full.remove_nodes_from([n for n in T_full if not len(list(T_full.neighbors(n)))])

height_full = {n: loss[n] for n in T_full}

## Compute condensed tree (no degree-2 nodes)

In [None]:
# Generate conventional tree
from collections import defaultdict
paths = defaultdict(list)

for (u,s,v) in tree.traverse_persistence():
    if u == v: continue
    paths[v].append(s)
    paths[u].append(s)

for u in paths.keys():
    paths[u].sort(key = lambda x: loss[x])

path_edges = []
for u in paths.keys():
    p = paths[u]
    print(u,p[0])
    path_edges.append((u, p[0]))

    for i in range(len(paths[u]) - 1):
        print(p[i], p[i+1])
        path_edges.append((p[i], p[i+1]))

        
import networkx as nx

T = nx.Graph()
T.add_nodes_from(list(range(len(loss))))
T.add_edges_from(path_edges)
T.remove_nodes_from([n for n in T if not len(list(T.neighbors(n)))])


### DEBUGGING 
# connect max in the tree with max in full tree
# T.add_edges_from([(13,11)])
nodes = list(T.nodes)
nodes.sort(key=lambda x: loss[x])
max_node = nodes[-1]

full_nodes = list(T_full.nodes)
full_nodes.sort(key=lambda x: loss[x])
max_node_full = full_nodes[-1]
print()
print(max_node, max_node_full)
T.add_edges_from([(max_node,max_node_full)])

height = {n: loss[n] for n in T}
# height = {n: loss[n] for n in T if len(T[n])} 
# height = {n: loss[n] for n in T if len(list(T.neighbors(n)))}

## Save intermediates (optional)

In [None]:
# import pickle

# with open("merge_tree.pkl", "wb") as f:
#     pickle.dump(T, f)
    
# with open("merge_tree_full.pkl", "wb") as f:
#     pickle.dump(T_full, f)
    
# with open("height.pkl", "wb") as f:
#     pickle.dump(height, f) 
    
# with open("height_full.pkl", "wb") as f:
#     pickle.dump(height_full, f)

## Draw the trees

In [None]:
np.random.seed(1)
nx.draw_kamada_kawai(T_full, node_color = list(height_full.values()), with_labels=True)

In [None]:
np.random.seed(1)
nx.draw_kamada_kawai(T, node_color = list(height.values()), with_labels=True)

## Draw the trees (using DMT_tools)

In [None]:
# borrowed from https://github.com/trneedham/Decorated-Merge-Trees
from DMT_tools import mergeTree_pos
def draw_merge_tree(G,height,axes=False, ax=None, **kwargs):
    # Input: merge tree as G, height
    # Output: draws the merge tree with correct node heights
    pos = mergeTree_pos(G,height)
    if ax is None:
        fig, ax = plt.subplots()
    nx.draw_networkx(G, pos=pos, ax=ax, with_labels=True, **kwargs)
    if axes:
        ax.tick_params(left=True, bottom=False, labelleft=True, labelbottom=False)
    return


fig, ax = plt.subplots(figsize=(20,10))
draw_merge_tree(T_full, height_full, ax=ax, node_size=500, width=10, node_color="indianred", edge_color="indianred")
draw_merge_tree(T, height, ax=ax, node_size=1000, width=10, node_color="cadetblue", edge_color="cadetblue")

# Save Merge Tree files (for 1D profiles)

The minimum Format:

    Merge Tree Edges
    "SegmentationId","upNodeId","downNodeId"

    Merge Tree Nodes
    "NodeId","Scalar","CriticalType"

    Merge Tree Segmentations
    "Loss","SegmentationId"

## Helper functions

In [None]:
def get_critical_type(T, n):
    # 0=minima, 1=saddle, 3=root
    deg = len(list(T.neighbors(n))) # use the number of neighbors as the critical point?
    if max(loss) == loss[n]:
        return 3
    if deg == 1:
        return 0
    if deg == 3:
        return 1
    return -1

sys.setrecursionlimit(5050)
def find_nearest_critical_point(G, n, verbose=0):
    if verbose > 0: 
        print(f"{n=}")
    
    # compute neighbors
    nbrs = list(G.neighbors(n))

    # return if critical point 
    if len(nbrs) != 2:
        if verbose > 0: 
            print(f"Found critical point!!! ({n=})")
        return n
    
    # TODO: sort neighbors by loss
    if verbose > 1: 
        print(n, nbrs)
    nbrs.sort(key = lambda x: loss[x])
    if verbose > 1: 
        print(n, nbrs)
    
    # traverse neighbors until critical point is found
    for nbr in nbrs:
        if verbose > 0: 
            print(f"\t{nbr=}")
        return find_nearest_critical_point(G, nbr)
        

## Initialize data frames for saving

In [None]:
df_mt_edges = pd.DataFrame(columns=["SegmentationId","upNodeId","downNodeId"])
df_mt_nodes = pd.DataFrame(columns=["NodeId","Scalar","CriticalType"])
df_mt_seg = pd.DataFrame(columns=["Loss","SegmentationId"])

## Sort edges in the merge tree by loss

In [None]:
T_edges_sorted = [list(_) for _ in T.edges()]
for T_edge in T_edges_sorted:
    print(T_edge)
    T_edge.sort(key = lambda x: loss[x])
    print(T_edge)
    print()

## Store merge tree edge information

In [None]:
df_mt_edges = pd.DataFrame(columns=["SegmentationId","upNodeId","downNodeId"])
df_mt_edges = df_mt_edges.assign(SegmentationId = [i for i,_ in enumerate(T_edges_sorted)])    
df_mt_edges = df_mt_edges.assign(upNodeId = [_[1] for i,_ in enumerate(T_edges_sorted)])
df_mt_edges = df_mt_edges.assign(downNodeId = [_[0] for i,_ in enumerate(T_edges_sorted)])
df_mt_edges

## Store merge tree node information

In [None]:
df_mt_nodes = pd.DataFrame(columns=["NodeId","Scalar","CriticalType"])
df_mt_nodes = df_mt_nodes.assign(NodeId = list(set(np.ravel(T.edges()))))    
df_mt_nodes = df_mt_nodes.assign(Scalar = [loss[_] for _ in df_mt_nodes.NodeId])
df_mt_nodes = df_mt_nodes.assign(CriticalType = [
    get_critical_type(T, _)
    for _ in df_mt_nodes.NodeId
])
df_mt_nodes

## Store merge tree segmentation information

In [None]:
df_mt_seg = pd.DataFrame(columns=["Loss","SegmentationId"])
df_mt_seg = df_mt_seg.assign(Loss = loss)

### find the nearest critical point in T_full
### ... map the critical point to segmentationId based on down node
for node_id in range(len(loss)):
    
    print(node_id)
    
    # find the nearest down node
    down_node_id = find_nearest_critical_point(T_full, node_id)
    
    # find the segmentation id for the edge
    if down_node_id not in df_mt_edges.downNodeId.values:
        # this should only happen for root node
        print(f"Found possible root node ({node_id=}, {down_node_id=}) ... using SegmentationId based on upNodeId ({seg_id=})")
        seg_id = df_mt_edges[df_mt_edges.upNodeId.eq(down_node_id)].SegmentationId.values[0]
    else:
        seg_id = df_mt_edges[df_mt_edges.downNodeId.eq(down_node_id)].SegmentationId.values[0]

    # display results
    # print(f"{node_id=}, {down_node_id=}, {seg_id=}")
    
    # update segmentation 
    df_mt_seg.at[node_id, 'SegmentationId'] = seg_id
    
# show df_mt_seg
df_mt_seg

## Re-compute the merge tree edge information (after using it for the segmentation)

a bit hacky... i know

In [None]:
### Re Number edges AFTER using it for segmentation
df_mt_edges = pd.DataFrame(columns=["SegmentationId","upNodeId","downNodeId"])
df_mt_edges = df_mt_edges.assign(SegmentationId = [i for i,_ in enumerate(T_edges_sorted)])    
df_mt_edges = df_mt_edges.assign(upNodeId = [_[1] for i,_ in enumerate(T_edges_sorted)])
df_mt_edges = df_mt_edges.assign(downNodeId = [_[0] for i,_ in enumerate(T_edges_sorted)])

T_nodes = list(np.sort(np.unique(np.ravel(T_edges_sorted))))
print(T_nodes)

df_mt_edges = df_mt_edges.assign(upNodeIdDataId = df_mt_edges.upNodeId)
df_mt_edges = df_mt_edges.assign(downNodeIdDataId = df_mt_edges.downNodeId)

df_mt_edges = df_mt_edges.assign(upNodeId = df_mt_edges.upNodeIdDataId.apply(T_nodes.index))
df_mt_edges = df_mt_edges.assign(downNodeId = df_mt_edges.downNodeIdDataId.apply(T_nodes.index))


df_mt_edges

## Save the merge tree files (consumed by the topological profile code)

In [None]:
file_name = f"PINN_convection_beta_{beta}_lr_1.0_n_seeds_100_pairs_eval_epoch_{eval_epoch}_keep_nodes_{keep_nodes}.npz"

mt_nodes_file = file_name.replace(".npz", "_MergeTree.csv")
mt_edges_file = file_name.replace(".npz", "_MergeTree_edge.csv")
mt_seg_file = file_name.replace(".npz", "_MergeTree_segmentation.csv")

df_mt_nodes.to_csv(mt_nodes_file, index=None)
df_mt_edges.to_csv(mt_edges_file, index=None)
df_mt_seg.to_csv(mt_seg_file, index=None)

print(f"[+] {mt_nodes_file}")
print(f"[+] {mt_edges_file}")
print(f"[+] {mt_seg_file}")