In [1]:
import ase.io
import numpy as np
import plotly.graph_objects as go

import milad
from milad import generate
from milad import moments
from milad import zernike


def prepare_molecule(molecule: ase.Atoms) -> ase.Atoms:
    com = molecule.get_center_of_mass()
    molecule.set_positions(molecule.positions - com)
    new_positions = molecule.positions
    max_dist_sq = max(np.dot(pos, pos) for pos in new_positions)
    rescaled = 0.7 * new_positions / max_dist_sq**0.5
    molecule.set_positions(rescaled)
    return molecule

num_points = 4
# positions = generate.random_points_in_sphere(num_points, radius=.7)
urea = prepare_molecule(ase.io.read('/home/martin/src/milad/unsaved/urea.pdb'))

In [20]:
# positions = np.array(((0.5, 0.5, 0), (-0.5, 0.5, 0), (0.5, -0.5, 0), (-0.5, -0.5, 0)))
positions = urea.positions
weights = np.array(urea.numbers)
# weights = 1.
sigmas = 0.05
max_order = 7
n_samples = 31

# geom_moms = moments.geometric_moments_of_gaussians(max_order, positions, weights=weights, sigmas=sigmas)
geom_moms = moments.geometric_moments_of_deltas(max_order, positions, weights=weights)
moms = zernike.from_geometric_moments(max_order, geom_moms)

# saved = moms._moments[1][1][0]
# print(f"{moms._moments[1][1][0]} {moms._moments[1][1][1]}")
# moms._moments[1][1][0]=0.05
# moms._moments[1][1][1]=0.15

# print(f"{moms._moments[0][0][0]}\n{moms._moments[1][1]}")

In [21]:
reconstructed_values = moms.value_at(positions)

# Now reconstruct a voxel grid
spacing = np.linspace(-1., 1., n_samples)
grid = np.array(np.meshgrid(spacing, spacing, spacing))

grid_points = grid.reshape(3, -1).T
grid_vals = moms.value_at(grid_points)
for idx, pt in enumerate(grid_points):
    if np.linalg.norm(pt) > 1:
        grid_vals[idx] = 0.

vals = grid_vals.reshape((grid.shape[1:]))

In [27]:
fig = go.Figure(data=go.Isosurface(
    x=grid[0].flatten(),
    y=grid[1].flatten(),
    z=grid[2].flatten(),
    value=vals.flatten(),
    isomin=34,
    isomax=int(reconstructed_values.max() + 1),
))
print(reconstructed_values)
fig.show()


[ 41.44645488 110.80035824  34.73731812  78.33951265 123.31686708
 110.82555839  34.83947502  41.56343147]


In [23]:
fig = go.Figure(data=go.Contour(
    x=grid[0, 0, :, 0],
    y=grid[1, :, 0, 0],
    z=vals[:, :, int(n_samples/2)],
))
fig['layout']['yaxis']['scaleanchor']='x'
print(reconstructed_values)
fig.show()

[ 41.44645488 110.80035824  34.73731812  78.33951265 123.31686708
 110.82555839  34.83947502  41.56343147]
