# Modelling irregular bodies shape via ANNs
In this notebook we explore the possibility to use ANNs to represent the generic shape and density of an irregular body and be trained to reproduce a known gravitational potential field.

To get statically stable asteroids we use results from MPIA work by Francesco Biscani obtained during simulation of protoplanetary formation made by large n-body simulations. Data are included as a submodule in the git project.

To run this notebook create a conda environment using the following commands:
```
 conda create -n geodesyann python=3.8 ipython scikit-learn numpy h5py matplotlib
 conda activate geodesyann
 conda install -c open3d-admin open3d
 pip install sobol_seq
```

And you will need pytorch (CPU is enough) for the ANN part


In [2]:
# core stuff
import h5py
import numpy as np
import scipy
from copy import deepcopy

# pytorch
from torch import nn
import torch
# For debugging and development purposes this is now set to float64 ... change for speed on GPUs
torch.set_default_tensor_type(torch.DoubleTensor)

# misc
import sobol_seq
from scipy import integrate

# plotting stuff
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
%matplotlib notebook

# Loading and visualizing the ground truth asteroid (a point cloud)

In [3]:
# We import the data from MPIA containing pseudo-stable asteroid shapes
f = h5py.File('sample_vis_data/sample_01/state_10567.hdf5','r')
f2 = h5py.File('sample_vis_data/sample_01/global.hdf5', 'r')

In [4]:
# The file state_ ... contains the positions of all particles as well as the indices
# of those belonging to a cluster. Here we extract the largest ones.
dims = [(len(f[cluster][()]), cluster) for cluster in f.keys() if 'cluster' in cluster]
largest_clusters = sorted(dims,reverse=True)

In [5]:
# We have ordered the largest asteroids, we now extract positions for one in particular
rank = 4
print("Target: ", largest_clusters[rank][1])
# The particles idxs for this cluster
idx = f[largest_clusters[rank][1]][()]
# The particle radius
radius = f2['radius'][()]
# Particle positions
x_raw = f['x'][()][idx]
y_raw = f['y'][()][idx]
z_raw = f['z'][()][idx]
print("Diameter: ", 2 * radius)

Target:  cluster_2400
Diameter:  0.00043088693800637674


In [6]:
from sklearn.neighbors import NearestNeighbors
# We put xyz in a different shape (point_cloud)
point_cloud = np.append(x_raw, np.append(y_raw,z_raw))
point_cloud = point_cloud.reshape((3,len(x_raw)))
point_cloud = np.transpose(point_cloud)

nbrs = NearestNeighbors(n_neighbors=4, algorithm='ball_tree').fit(point_cloud)
distances, indices = nbrs.kneighbors(point_cloud)

print("Minimum distance between particles: ", min(distances[:,1]))
print("Maximum distance between particles: ", max(distances[:,1]))

# We take out particles that are not "touching" at least two neighbours
unstable_points = np.where(distances[:, 3]> 2 * radius * 1.01)[0]
print("Number of unstable points: ", len(unstable_points))
x = np.delete(x_raw, unstable_points, 0)
y = np.delete(y_raw, unstable_points, 0)
z = np.delete(z_raw, unstable_points, 0)

Minimum distance between particles:  0.00042108838275198475
Maximum distance between particles:  0.0005132013991266676
Number of unstable points:  5


In [7]:
# We subtract the mean so that the origin is the center of figure
x = x - np.mean(x)
y = y - np.mean(y)
z = z - np.mean(z)
# We normalize so that the axes are at most one
max_value = max([max(abs(it)) for it in [x,y,z]])
x = x / max_value
y = y / max_value
z = z / max_value
plot_radius = radius /  max_value  * 3000
# We put xyz in a different shape (point_cloud)
point_cloud = np.append(x, np.append(y,z))
point_cloud = point_cloud.reshape((3,len(x)))
point_cloud = np.transpose(point_cloud)
point_cloud = torch.tensor(point_cloud)

### Visualization via matplotlib

In [8]:
fig = plt.figure()
ax = fig.add_subplot(221, projection='3d')

# And visualize the masses
ax.scatter(x, y, z, color = 'k', s = plot_radius/2, alpha=0.1)
ax.set_xlim([-1,1])
ax.set_ylim([-1,1])
ax.set_zlim([-1,1])
ax.view_init(elev=45., azim=125.)

ax2 = fig.add_subplot(222)
ax2.scatter(x, y, color = 'k', s = plot_radius/2, alpha=0.1)
ax2.set_xlim([-1,1])
ax2.set_ylim([-1,1])

ax3 = fig.add_subplot(223)
ax3.scatter(x, z, color = 'k', s = plot_radius/2, alpha=0.1)
ax3.set_xlim([-1,1])
ax3.set_ylim([-1,1])

ax4 = fig.add_subplot(224)
ax4.scatter(y, z, color = 'k', s = plot_radius/2, alpha=0.1)
ax4.set_xlim([-1,1])
ax4.set_ylim([-1,1])

<IPython.core.display.Javascript object>

(-1.0, 1.0)

### Visualization via open3d

In [9]:
import numpy as np
import open3d as o3d

colors = np.exp(-np.array(color))
colors = np.append(np.append(colors, colors), colors).reshape((3, len(color))).transpose()

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(point_cloud)
pcd.colors = o3d.utility.Vector3dVector(colors)


hull, _ = pcd.compute_convex_hull()
hull_ls = o3d.geometry.LineSet.create_from_triangle_mesh(hull)
hull_ls.paint_uniform_color((1, 0, 0))
o3d.visualization.draw_geometries([pcd, hull_ls])

NameError: name 'color' is not defined

# Computing the gravitational potential of the ground truth the asteroid
The Canvendish constant is not included (or $G=1$), so that we have:
$$
U_L = - \sum_{i=1}^N \frac{m_i}{|\mathbf x - \mathbf r_i|}
$$
where, assuming the asteroid with a unitary mass $m_i = 1/N$, hence:
$$
U_L = - \frac 1N \sum_{i=1}^N \frac{1}{|\mathbf x - \mathbf r_i|}
$$

In [20]:
# This will create the labels for the supervised learning
def U_L(target_points, point_cloud):
    retval=torch.empty(len(target_points),1)
    for i, target_point in enumerate(target_points):
        retval[i] = torch.mean(1./torch.norm(torch.sub(point_cloud,target_point), dim=1))
    return - retval 

In [21]:
target_points = torch.rand(100,3)

In [22]:
%timeit U_L(target_points, point_cloud)

56.2 ms ± 953 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Representing an asteroid via a neural network


## 1 - Instantiating the network
The networks inputs are the cartesian coordinates of a point in the unit cube, encoded via some transformation

In [23]:
# All encodings work taking as input a tensor (N, 3) containing the cartesian coordinates of N points
# and returning a tensor of (N, M) that can be used as input to the ANN

# Encoding N.1 (directional encoding):
# x = [x,y,z] is encoded as [ix, iy, iz, r]
def directional_encoding(sampled_points):
    unit = sampled_points / torch.norm(sampled_points,dim=1).view(-1,1)
    return torch.cat((unit, torch.norm(sampled_points,dim=1).view(-1,1)), dim=1)

# Encoding N.2 (positional encoding):
# x = [x,y,z] is encoded as [sin(pi x), sin(pi y), sin(pi z), cos(pi x), cos(pi y), cos(pi z), sin(2 pi x), ....]
def positional_encoding(sampled_points, N = 4):
    retval = torch.cat((torch.sin(np.pi * dummy[:,0]).view(-1,1), torch.cos(np.pi * dummy[:,0]).view(-1,1), torch.sin(np.pi * dummy[:,1]).view(-1,1), torch.cos(np.pi * dummy[:,1]).view(-1,1), torch.sin(np.pi * dummy[:,2]).view(-1,1), torch.cos(np.pi * dummy[:,2]).view(-1,1)), dim=1)
    for i in range(1, N):
        retval = torch.cat((retval, torch.sin(2**i * np.pi * dummy[:,0]).view(-1,1), torch.cos(2**i * np.pi * dummy[:,0]).view(-1,1), torch.sin(2**i * np.pi * dummy[:,1]).view(-1,1), torch.cos(2**i * np.pi * dummy[:,1]).view(-1,1), torch.sin(2**i * np.pi * dummy[:,2]).view(-1,1), torch.cos(2**i * np.pi * dummy[:,2]).view(-1,1)), dim=1)
    return retval

# Encoding N.3 (direct encoding):
def direct_encoding(sampled_points):
    return sampled_points
        
# Encoding N.4 (spherical coordinates). These can be used with positional encoding to create effectively harmonics
def spherical_coordinates(sampled_points):
    phi = torch.atan2(dummy[:,1], dummy[:,0]) / np.pi
    r = torch.norm(dummy, dim=1)
    theta = torch.div(dummy[:,2], r)
    return torch.cat((r.view(-1,1), phi.view(-1,1), theta.view(-1,1)), dim=1)
    

In [150]:
# Encoding choosen
encoding = directional_encoding

# Network initialization scheme (note that if xavier uniform is used all outputs will start at, roughly 0.5)
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.uniform_(m.bias.data, -0.0, 0.0)

# Network architecture. Note that the dimensionality of the first linear layer must match the output
# of the encoding chosen
n_neurons = 100
model = nn.Sequential(
          nn.Linear(4,n_neurons),
          nn.ReLU(),
          nn.Linear(n_neurons,n_neurons),
          nn.ReLU(),
          nn.Linear(n_neurons,n_neurons),
          nn.ReLU(),
          nn.Linear(n_neurons,n_neurons),
          nn.ReLU(),
          nn.Linear(n_neurons,n_neurons),
          nn.ReLU(),
          nn.Linear(n_neurons,n_neurons),
          nn.ReLU(),
          nn.Linear(n_neurons,n_neurons),
          nn.ReLU(),
          nn.Linear(n_neurons,n_neurons),
          nn.ReLU(),
          nn.Linear(n_neurons,n_neurons),
          nn.ReLU(),
          nn.Linear(n_neurons,1),
          nn.Sigmoid(),
        )

# Applying our weight initialization
_  = model.apply(weights_init)

## Visualizing an asteroid represented by the network
The network output is the density in the unit cube. It is, essentially, a three dimensional function and as such it is difficult to plot. 

### Approach 1: plotting a grid of points colored with the rho value

In [152]:
def plot_asteroid1(model, encoding, N=20, bw = False, ax = None, alpha = 0.2, views_2d = False):
    # We create the grid
    x = torch.linspace(-1,1,N)
    y = torch.linspace(-1,1,N)
    z = torch.linspace(-1,1,N)
    X, Y, Z = torch.meshgrid((x,y,z))

    # We compute the density on the grid points (no gradient as its only for plotting)
    nn_inputs = torch.cat((X.reshape(-1,1), Y.reshape(-1,1), Z.reshape(-1,1)), dim=1)
    nn_inputs = encoding(nn_inputs)
    RHO = model(nn_inputs).detach()
 
    # And we plot it
    fig = plt.figure()
    if views_2d:
        ax = fig.add_subplot(221, projection='3d')
    else:
        ax = fig.add_subplot(111, projection='3d')
    if bw:
        col = torch.cat((1-RHO, 1-RHO, 1-RHO, RHO), dim=1)
        alpha = None
    else:
        col = RHO
    
    ax.scatter(X.reshape(-1,1), Y.reshape(-1,1), Z.reshape(-1,1), marker='.', c=col, s=100, alpha = alpha)
    ax.set_xlim([-1,1])
    ax.set_ylim([-1,1])
    ax.set_zlim([-1,1])
    ax.view_init(elev=45., azim=125.)
    
    if views_2d:
        ax2 = fig.add_subplot(222)
        ax2.scatter(X.reshape(-1,1)[:,0], Y.reshape(-1,1)[:,0], marker='.', c=col, s=100, alpha=alpha)
        ax2.set_xlim([-1,1])
        ax2.set_ylim([-1,1])

        ax3 = fig.add_subplot(223)
        ax3.scatter(X.reshape(-1,1)[:,0], Z.reshape(-1,1)[:,0], marker='.', c=col, s=100, alpha=alpha)
        ax3.set_xlim([-1,1])
        ax3.set_ylim([-1,1])

        ax4 = fig.add_subplot(224)
        ax4.scatter(Y.reshape(-1,1)[:,0], Z.reshape(-1,1)[:,0], marker='.', c=col, s=100, alpha=alpha)
        ax4.set_xlim([-1,1])
        ax4.set_ylim([-1,1])
    
    return fig
       
fig = plot_asteroid1(model, encoding, bw = False, views_2d=False)
plt.title("Believe it or not I am an asteroid")

<IPython.core.display.Javascript object>

Text(0.5, 0.92, 'Believe it or not I am an asteroid')

### Approach 2: considering rho as a probability density function and sampling points from it

In [153]:
# Rejection sampling
def plot_asteroid2(model, encoding, N=30**3, views_2d = False, bw = False, alpha = 0.2):
    points = torch.rand(N, 3) *2 -1
    nn_inputs = encoding(points)
    RHO = model(nn_inputs).detach()
    mask = RHO > torch.rand(N,1)
    RHO = RHO[mask]
    points = [[it[0].item(), it[1].item(), it[2].item()] for it,m in zip(points, mask) if m] 
    points = torch.tensor(points)
    
    fig = plt.figure()
    if views_2d:
        ax = fig.add_subplot(221, projection='3d')
    else:
        ax = fig.add_subplot(111, projection='3d')
    if bw:
        col = 'k'
    else:
        col = RHO
    # And we plot it
    ax.scatter(points[:,0], points[:,1], points[:,2], marker='.', c=col, s=100, alpha=alpha)
    ax.set_xlim([-1,1])
    ax.set_ylim([-1,1])
    ax.set_zlim([-1,1])
    ax.view_init(elev=45., azim=125.)

    
    if views_2d:
        ax2 = fig.add_subplot(222)
        ax2.scatter(points[:,0], points[:,1], marker='.', c=col, s=100, alpha=alpha)
        ax2.set_xlim([-1,1])
        ax2.set_ylim([-1,1])

        ax3 = fig.add_subplot(223)
        ax3.scatter(points[:,0], points[:,2], marker='.', c=col, s=100, alpha=alpha)
        ax3.set_xlim([-1,1])
        ax3.set_ylim([-1,1])

        ax4 = fig.add_subplot(224)
        ax4.scatter(points[:,1], points[:,2], marker='.', c=col, s=100, alpha=alpha)
        ax4.set_xlim([-1,1])
        ax4.set_ylim([-1,1])
    
    return fig
    
plot_asteroid2(model, encoding)
plt.title("Believe it or not I am an asteroid")

# Note that if the network is initialized by xavier the density will roughly be 0.5 everywhere so rejection sampling
# will also create a uniform cloud of points

<IPython.core.display.Javascript object>

Text(0.5, 0.92, 'Believe it or not I am an asteroid')

## Computing the gravitational potential of the asteroid ANN model wia Monte Carlo methods
The Network represents the mass density $\rho$, but the potential field created by it is given by the integral:
$$
U_P = - \int_V \frac\rho {|\mathbf r-\mathbf x|}  dV 
$$
where the volume V is the cube $[-1,1]^3$. 

We thus must approximate the above integral, and to do so we use the Monte Carlo formula:
$$
\int_V f(\mathbf x) dv = \frac 1N \sum_i \frac{f(\mathbf x_i)}{g(\mathbf x_i)}
$$
where g is the pdf of the distribution we sample $x_i$ from.

Applying the formula above to our integral we get:

- Naive Monte Carlo
$$
U_P(\mathbf r) \approx \frac 8N \sum_i \frac {\rho_i}{|\mathbf r-\mathbf x_i|} 
$$


where we assumed that $g$ was the uniform random distribution.

Improvements to the above strategy are often a) the use of low-discrepancy seqeunces, b) importance sampling and c) MCMC methods. While a) is indeed applicable, b) and c) seem problematic in our case.

In [27]:
# We generate a low-discrepancy sequence here and keep it in memory (generating it requires some CPU time)
sobol_points = sobol_seq.i4_sobol_generate(3, 200000)

# Naive Montecarlo
def U_Pmc(target_points, model, N = 3000):
    # We generate randomly points in the [-1,1]^3 bounds
    sample_points = torch.rand(N,3) * 2 - 1
    nn_inputs = encoding(sample_points)
    rho = model(nn_inputs)
    retval=torch.empty(len(target_points),1)
    # Only for the points inside we accumulate the integrand (MC method)
    for i, target_point in enumerate(target_points):
        retval[i] = torch.sum(rho/torch.norm(target_point - sample_points, dim=1).view(-1,1)) / N
    return  - 8 * retval

# Low-discrepancy Montecarlo
def U_Pld(target_points, model, N = 3000, noise = 1e-5):
    # We generate randomly points in the [-1,1]^3 bounds
    sample_points = torch.tensor(sobol_points[:N,:] * 2 - 1) + torch.rand(N,3) * noise
    nn_inputs = encoding(sample_points)
    rho = model(nn_inputs)
    retval=torch.empty(len(target_points),1)
    # Only for the points inside we accumulate the integrand (MC method)
    for i, target_point in enumerate(target_points):
        retval[i] = torch.sum(rho/torch.norm(target_point - sample_points, dim=1).view(-1,1)) / N
    return  - 8 * retval

# Importance sampling could be a third approach to use here and benchmark .. so far we did not have success
# in assembling an algorithm able to

In [45]:
# Here we create some target points where to compute the potential
N_try = 1
target_points = (torch.rand(N_try,3)*2-1)*1.1
a = torch.logical_and((target_points[:,0]>-1),(target_points[:,0]<1))
b = torch.logical_and((target_points[:,1]>-1),(target_points[:,1]<1))
c = torch.logical_and((target_points[:,2]>-1),(target_points[:,2]<1))
d = torch.logical_and(torch.logical_or(a,b), c)
target_points=target_points[d]
print("Target point is: ", target_points)

Target point is:  tensor([[-0.0566, -1.0814,  0.6126]])


#### We time the MC methods

In [46]:
%timeit U_Pld(target_points, model, N = 10000)

58 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [47]:
%timeit U_Pmc(target_points, model, N = 10000)

62.9 ms ± 4.92 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### We study their convergence
Note that results may be different according to how discontinuous and ill is the function $\rho$ representaed by the network and to what ground truth we set

In [48]:
# We use as ground truth the value computed by the low-discrepancy method for 100000 points
ground_truth = U_Pld(target_points, model, N = 200000).detach()
# Or, uncomment these lines to attempt quadrature
#def f(x,y,z):
#    nn_inputs = encoding(torch.tensor([[x,y,z]])) 
#    return model(nn_inputs).detach().item()
#
#res, err = integrate.tplquad(f, -1, 1, lambda x: -1, lambda x: 1, lambda x, y: -1, lambda x, y: 1, epsabs = 1e-5, epsrel=1e-5)
print("Ground truth is: ", ground_truth.item())

Ground truth is:  -3.1593624613421794


In [49]:
grid = range(1000, 60000, 100)
mc = []
for g in grid:
    mc.append(torch.abs(U_Pmc(target_points, model, N = g).detach()-ground_truth))

In [50]:
ld = []
for g in grid:
    ld.append(torch.abs(U_Pld(target_points, model, N = g, noise=1e-5).detach()-ground_truth))

In [52]:
# We plot the results
fig = plt.figure()
plt.semilogy(grid, mc, label = "naive MC")
plt.semilogy(grid, ld, label = "low-discrepancy")
plt.legend()

# Note that we do not have a proper ground truth here to compare with, nevertheless the trend seems clear.

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f16c03d5cd0>

# Training The ANN to match the ground truth potential

Let it run up to when its < 1e-3 to actually see something that resembles the original asteroid. When stuck increase the number of monte carlo samples or play around the learning rate.

In [53]:
# Here we set some details of the training
# This loss function adds a normalization constant to the network outputs so
# that the MSE is minimized. Like this it is only the mass distribution to count
# not its absolute value. Seems on paper a great idea, everybody should do it .. but does it work?
def normalized_loss(predicted, labels):
    c = sum(torch.mul(labels, predicted))/sum(torch.pow(predicted,2))
    return sum(torch.pow(torch.sub(labels,c*predicted),2)) / len(labels)

# Here we set the loss function
#loss_fn = torch.nn.MSELoss()
loss_fn = normalized_loss

# Here we set the choosen Monte Carlo method
mc_method = U_Pld

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [60]:
# This is the main training loop
for i in range(5000):
    # At each new epoch we generate new points (its like a new batch), but we make sure
    # they are outside the unit cube
    targets = (torch.rand(100,3)*2-1)*1.1
    a = torch.logical_and((targets[:,0]>-1),(targets[:,0]<1))
    b = torch.logical_and((targets[:,1]>-1),(targets[:,1]<1))
    c = torch.logical_and((targets[:,2]>-1),(targets[:,2]<1))
    d = torch.logical_and(torch.logical_or(a,b), c)
    targets=targets[d]
    labels = U_L(targets, point_cloud)
    
    # Compute the loss
    predicted = mc_method(targets, model, N=100000)
    loss = loss_fn(predicted, labels)
    print(i, loss.item())
    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 2.8880378896937918e-05
1 2.3911347827866795e-05
2 8.525001049511075e-06
3 9.887291442060018e-06
4 7.073956602304201e-05
5 2.2361566606497982e-05
6 1.396521415425562e-05
7 1.2918614002290528e-05
8 1.8310815790225214e-05
9 2.622784949558246e-05
10 1.6399584871489385e-05
11 1.5421931551120097e-05
12 8.330702054499115e-06
13 1.4681882394439987e-05
14 7.958303575797302e-06
15 1.204189484114966e-05
16 1.119107905811304e-05
17 3.371032805451456e-05
18 2.939424678500174e-05
19 9.526395635930094e-06
20 1.641287626618358e-05
21 0.00020835053383695944
22 2.7451632338334515e-05
23 6.780405946970884e-05
24 5.3718334082445034e-05
25 2.5099003711996e-05
26 5.355332179403202e-05
27 6.047502628268048e-05
28 4.692414688802653e-05
29 2.867724545504297e-05
30 1.3929804424419884e-05
31 5.3303850067046045e-05
32 2.3606439532208105e-05
33 4.2732399931389956e-05
34 4.521625355910836e-05
35 1.330105782170851e-05
36 7.441256120516635e-06
37 3.767106500303081e-05
38 0.00024303070307190732
39 7.574223075433466e

315 1.1656646892741166e-05
316 2.489579887932019e-05
317 1.0085086729968072e-05
318 8.367431327130033e-06
319 1.3572758330655623e-05
320 5.68689382077081e-05
321 6.7395142610071e-05
322 9.901763473591494e-06
323 4.508848290971662e-05
324 5.827316655820221e-06
325 6.669222450960227e-06
326 6.913391211122049e-06
327 1.0004964198914403e-05
328 1.2625623535193193e-05
329 5.639715316943907e-06
330 1.508440702749698e-05
331 1.3333298079812224e-05
332 3.218000315598941e-06
333 2.555639237977814e-05
334 2.378300103636433e-05
335 3.514686146669534e-06
336 4.262017701662527e-05
337 1.2781279113850925e-05
338 4.604783064361106e-06
339 6.696252820867026e-06
340 1.2453605977169896e-05
341 7.669355392716338e-05
342 3.893150531779141e-05
343 2.832982527727049e-05
344 2.200083969371405e-05
345 1.1926343479955985e-05
346 1.2708392814709623e-05
347 2.9851457388476407e-05
348 9.954788427985106e-05
349 1.9498341076486975e-05
350 2.7571150937809293e-05
351 7.878802863910777e-06
352 5.251353430104134e-06
35

625 1.5386954486408285e-05
626 8.863668180485037e-06
627 6.0192553083969265e-05
628 1.620192031886254e-05
629 2.151102212118646e-05
630 2.0759182853483537e-05
631 3.285364568432383e-06
632 1.161155713728956e-05
633 1.5853644812058226e-05
634 9.88028380042247e-06
635 1.3806624155550921e-05
636 1.7081352508191312e-05
637 5.060495755174659e-06
638 5.189343987341153e-06
639 1.041077927531347e-05
640 1.533790129716369e-05
641 2.5510895900477262e-06
642 8.923485633064763e-06
643 1.8426023109438577e-05
644 8.879232028436735e-06
645 6.972037347204411e-06
646 5.440676794121655e-05
647 1.095612134647082e-05
648 5.461429269934824e-06
649 1.8980789077615793e-05
650 6.547220055108195e-06
651 9.946745308343937e-06
652 5.12108033965445e-06
653 1.1058128288822768e-05
654 8.174775470919727e-06
655 3.142925701196384e-05
656 1.37690159768341e-05
657 4.954505800623338e-05
658 4.775360658770548e-05
659 2.2655052757970273e-05
660 1.3390103112958842e-05
661 1.3101250922546801e-05
662 2.7151033774041362e-05
6

935 5.3607476542200994e-06
936 2.2828676799057884e-05
937 1.0758272665117404e-05
938 1.1723266800764762e-05
939 1.507994860235132e-05
940 2.4311476764357243e-05
941 2.4413924523955437e-05
942 4.063516461678237e-05
943 7.844560194328274e-06
944 1.6990812823774394e-05
945 5.9912558466821655e-05
946 1.3135953631145597e-05
947 3.368798523955513e-05
948 1.961452502411295e-05
949 1.6525719323441772e-05
950 1.5400375676623048e-05
951 7.5424561019663935e-06
952 2.3485466995640366e-05
953 2.2261270856811903e-05
954 1.5461068914599985e-05
955 1.5572714589463546e-05
956 1.689702401903385e-05
957 2.651161884310518e-05
958 1.0853683092407253e-05
959 1.2091465627806486e-05
960 1.3695328800636217e-05
961 2.284871528294917e-05
962 9.521116207671049e-06
963 2.436744385970845e-05
964 2.435988291477461e-05
965 2.83171486440419e-05
966 1.619698354311044e-05
967 1.2074690037549481e-05
968 1.3930507441434442e-05
969 1.3627690028423911e-05
970 1.847123767513086e-05
971 8.823822518152435e-05
972 1.66128650669

1237 9.802355629184643e-06
1238 4.739316665133267e-06
1239 9.324492944094681e-06
1240 2.048450560821774e-05
1241 6.34457832508402e-05
1242 1.2435126575535737e-05
1243 8.418319668285794e-06
1244 1.4043065392525878e-05
1245 4.2340716812184585e-05
1246 1.0525925995673015e-05
1247 6.538964475949429e-06
1248 3.803644628794541e-05
1249 1.6664097992791648e-05
1250 2.6165553274084625e-06
1251 3.5463686290980055e-05
1252 2.0508932300803782e-05
1253 5.21668758669916e-05
1254 3.1290114667206015e-05
1255 0.0004460755994841888
1256 2.9231965990999578e-05
1257 9.369800489286404e-05
1258 2.8171174521353137e-05
1259 8.884060790187184e-05
1260 4.566978426212855e-05
1261 2.303282736402974e-05
1262 3.819647810535002e-05
1263 1.680046157900038e-05
1264 4.542626900852608e-05
1265 3.762546495453901e-05
1266 1.0111978035892317e-05
1267 1.7347884814235344e-05
1268 5.852462173435181e-05
1269 1.6270300048088404e-05
1270 2.5245339904615325e-05
1271 6.494117331800318e-05
1272 1.7516666804628294e-05
1273 1.4164977

1537 2.0724853610520362e-05
1538 1.5559703246100616e-05
1539 1.8742444773575474e-05
1540 1.2774591924666263e-05
1541 2.7806796825234617e-05
1542 4.877196960679375e-05
1543 5.847517010815261e-05
1544 9.354676629575425e-06
1545 3.2349570297092785e-05
1546 1.1033982820129479e-05
1547 1.93296959179901e-05
1548 3.3225537202070027e-05
1549 1.87441627905038e-05
1550 6.500686930910731e-06
1551 2.3056667804470445e-05
1552 7.9276859212017e-06
1553 2.406135450208846e-05
1554 7.752783148749563e-06
1555 3.888926119890088e-05
1556 6.020741396822753e-06
1557 3.398835761360778e-05
1558 9.261899446135088e-06
1559 1.3943833360406702e-05
1560 2.22800426726736e-05
1561 1.1918481919180751e-05
1562 1.5569869562381864e-05
1563 1.8349998042826023e-05
1564 1.931034556389395e-05
1565 2.8468461276150532e-05
1566 5.494968423513561e-06
1567 1.1517433254697672e-05
1568 1.1460455731962114e-05
1569 1.2989086797512608e-05
1570 1.7050571864605736e-06
1571 1.0202412227495752e-05
1572 3.815166473051395e-05
1573 2.7788486

1836 1.8389575344325614e-05
1837 1.4389807469188258e-05
1838 5.385117952311922e-06
1839 5.025829742892396e-06
1840 1.1857432427513151e-05
1841 9.95336409876193e-06
1842 2.2121405476364836e-05
1843 5.60030958283321e-05
1844 3.229747732318015e-05
1845 7.488426582364411e-06
1846 1.463706679713115e-05
1847 1.9187559839072237e-05
1848 2.09275073666074e-05
1849 0.0006222743566979778
1850 7.686079921178352e-06
1851 2.0409907936194232e-05
1852 1.388385147913447e-05
1853 8.20708121805525e-06
1854 9.154333122800961e-06
1855 3.1228088280048866e-05
1856 2.2136034285745205e-05
1857 4.359066273137703e-06
1858 1.0921242925092178e-05
1859 4.3376216700401404e-05
1860 1.528555949863273e-05
1861 7.030523401395034e-06
1862 1.3689900747826987e-05
1863 2.7004372397600998e-05
1864 1.0630422837047656e-05
1865 1.025256703215432e-05
1866 6.856334238836165e-06
1867 2.6910741215777566e-05
1868 1.4270144516669325e-05
1869 6.428278242956338e-06
1870 1.650414604382704e-05
1871 7.902605541578294e-06
1872 3.8515414140

2137 8.338775971501661e-06
2138 2.1574934916678644e-05
2139 1.1244584541829575e-05
2140 1.3357321644266144e-05
2141 1.6483546911305853e-05
2142 6.130195999857028e-06
2143 7.838527730292356e-06
2144 1.6476342021779292e-05
2145 5.8497475652122425e-05
2146 2.710098313407715e-05
2147 1.8135465595801376e-05
2148 4.8982343524878275e-05
2149 4.714406617533299e-06
2150 7.153590292292243e-06
2151 2.5828710962843398e-05
2152 8.505732651301685e-06
2153 8.163866200176088e-05
2154 9.32227872903378e-05
2155 2.076688746186762e-05
2156 8.325515773646138e-06
2157 3.781782805823898e-05
2158 1.566597215114016e-05
2159 1.0727752340321808e-05
2160 2.2433084483019962e-05
2161 0.0001398332841231756
2162 1.2958796784223184e-05
2163 2.1591936101006457e-05
2164 2.5681790740877948e-05
2165 0.00011576067601595947
2166 2.604637093866618e-05
2167 2.3749245001392908e-05
2168 1.8485258719562175e-05
2169 8.470674916499804e-06
2170 2.0263689977790674e-05
2171 1.57055140410931e-05
2172 1.1686478194433832e-05
2173 2.8501

2436 2.2253230802523647e-05
2437 2.300899391343272e-05
2438 3.599500261728684e-05
2439 1.3422078421142624e-05
2440 1.06693846746665e-05
2441 1.574924613468531e-05
2442 3.0875682554667705e-05
2443 1.8649293395425702e-05
2444 5.809072713385194e-06
2445 7.121806369862746e-06
2446 1.2712413117665361e-05
2447 8.942007771853977e-06
2448 1.9958019640435645e-05
2449 1.1980158735406073e-05
2450 4.448033778422646e-05
2451 2.226737505724355e-05
2452 3.30214456173489e-05
2453 2.0415228455837927e-05
2454 2.398404639905058e-05
2455 4.340131945289436e-06
2456 1.847286891104602e-05
2457 3.3970119976042815e-05
2458 7.884357815919788e-06
2459 6.809382871859423e-06
2460 9.43169719861949e-06
2461 6.478129902381544e-05
2462 4.741102507828681e-06
2463 5.694352696268688e-05
2464 9.126275375876751e-05
2465 1.3863343255586111e-05
2466 3.081440951120089e-05
2467 3.197224390159858e-05
2468 1.541017509094805e-05
2469 1.298775489196751e-05
2470 2.0352798602777526e-05
2471 2.6665155416404636e-05
2472 1.596218297878

2736 2.35358747737678e-05
2737 8.778774970030287e-06
2738 1.7212456419199358e-05
2739 1.4784436618532732e-05
2740 2.2515961614755326e-05
2741 5.404556506855689e-05
2742 2.3246162115394684e-05
2743 1.4537731518155075e-05
2744 1.2810751979913311e-05
2745 6.196217047736181e-06
2746 2.7761275217969038e-05
2747 1.3007294385601412e-05
2748 1.2765511299823658e-05
2749 0.00010332975472306257
2750 1.0482256028926025e-05
2751 3.949649454853377e-06
2752 1.711577548166598e-05
2753 1.554060072164853e-05
2754 0.00021799708954009137
2755 1.2212249547999275e-05
2756 1.0174073147025e-05
2757 1.70564092979552e-05
2758 2.05386771666675e-05
2759 1.1384065335941127e-05
2760 2.805242309187867e-05
2761 5.220693116634106e-06
2762 3.082970106140981e-06
2763 2.1851178841636848e-05
2764 1.18569135259155e-05
2765 9.320364583562002e-06
2766 1.6076166365309102e-05
2767 9.477058327187651e-06
2768 1.5924070702402058e-05
2769 5.4287575913314275e-06
2770 1.4770514094949048e-05
2771 9.83567385521536e-05
2772 1.313652569

3036 6.683380744100513e-05
3037 3.2950000197433346e-05
3038 1.5144759718191329e-05
3039 2.9758526117510122e-05
3040 1.5315154614924686e-05
3041 1.7918644752419765e-05
3042 4.332191727607394e-05
3043 2.3982850338932123e-05
3044 1.3133023304383337e-05
3045 8.134844507563934e-06
3046 1.4679335878863223e-05
3047 3.0369025519225456e-05
3048 3.320363808710697e-05
3049 2.4306236966298462e-05
3050 1.2335904208405285e-05
3051 3.119091185987388e-05
3052 1.1794181237520064e-05
3053 1.957004865970446e-05
3054 7.972561749722277e-05
3055 2.36376593365102e-05
3056 9.499862026534305e-06
3057 1.4189727664053955e-05
3058 2.0813111277274693e-05
3059 1.0903810341101082e-05
3060 2.3835224693639754e-05
3061 3.599585834577912e-05
3062 1.5150225487677911e-05
3063 4.9886867195740025e-05
3064 1.1127270145226864e-05
3065 6.394036172450078e-06
3066 9.146077811495232e-06
3067 2.797188779474944e-06
3068 7.551009665780119e-06
3069 1.7432313120755488e-05
3070 4.817233494557104e-06
3071 1.5294342270556318e-05
3072 1.0

3336 1.0150202926357619e-05
3337 7.146883194692812e-06
3338 1.5934631846485045e-05
3339 0.00013217710524036306
3340 2.501533601361506e-05
3341 2.4549759796266378e-05
3342 5.309362257844817e-05
3343 2.422708786631178e-05
3344 1.5091043044700001e-05
3345 2.2728896620092493e-05
3346 1.8135771562839032e-05
3347 1.0911352028811118e-05
3348 2.803739945398195e-05
3349 2.4928381171070655e-05
3350 1.3078379119237313e-05
3351 3.172092746160052e-05
3352 3.821747295372113e-05
3353 1.1992701477526587e-05
3354 4.2564428304999493e-05
3355 1.848562713329431e-05
3356 1.9556605221913412e-05
3357 9.334386711666838e-06
3358 1.2721476526184527e-05
3359 5.716171684178851e-06
3360 2.4083739853599307e-05
3361 8.360179837601944e-06
3362 5.090138782015495e-06
3363 2.5566755009735056e-05
3364 8.241533934386678e-06
3365 5.219276297861884e-06
3366 2.1271603336116832e-05
3367 5.318719231459185e-06
3368 1.1159233835546018e-05
3369 1.0906929038091897e-05
3370 2.7685483928684885e-05
3371 8.167018609503192e-06
3372 1.2

3635 4.770819037755051e-06
3636 1.6710165003664792e-05
3637 2.182167109485257e-05
3638 2.764772916812828e-05
3639 3.273532130029616e-06
3640 2.4494452128982233e-05
3641 1.903405018810666e-05
3642 2.8820258484368108e-05
3643 8.796661092021497e-06
3644 4.831917668147168e-06
3645 3.208727597670764e-05
3646 0.00014574966812328538
3647 1.279974159004528e-05
3648 3.855088498401603e-05
3649 0.0001372155935243764
3650 1.1946159675574726e-05
3651 2.4894590267190878e-05
3652 1.3210570118777816e-05
3653 2.00345271378205e-05
3654 4.554216225609464e-05
3655 1.4538172853505447e-05
3656 4.320593519805203e-05
3657 2.875685662439801e-05
3658 1.3418306290880111e-05
3659 1.906232259331039e-05
3660 5.523106148992809e-06
3661 2.4246293659415413e-05
3662 1.3978587226109892e-05
3663 1.426819361601601e-05
3664 3.861583107300081e-05
3665 1.562273218887288e-05
3666 3.346810728427072e-05
3667 1.5514168739340444e-05
3668 8.717865299608032e-05
3669 1.6028484970841073e-05
3670 8.714783620213604e-06
3671 8.858661536

3935 2.015419238808992e-05
3936 1.1151243872544607e-05
3937 5.815495157353796e-05
3938 6.303938718262521e-05
3939 2.298568748447621e-05
3940 1.3376087409449803e-05
3941 3.212995064658717e-05
3942 1.7185452796692175e-05
3943 2.878139434141315e-05
3944 1.2294555966512707e-05
3945 2.4828562866558657e-05
3946 9.069945235678834e-06
3947 3.0141109946485437e-05
3948 3.995277460434014e-06
3949 1.954622284280078e-05
3950 1.709468494233012e-05
3951 0.0004207875937865373
3952 1.3107377990532473e-05
3953 1.8136584370777466e-05
3954 1.0250494773187126e-05
3955 3.0401156482245533e-05
3956 1.4334651599125917e-05
3957 1.2470785464011896e-05
3958 1.802941887947041e-05
3959 8.50036649274183e-06
3960 1.7623408970283336e-05
3961 1.2724397207636598e-05
3962 6.06031940784994e-06
3963 1.8851153080039882e-05
3964 9.35672182800561e-05
3965 9.667295910038595e-06
3966 1.3143077117611106e-05
3967 6.041978899059591e-06
3968 4.1772255100610575e-06
3969 5.147618097485813e-06
3970 1.283259306680213e-05
3971 1.0226562

4235 1.1269610925899918e-05
4236 2.5436576911309227e-06
4237 7.698222468466497e-06
4238 0.0007885165267753063
4239 1.546566739301153e-05
4240 6.242293567347782e-05
4241 4.1756255201958054e-05
4242 1.840246507202644e-05
4243 0.00013519535695158917
4244 4.446970169530664e-05
4245 5.13021045896997e-05
4246 4.3212461067418754e-05
4247 3.587651720019229e-05
4248 1.5255365098444511e-05
4249 1.314000031618847e-05
4250 2.0347476786086598e-05
4251 1.8876921226730055e-05
4252 2.0070529745892185e-05
4253 9.438455761277781e-06
4254 1.2900603195537019e-05
4255 8.78483990411079e-06
4256 2.45749416527098e-05
4257 1.63614523709273e-05
4258 4.679750898202384e-06
4259 2.909493907317194e-05
4260 1.255733567637536e-05
4261 1.6909297889058967e-05
4262 1.4635158786520926e-05
4263 6.71566363174318e-06
4264 1.1911446692835929e-05
4265 2.2031726148207597e-05
4266 2.1975714941613173e-05
4267 9.394347224996536e-06
4268 2.50223551775795e-05
4269 1.509889805753063e-05
4270 1.3838256544609055e-05
4271 1.15327155048

4535 1.2510901445165128e-05
4536 5.906503365738568e-06
4537 1.195590772943333e-05
4538 4.797148427686645e-06
4539 5.306617871841984e-06
4540 8.08796085538698e-06
4541 4.030802268369607e-05
4542 2.3820762857125714e-06
4543 5.8748718189516884e-06
4544 4.246113733233058e-06
4545 1.8915201825764147e-05
4546 7.5519433108969844e-06
4547 1.873644352893474e-05
4548 7.275990539180955e-06
4549 1.4830514055047028e-05
4550 7.2640598620260945e-06
4551 5.415247021346419e-05
4552 1.6183798811528013e-05
4553 1.0059741301339167e-05
4554 1.118220852167694e-05
4555 4.806195203968603e-06
4556 1.4031661060439082e-05
4557 5.739401809299415e-06
4558 2.429674822882717e-05
4559 1.5878523902294932e-05
4560 9.224904607500025e-06
4561 1.7295213844703734e-05
4562 8.793512750294773e-05
4563 4.893207020641049e-06
4564 1.4413560888829094e-05
4565 8.998588936905417e-06
4566 1.2433391152856446e-05
4567 1.1422599059182363e-05
4568 3.507972894972586e-06
4569 2.1753622289852405e-06
4570 1.4365537621743518e-05
4571 1.12871

4834 3.6374713289677416e-06
4835 2.5449587721725555e-05
4836 1.4037143326007419e-05
4837 9.341255319134655e-06
4838 1.6424901409721047e-05
4839 7.32014888094537e-06
4840 0.00010186310848831879
4841 6.274000447524018e-05
4842 0.00013189176785406538
4843 3.390823226529261e-05
4844 7.041858366502462e-06
4845 0.0005002957863398967
4846 7.012964566716758e-06
4847 1.703557235719942e-05
4848 2.3486085593516824e-05
4849 2.1304192835752054e-05
4850 2.6403158735930038e-05
4851 1.544765570673749e-05
4852 6.255051011489207e-06
4853 1.9246848364715837e-05
4854 1.1130328882137934e-05
4855 1.1779773694915647e-05
4856 2.4079872170448426e-05
4857 1.0369037658675086e-05
4858 1.3814290699895852e-05
4859 1.002301530719514e-05
4860 1.7908113514746636e-05
4861 3.398960876051764e-05
4862 1.1717248983384345e-05
4863 1.6842363656283027e-05
4864 3.59776038850769e-05
4865 6.398206897428698e-06
4866 9.149694914134479e-06
4867 2.1160714566991204e-05
4868 5.067913186196526e-06
4869 7.874911542697853e-06
4870 5.7994

In [134]:
fig = plot_asteroid2(model, encoding, N=30**3, bw=True, views_2d=True)
#plt.title("Do I look like an asteroid now?")

<IPython.core.display.Javascript object>

#### Saving the model

In [72]:
# Uncomment to save to models/cluster_xxxx
# torch.save(model.state_dict(), "models/" + largest_clusters[rank][1])

#### Loading the model

In [154]:
## It is important that the network architecture is compatible, otherwise this will fail
model.load_state_dict(torch.load("models/" + largest_clusters[rank][1]))


<All keys matched successfully>

# TODO list:

* Code efficiency -> move to GPU and make training scalable to more sample points / mc points.
* MC integration -> importance sampling maybe?
* Network architecture -> study different encodings
* How to visualize and interpret the results quantitatively.
* Propagate trajectories around the asteroids (ground truth and trained).
* Incorporate visual cues.
* Training with gravity rather than potential?
* What happens for non uniform bodies?