In [None]:
"""Notebook for visualising a transition and graph creation used for the MGK GPR model"""

In [None]:
import numpy as np
import pandas as pd
from graph_helpers import all_atoms_to_networkx
import matplotlib.pyplot as plt
import networkx as nx
import nglview
from scipy.spatial import distance

In [None]:
def get_atoms_view(atoms_before, atoms_after):

    atoms_before = atoms_before.copy()
    atoms_after = atoms_after.copy()

    atoms_before.set_positions((p := atoms_before.get_positions()) - p[0])
    atoms_after.set_positions((p := atoms_after.get_positions()) - p[0])

    n_before = len(atoms_before)

    positions_before = atoms_before.get_positions().copy()
    positions_after = atoms_after.get_positions().copy()

    D = distance.cdist(positions_before, positions_after)
    survived_pruning = np.any(D <= 1e-5, axis=1)
    assert survived_pruning.sum() == len(atoms_after)

    indices_full = np.arange(n_before)
    indices_survived = indices_full[survived_pruning]

    labelText = np.array(["  "] * indices_full.size)
    labelText[indices_survived] = np.arange(indices_survived.size).astype(str)

    i_S_bef, i_M_bef, i_E_bef = 0, n_before - 2, n_before - 1

    view = nglview.show_ase(atoms_before, default_representation=False)
    view.clear_representations()
    view.add_representation("ball+stick", selection=indices_full, opacity=0.3)
    view.add_representation("ball+stick", selection=indices_survived)
    view.add_representation("ball+stick", selection=[i_S_bef], color="#FF5722")
    view.add_representation("ball+stick", selection=[i_M_bef], color="#FF5722")
    view.add_representation("ball+stick", selection=[i_E_bef], color="#FF5722")

    labelText[i_S_bef] = "S"
    labelText[i_M_bef] = "M"
    labelText[i_E_bef] = "E"
    view.add_label(
        color="black", labelType="text", labelText=labelText, attachment="middle_center"
    )

    view.center(selection=[0])
    desired_orientation = [
        1.4801925992525733,
        -9.821012471758628,
        11.616216321538623,
        0,
        -14.694371725634804,
        2.094024828506622,
        3.6428336959793937,
        0,
        -3.932455618069119,
        -11.521393656032389,
        -9.239752094334555,
        0,
        -1.9299999475479126,
        2.639999896287918,
        1.8285000324249268,
        1,
    ]
    view.control.orient(desired_orientation)
    view.control.zoom(0.05)
    return view, labelText


config = dict(
    r_cut=3,  # maximum radius of neighborhood for edge consideration for each atom
    neigh_max=5,  # maximum number of connected edges allowed per atom
    HR_radius=5,  # maximum distance from start, middle or end positions allowed for atom still to be considered
    seed=13,
)
df = pd.read_pickle("../data/atoms.pkl")
df = (
    df[(df.split != "test") * (df.origin == "traj")]
    .sample(n=1, random_state=config["seed"])
    .reset_index(drop=True)
)
df["d"] = df["d_empi"].copy()

collect_networks, atoms_pruned = all_atoms_to_networkx(
    df,
    r_cut=config["r_cut"],
    neigh_max=config["neigh_max"],
    HR_radius=config["HR_radius"],
    collect_pruned=True,
)

atoms_before = atoms_pruned[0][0]
atoms_after = atoms_pruned[0][1]
view, labels = get_atoms_view(atoms_before, atoms_after)
view

In [None]:
view.render_image()

In [None]:
view.download_image(
    filename="structure_MGK.png", factor=5, antialias=True, trim=True, transparent=True
)

In [None]:
def plot_graph(G_to_draw, middle=False, name="atom_graph.png"):
    ws = np.array([attr["w"] for u, v, attr in G_to_draw.edges(data=True)])
    ws = (8 - 0.2) * ws + 0.2
    plt.figure(figsize=(6, 6))
    symbol_color_map = {
        1: "#fefefe",
        6: "#868686",
        7: "#4766ff",
        8: "#fb1c1c",
        16: "#f6f62f",
    }  #

    symbol_colors = [
        symbol_color_map[attr["atomic_number"]]
        for i, attr in G_to_draw.nodes(data=True)
    ]
    symbol_colors[0] = "#FF5722"
    symbol_colors[-1] = "#FF5722"
    if middle:
        symbol_colors[-2] = "#FF5722"

    n = len(G_to_draw)
    i_S_bef, i_M_bef, i_E_bef = 0, n - 2, n - 1
    labelText = {i: f"{i}" for i in range(n)}
    labelText[i_S_bef] = "S"
    labelText[i_M_bef] = "M"
    labelText[i_E_bef] = "E"

    nx.draw_networkx(
        G_to_draw,
        pos=nx.kamada_kawai_layout(G_to_draw),
        edgecolors="black",
        with_labels=True,
        labels=labelText,
        width=ws,
        node_size=1000,
        node_color=symbol_colors,
        linewidths=1,
    )

    plt.box(False)
    plt.savefig(name, dpi=400, bbox_inches="tight", transparent=True)
    plt.show()


plot_graph(collect_networks.iloc[0].graph, middle=True)