In [43]:
import os
import json
import sys
import glob
os.environ['NUMEXPR_MAX_THREADS'] = '32'
lib_path = os.path.abspath(os.path.join('..'))
sys.path.append(lib_path)
from report_generator.traj_loading import (
    load_sim_trajs,
    load_native_trajs,
    AndyTraj,
    NativeTraj,
)
from report_generator.tica_plots import make_tica_model, calc_bond_lens
from report_generator.simulate_loss import load_model, get_losses
import deeptime
import numpy as np
from matplotlib.figure import Figure
from matplotlib import pyplot as plt
import matplotlib
from report_generator.kullback_leibler_divergence import kullback_leibler_divergence_1d
import scipy
import plotly.graph_objects as go
import ipywidgets as widgets
from plotly.subplots import make_subplots

import nglview as nv  # For molecular visualization
import mdtraj         # For trajectory handling
import pickle
import hashlib
# NN model
model_path = "/media/DATA_18_TB_1/daniel_s/cgschnet/Majewski_prior_v1_2024.09.12/model_single_chain_subsetD_Majewski_v1_2024.09.12__wd0_explr1en3_0.85_bs4"
# simulated with the AI model 3 times per protein
simulated_traj_paths = [['/tmp/benchmark_sims_1010/2024-10-19T20:13:21.061446/BBA_0.h5', '/tmp/benchmark_sims_1010/2024-10-19T20:13:21.061446/BBA_1.h5', '/tmp/benchmark_sims_1010/2024-10-19T20:13:21.061446/BBA_2.h5']]
frame_stride = 100
LOSS_STRIDE = 100
# harmonic motion??
prior_params = json.load(open(os.path.join(model_path, "prior_params.json"), "r"))

# loaded from the C dabase and runned it trow the MD simulator
native_trajectory_paths = sorted(glob.glob("/media/DATA_18_TB_1/andy/benchmark_set_small/*"))
simulated_trajs_per_protein = load_sim_trajs(simulated_traj_paths)

native_trajs_per_protein: list[list[NativeTraj]] = list(
            map(lambda x: load_native_trajs(x, prior_params), native_trajectory_paths)
        )

def get_cache_filename(native_trajs_path):
    # Create a unique identifier based on the input data
    hash_object = hashlib.md5(str(native_trajs_path).encode())
    return f"tica_model_cache_{hash_object.hexdigest()}.pkl"

def load_or_compute_tica_model(native_trajs, native_trajs_path, cache_dir="./cache"):
    cache_filename = os.path.join(cache_dir, get_cache_filename(native_trajs_path))
    
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)

    if os.path.exists(cache_filename):
        print("Loading TICA model from cache...")
        with open(cache_filename, 'rb') as f:
            return pickle.load(f)
    else:
        print("Computing TICA model...")
        tica_model, native_bond_lens = make_tica_model(native_trajs)
        
        # Save the computed model and data
        with open(cache_filename, 'wb') as f:
            pickle.dump((tica_model, native_bond_lens), f)
        
        return tica_model, native_bond_lens
    
def create_interactive_tica_plot(tica_model, sim_trajs, native_bond_lens, losses):
    sim_datas = [calc_bond_lens(sim_traj.trajectory) for sim_traj in sim_trajs]
    sim_projected_datas = [tica_model.transform(sim_data) for sim_data in sim_datas]
    native_projected_datas = [tica_model.transform(x) for x in native_bond_lens]

    fig, divergence = make_figure_interactive(sim_projected_datas, native_projected_datas, np.concatenate(losses).flatten())
    
    # Create NGLView widget for protein visualization
    view = nv.NGLWidget()
    view.layout.width = '1600px'
    view.layout.height = '800px'

    # Assuming the first trajectory in sim_trajs is representative
    traj = sim_trajs[0].trajectory
    component = view.add_trajectory(traj)
    view.frame = 0

    # Create a FigureWidget for interactivity
    fig_widget = go.FigureWidget(fig)

    def update_protein_view(trace, points, selector):
        if points.point_inds:
            current_frame = points.point_inds[0]
            view.frame = current_frame
            print(f"Displaying frame {current_frame}")

    # Attach the update function to the scatter plot trace
    fig_widget.data[1].on_click(update_protein_view)

    # Combine TICA plot and protein view
    layout = widgets.HBox([fig_widget, view])
    
    return widgets.VBox([layout, out])

def make_figure_interactive(sim_projected_datas, native_projected_data_trajs, losses):
    import numpy as np
    import scipy.stats
    from plotly.subplots import make_subplots
    import plotly.graph_objects as go

    # Concatenate data
    native_projected_data = np.concatenate(native_projected_data_trajs)
    sim_projected_data = np.concatenate(sim_projected_datas)

    # Compute divergence
    divergence = kullback_leibler_divergence_1d(
        sim_projected_data[:, 0],
        native_projected_data[:, 0]
    )

    # Extract TICA components
    x_sim = sim_projected_data[:, 0]
    y_sim = sim_projected_data[:, 1]
    x_native = native_projected_data[:, 0]
    y_native = native_projected_data[:, 1]

    # Create 2D KDE for native data
    native_2d_kernel = scipy.stats.gaussian_kde(native_projected_data[:, :2].T)
    xmin, xmax = x_native.min(), x_native.max()
    ymin, ymax = y_native.min(), y_native.max()
    X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([X.ravel(), Y.ravel()])
    Z = np.reshape(native_2d_kernel(positions).T, X.shape)

    # Create figure with secondary y-axis
    fig = go.FigureWidget(make_subplots(specs=[[{"secondary_y": True}]]))

    # Add 2D KDE heatmap
    heatmap = go.Heatmap(
        x=np.linspace(xmin, xmax, num=100),
        y=np.linspace(ymin, ymax, num=100),
        z=Z.T,
        colorscale='Viridis',
        showscale=False,
        name='Native 2D KDE'
    )
    fig.add_trace(heatmap)

    # Add scatter plot of simulated data
    scatter = go.Scatter(
        x=x_sim,
        y=y_sim,
        mode='markers',
        marker=dict(color='black', size=5),
        name='Simulated Data',
        text=[f"Frame {i}" for i in range(len(x_sim))]  # Hover text
    )
    fig.add_trace(scatter)

    # Compute 1D KDEs
    x_range = np.linspace(
        min(x_sim.min(), x_native.min()),
        max(x_sim.max(), x_native.max()),
        100
    )
    sim_kernel = scipy.stats.gaussian_kde(x_sim)
    native_kernel = scipy.stats.gaussian_kde(x_native)
    y_sim_kde = sim_kernel(x_range)
    y_native_kde = native_kernel(x_range)

    # Add 1D KDE lines on secondary y-axis
    sim_kde_line = go.Scatter(
        x=x_range,
        y=y_sim_kde,
        mode='lines',
        line=dict(color='red'),
        name='Simulated KDE',
        yaxis='y2'
    )
    native_kde_line = go.Scatter(
        x=x_range,
        y=y_native_kde,
        mode='lines',
        line=dict(color='blue'),
        name='Native KDE',
        yaxis='y2'
    )
    fig.add_trace(sim_kde_line)
    fig.add_trace(native_kde_line)

    # Update layout
    fig.update_layout(
        autosize=False,
        width=1200, # Increased from the default
        height=800,
        title='TICA',
        xaxis_title='TICA 0th component',
        yaxis_title='TICA 1st component',
        yaxis2=dict(
            title='Probability',
            overlaying='y',
            side='right',
            position=1.0
        ),
        hovermode='closest'
    )

    # Adjust axes ranges
    fig.update_xaxes(range=[x_range.min(), x_range.max()])
    fig.update_yaxes(range=[y_sim.min(), y_sim.max()], secondary_y=False)
    fig.update_yaxes(
        range=[0, max(y_sim_kde.max(), y_native_kde.max())],
        secondary_y=True
    )

    return fig, divergence


interactive_plot = None
for i, (native_trajs, native_trajs_path, simulate_trajectories) in enumerate(
        zip(native_trajs_per_protein, native_trajectory_paths, simulated_trajs_per_protein)
    ):
    # tica_model, native_bond_lens = make_tica_model(native_trajs)
    tica_model, native_bond_lens = load_or_compute_tica_model(native_trajs, native_trajs_path)
    losses = get_losses(None, model_path, native_trajs, prior_params, LOSS_STRIDE)

    # Create and display the interactive TICA plot with protein visualization
    interactive_plot = create_interactive_tica_plot(
        tica_model,
        simulate_trajectories,
        [x[::LOSS_STRIDE] for x in native_bond_lens],
        losses,
    )

display(interactive_plot)


loading path: /media/DATA_18_TB_1/andy/benchmark_set_small/BBA/starting_pos_0
Loading TICA model from cache...


VBox(children=(HBox(children=(FigureWidget({
    'data': [{'colorscale': [[0.0, '#440154'], [0.111111111111111…

Displaying frame 122
Displaying frame 601
Displaying frame 577
Displaying frame 622
Displaying frame 2641
Displaying frame 725
