# Training on data from the Osiris-Rex mission

In this notebook we perform the training of a GeodesyNet using the observed trajectories of detached pebbles.

During the time of observation, Bennu ejected multiple small rock pebbles. Many of these stayed in orbit around Bennu for several days before either falling back to Bennu's surface or escaping its gravitational influence. These trajectories yield additional samples of the gravity field. However, due to their small size, they have a high surface-to-mass ratio and radiation effects play a large role, adding substantial unmodelled effects to their trajectories and thus to the value of the purely gravitational acceleration that can be computed from them.

NOTE: With respect to a normal training (see Starter Notebook) the difference is only on the dataset used. In the Starter Notebook we use synthetically generated data provided via a sampler. Here we use data precomputed from the real observed pebble trajectories. Here we do not make use of any prior knowledge on Bennu shape model.

We suggest to run this notebook in the same conda environment as the one described in the Starter Notebook.

In [None]:
# core stuff
import gravann
import numpy as np
import pickle as pk
import os
from collections import deque

# pytorch
from torch import nn
import torch

# plotting stuff
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
%matplotlib notebook

# Ensure that changes in imported module (gravann most importantly) are autoreloaded
%load_ext autoreload
%autoreload 2

# If possible enable CUDA
gravann.enableCUDA()
gravann.fixRandomSeeds()
device = os.environ["TORCH_DEVICE"]
print("Will use device ",device) 


# Loading and visualizing the ground truth asteroid (a point cloud)

In [None]:
mascon_points,mascon_masses, mascon_masses_nu = gravann.load_sample("Bennu.pk")

In [None]:
gravann.plot_mascon(mascon_points, mascon_masses)

## Loading the trajectory data, created on a different notebook from SPICE kernels

In [None]:
with open("osirisrex/bennu_pebbles_filtered.pk", "rb") as file:
    _,_,data_points_p, data_labels_p = pk.load(file)
data_points_p=torch.tensor(data_points_p)
data_labels_p=torch.tensor(data_labels_p)

# Representing an asteroid via a neural network


## 1 - Instantiating the network
The networks inputs are the cartesian coordinates of a point in the unit cube, encoded via some transformation

In [None]:
# Encoding choosen
encoding = gravann.direct_encoding()

# For "normal" training
model = gravann.init_network(encoding, model_type="siren")

# For differnential training
# model = gravann.init_network(encoding, model_type="siren", activation = nn.Tanh())

# When a new network is created we init empty training logs
loss_log = []
weighted_average_log = []
running_loss_log = []
n_inferences = []
# .. and we init some loss trend indicators
weighted_average = deque([], maxlen=20)

In [None]:
# IF YOU NOW WANT TO LOAD THE ALREADY TRAINED NETWORK UNCOMMENT HERE.
## It is important that the network architecture is compatible, otherwise this will fail
#model.load_state_dict(torch.load("FILENAME"))

## Visualizing an asteroid represented by the network
The network output is the density in the unit cube. It is, essentially, a three dimensional function. (Does not work with differetial training!)

In [None]:
gravann.plot_model_rejection(model, encoding, views_2d=False, N=2500, progressbar=True, c=10)
plt.title("Believe me, I am an asteroid")

# Training The ANN to match the ground truth potential

Let it run up to when its < 1e-3 to actually see something that resembles the original asteroid. When stuck increase the number of monte carlo samples or play around the learning rate.

In [None]:
# EXPERIMENTAL SETUP ------------------------------------------------------------------------------------
# Here we set some hyperparameters
N_integration = 30000
batch_size = 100

# Here we set the loss
loss_fn = gravann.normalized_L1_loss

# Here we set the choosen Integration method
integrator = gravann.ACC_trap

# Here we set the optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.8, patience=200, min_lr=1e-6, verbose=True)


In [None]:
# This cell can be stopped and started again without loosing memory of the training nor its indicators
torch.cuda.empty_cache()
plt.close('all')
# The main training loop
for i in range(5000):
    # Sample random data from our observations,
    # This might want to into a separate sampler in
    # _sample_observation_points.py and an associated 
    # labels samplers in masocn_labels.py
    idxs = np.random.choice(np.arange(len(data_points_p)), batch_size, replace=False)
    target_points = data_points_p[idxs]
    labels = data_labels_p[idxs]
    
    # Compute the loss (use N=3000 to start with, then, eventually, beef it up to 200000)
    predicted = integrator(target_points, model, encoding, N=N_integration)
    c = torch.sum(predicted*labels)/torch.sum(predicted*predicted)
    if loss_fn == gravann.contrastive_loss or loss_fn == gravann.normalized_relative_component_loss:
       loss = loss_fn(predicted, labels)
    else:
       loss = loss_fn(predicted.view(-1), labels.view(-1))
    
    # Update the loss trend indicators
    weighted_average.append(loss.item())
    weighted_average_log.append(np.mean(weighted_average))
    loss_log.append(loss.item())
    n_inferences.append((N_integration*batch_size) // 1000)
    
    # Print every i iterations
    if i % 25 == 0:
        wa_out = np.mean(weighted_average)
        print(f"It={i}\t loss={loss.item():.3e}\t  weighted_average={wa_out:.3e}\t  c={c:.3e}")
    
    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()
    
    # Perform a step in LR scheduler to update LR
    scheduler.step(loss.item())

In [None]:
# Rejection plot
gravann.plot_model_rejection(model, encoding, views_2d=True, bw=True, N=1500, alpha=0.1, s=50, c=c, crop_p=0.1, progressbar=True)

In [None]:
# Plot the loss history
plt.figure()
abscissa = np.cumsum(n_inferences)
plt.semilogy(abscissa, loss_log)
plt.semilogy(abscissa, weighted_average_log)
plt.xlabel("Thousands of model evaluations")
plt.ylabel("Loss")
plt.legend(["Loss","Weighted Average Loss"])

In [None]:
# Rejection plot overlayed with the mascon
gravann.plot_model_vs_mascon_contours(model, encoding, mascon_points, mascon_masses,c=c, progressbar = True, N=2500, heatmap=False)

In [None]:
# Compute the acceleration plot
gravann.plot_model_mascon_acceleration("3dmeshes/Bennu.pk", model, encoding, mascon_points, mascon_masses, plane="XY", c=c, N=5000, logscale=False)

In [None]:
gravann.plot_model_mascon_acceleration("3dmeshes/Bennu.pk", model, encoding, mascon_points, mascon_masses, plane="XZ", c=c, N=5000, logscale=False)

In [None]:
gravann.plot_model_mascon_acceleration("3dmeshes/Bennu.pk", model, encoding, mascon_points, mascon_masses, plane="YZ", c=c, N=5000, logscale=False)

#### Saving the model

In [None]:
# Uncomment to save to models
#torch.save(model.state_dict(), "models/siren_acc_bennu.mdl")