In [1]:
import plotly.express as px
import pandas as pd

def plot_scatter(
    data,
    color_dict,
    color,
):
    df_dict = color_dict.copy()

    dim = data.shape[1]
    assert dim in [2, 3]
    for i, name in zip(range(dim), ["x", "y", "z"]):
        df_dict[name] = data[:, i]

    df = pd.DataFrame(df_dict)
    scatter_kwargs = dict(
        x="x",
        y="y",
        color=color,
        width=600,
        height=600,
        size_max=7,
        hover_data=list(df_dict.keys()),
    )
    if dim == 2:
        fig = px.scatter(df, **scatter_kwargs)
    else:  # dim == 3
        fig = px.scatter_3d(df, z="z", **scatter_kwargs)
    return fig




## Load the data from training

In [11]:
import numpy as np
from synd.core import load_model

# Load the data
data = np.load('static_model/z.npy')
# Load the model
model = load_model('ntl9_folding_synd/ntl9_folding.synd')
# Get the rmsd from the backmapper for all of the states
rmsd = np.array(model.backmap([i for i in range(3152)]))
# Add a zero to the end for the rmsd to the native state
rmsd = np.concatenate((np.concatenate(rmsd), np.zeros(1)))
# Load the extra target data
extra_data = np.load('ntl9_folding_synd/near_target_CA_rmsd.npy')
# Add the extra data to the rmsd
rmsd = np.concatenate((rmsd, extra_data))

## Plot the data

In [12]:
# Plot colored by rmsd
plot_scatter(data, {'rmsd':  rmsd}, 'rmsd')