In [None]:
import os

import numpy as np
import plotly.graph_objects as go

import cuarray
import netchem
import netcalc

In [None]:
tutorial_directory = netchem.data_files("pyro")

In [None]:
r = 0.9
gaussian_2D_mean = np.zeros(2)
gaussian_2D_cov = np.array([[1.0, r], [r, 1.0]])
gaussian_2D_size = 10000

In [None]:
gaussian_2D = np.random.multivariate_normal(
    mean=gaussian_2D_mean,
    cov=gaussian_2D_cov,
    size=gaussian_2D_size,
).T.astype(np.float32)

In [None]:
gaussian_2D_figure = go.Figure(
    data=go.Scatter(
    x=gaussian_2D[0],
    y=gaussian_2D[1],
    mode='markers',
    ))

In [None]:
gaussian_2D_figure.show()

In [None]:
X = cuarray.FloatCuArray()
X.fromNumpy2D(gaussian_2D)

In [None]:
ab = cuarray.IntCuArray()

In [None]:
ab.init(1, 2)
ab[0][0] = 0
ab[0][1] = 1

In [None]:
I = cuarray.FloatCuArray()

In [None]:
n = gaussian_2D_size
k = 4
xd = 2
d = 1
platform = netcalc.GPU_PLATFORM

In [None]:
netcalc.mutualInformation(
    X=X,
    I=I, 
    ab=ab,
    k=k,
    n=n,
    xd=xd,
    d=d,
    platform=platform,
)

In [None]:
print(I)

In [None]:
analytic_I = -0.5*np.log(np.linalg.det(np.cov(gaussian_2D)))
print(analytic_I)

In [None]:
pyro_network = netchem.Network()

In [None]:
trajectory_file = str(tutorial_directory / "pyro.dcd")
topology_file = str(tutorial_directory / "pyro.pdb")
first_frame = 0
last_frame = 999
#stride=3
pyro_network.init(
    trajectoryFile=trajectory_file,
    topologyFile=topology_file,
    firstFrame=first_frame,
    lastFrame=last_frame,
    #stride=stride 
)



In [None]:
pyro_R = cuarray.FloatCuArray()
pyro_ab = cuarray.IntCuArray()

In [None]:
pyro_num_nodes = pyro_network.numNodes()
pyro_num_node_pairs = pyro_num_nodes**2
pyro_ab.init(
    pyro_num_node_pairs,
    2,
)
for i in range(pyro_num_nodes):
    for j in range(pyro_num_nodes):
        pyro_node_pair_index = i*pyro_num_nodes + j
        pyro_ab[pyro_node_pair_index][0] = i
        pyro_ab[pyro_node_pair_index][1] = j

In [None]:
pyro_n = pyro_network.numFrames()
pyro_d = 3
pyro_xd = 2
pyro_k = 4
pyro_platform = 0

In [None]:
netcalc.generalizedCorrelation(
    X=pyro_network.nodeCoordinates(),
    R=pyro_R,
    ab=pyro_ab,
    k=pyro_k,
    n=pyro_n,
    d=pyro_d,
    xd=pyro_xd,
    platform=pyro_platform,
)

In [None]:
pyro_R_np = pyro_R.toNumpy2D().reshape(
    pyro_num_nodes,
    pyro_num_nodes,
)

In [None]:
pyro_R_figure_x = [i for i in range(pyro_num_nodes)]
pyro_R_figure_y = [i for i in range(pyro_num_nodes)]

In [None]:
pyro_R_figure = go.Figure(
    data=go.Heatmap(
        x=pyro_R_figure_x,
        y=pyro_R_figure_y,
        z=pyro_R_np,
        colorscale='jet',
        zsmooth='best',
    ),
)

In [None]:
pyro_R_figure.show()

In [None]:
checkpoint_frequency = 10000
checkpoint_file_name = str(tutorial_directory / "pyro_R")
netcalc.generalizedCorrelationWithCheckpointing(
    X=pyro_network.nodeCoordinates(),
    R=pyro_R,
    ab=pyro_ab,
    k=pyro_k,
    n=pyro_n,
    d=pyro_d,
    xd=pyro_xd,
    platform=pyro_platform,
    checkpointFrequency=checkpoint_frequency,
    checkpointFileName=checkpoint_file_name
)

In [None]:
list(tutorial_directory.glob("*"))

In [None]:
restart_R_file_name = f"{checkpoint_file_name}_80000.npy"
restart_ab_file_name = f"{checkpoint_file_name}_ab.npy"
restart_pyro_R = cuarray.FloatCuArray()

netcalc.generalizedCorrelationRestartWithCheckpointing(
    X=pyro_network.nodeCoordinates(),
    R=restart_pyro_R,
    k=pyro_k,
    n=pyro_n,
    d=pyro_d,
    xd=pyro_xd,
    platform=pyro_platform,
    checkpointFrequency=checkpoint_frequency,
    checkpointFileName=checkpoint_file_name,
    restartAbFileName=restart_ab_file_name,
    restartRFileName=restart_R_file_name,
)

In [None]:
restart_pyro_R_np = restart_pyro_R.toNumpy2D().reshape(
    pyro_num_nodes,
    pyro_num_nodes,
)

In [None]:
sum(sum(restart_pyro_R_np == pyro_R_np))==pyro_num_nodes**2

In [None]:
restart_pyro_R_figure = go.Figure(
    data=go.Heatmap(
        x=pyro_R_figure_x,
        y=pyro_R_figure_y,
        z=restart_pyro_R_np,
        colorscale='jet',
        zsmooth='best',
    ),
)

In [None]:
restart_pyro_R_figure.show()

In [None]:
pyro_R_figure.show()