In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from torch_autoneb import suggest, config, find_minimum, neb, to_simple_graph, visualise
import torch
from torch import optim
import main
import pickle
import numpy as np
import os
import networkx as nx
import matplotlib.pyplot as plt
import yaml
%matplotlib inline

In [None]:
# Replace this path with your own project directory
project_directory = "your_project_dir"
assert os.path.isdir(project_directory), "Project directory was not found!"

In [None]:
# Load config
with open(os.path.join(project_directory, "config.yaml"), "r") as file:
    configuration = yaml.safe_load(file)

min_config = config.OptimConfig.from_dict(configuration["minimum"])
lex_config = config.LandscapeExplorationConfig.from_dict(configuration["exploration"])

In [None]:
# Load graph
with open(os.path.join(project_directory, "graph.p"), "rb") as file:
    graph = pickle.load(file)
simple_graph = to_simple_graph(graph, lex_config.weight_key, lex_config.auto_neb_config.cycle_count)
len(graph.nodes), len(graph.edges), len(simple_graph.edges)

## Connectivity Graph

In [None]:
visualise.draw_connectivity_graph(simple_graph, lex_config.value_key, lex_config.weight_key)

## Minimum Spanning Tree

In [None]:
mst = nx.minimum_spanning_tree(simple_graph, lex_config.weight_key)
visualise.draw_connectivity_graph(mst, lex_config.value_key, lex_config.weight_key)

## Evaluation

In [None]:
node_values = torch.Tensor([mst.nodes[node][lex_config.value_key] for node in mst.nodes])
saddle_values = torch.Tensor([mst.get_edge_data(*edge)[lex_config.weight_key] for edge in mst.edges])

In [None]:
print("Averages over minimum spanning tree:")
print(f"Minima:  {node_values.mean().item():.4f} ± {node_values.std().item():.4f} ({lex_config.value_key})")
print(f"Saddles: {saddle_values.mean().item():.4f} ± {saddle_values.std().item():.4f} ({lex_config.weight_key})")