# Visualization
This notebooks serves to visualize results of an individual run in detail.

In [None]:
# core stuff
import gravann,os
import numpy as np
import pickle as pk
from tqdm import tqdm
import torch

# For animation etc.
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython import display
import matplotlib.pyplot as plt
import imageio

# Ensure that changes in imported module (gravann most importantly) are autoreloaded
%load_ext autoreload
%autoreload 2

# If possible enable CUDA
gravann.enableCUDA()
gravann.fixRandomSeeds()
device = os.environ["TORCH_DEVICE"]
print("Will use device ",device)

# Define the run folder

In [None]:
results_folder = "results/siren_all_runs_ACC_siren_diff_train\Hollow.pk\LR=0.0001_loss=normalized_L1_loss_encoding=direct_encoding_batch_size=1000_target_sample=spherical_activation=Tanh_omega=3e+01/"
differential_training = True

model, encoding, sample, c, use_acc, mascon_points, mascon_masses_u, mascon_masses_nu, cfg = gravann.load_model_run(
        results_folder, differential_training)

torch.cuda.empty_cache()


In [None]:
import pyvista as pv
with open("3dmeshes/"+sample, "rb") as file:
    verts,triangles = pk.load(file)
    faces = [[3, t[0], t[1], t[2]] for t in triangles]

    # Create PV Polydata
    mesh = pv.PolyData(np.asarray(verts), np.asarray(faces))
volume = mesh.volume
if sample == "Hollow.pk": #subtract the sphere inside
    volume = volume - 0.37465678565
print("Volume=",volume)

In [None]:
pv.set_plot_theme("document")
v = np.array(verts)

#Itokawa
# mask = (v[:, 0] - 0.5*v[:, 2]) + (np.random.random((len(verts), ))*2-1)*1e-1 > 0.4 
# mask = (v[:, 0] - 0.5*v[:, 2]) > 0.4 

#Bennu
# mask = np.logical_or(v[:,2]+(np.random.random((len(verts), ))*2-1)*1e-1 > 0.25,  v[:,2] + (np.random.random((len(verts), ))*2-1)*1e-1 < -0.25)
# mask = np.logical_or(v[:,2] > 0.25,  v[:,2] < -0.25)

#Hollow
sphere = pv.Sphere(radius=0.4472,center=(0.2,0,0))
submesh = mesh.copy()
cell_center = mesh.cell_centers().points
mask2 =  np.logical_or(cell_center[:,1] > -0.25, cell_center[:,0] > 0.25)
mask3 =  np.logical_and(cell_center[:,1] < -0.25, cell_center[:,0] < 0.25)
cell_ind = mask2.nonzero()[0]
mesh = mesh.extract_cells(cell_ind)
cell_ind = mask3.nonzero()[0]
submesh = submesh.extract_cells(cell_ind)
mask = [1]*sum(mask2)

mesh["Density"] = mask

In [None]:
# import pyvistaqt as pvqt
# p = pvqt.BackgroundPlotter()
p = pv.Plotter(notebook=True)
p.show_axes()
# Controlling the text properties
sargs = dict(
    title_font_size=20,
    label_font_size=16,
    shadow=True,
    n_labels=0,
    italic=True,
    fmt="%.1f",
    font_family="arial",
)

if sample != "Hollow.pk":
    p.add_mesh(mesh,scalars="Density",cmap=["lightgrey","gray"],lighting=True, smooth_shading=False,
               show_scalar_bar=True,scalar_bar_args=sargs,
               annotations={0: "                  Hollow", 1: "Uniform"})
if sample == "Hollow.pk":
    p.add_mesh(mesh,color="gray",lighting=True, smooth_shading=False)
    p.add_mesh(sphere,color="lightgrey",lighting=True, smooth_shading=True)
    p.add_mesh(submesh,"lightgrey","wireframe")
p.camera_position = [(0,-4,0), (0,0,0,), (0,0,1)]
p.camera_set = True

# light = pv.Light()
# light.set_direction_angle(30, 0)
# p.add_light(light)
p.show(screenshot="figures/"+sample+"_nu.png", window_size=[800,800])

In [None]:
gravann.plot_model_vs_mascon_contours(model,encoding,mascon_points,mascon_masses_u,c=c,heatmap=True,save_path="figures/"+sample+".png",mascon_alpha=0.175);

In [None]:
#differential
gravann.plot_model_vs_mascon_contours(model,encoding,mascon_points,mascon_masses_u,c=c,
                                      heatmap=True,save_path="figures/"+sample+".png",
                                      mascon_alpha=0.175,add_shape_base_value="3dmeshes/"+sample,
                                      add_const_density=1./volume);

In [None]:
gravann.plot_model_mascon_acceleration("3dmeshes/"+sample,model,encoding,mascon_points,mascon_masses_u,plane="XY", altitude=0.05,N=5000,c=c,logscale=True);
gravann.plot_model_mascon_acceleration("3dmeshes/"+sample,model,encoding,mascon_points,mascon_masses_u,plane="XZ", altitude=0.05,N=5000,c=c,logscale=True);
gravann.plot_model_mascon_acceleration("3dmeshes/"+sample,model,encoding,mascon_points,mascon_masses_u,plane="YZ", altitude=0.05,N=5000,c=c,logscale=True);

In [None]:
mascon_alpha = 0
images = []
for offset in tqdm(np.linspace(-0.75,0.75,10)):
    fig = gravann.plot_model_vs_mascon_contours(model,encoding,mascon_points,mascon_masses,c=c,offset=offset,heatmap=True,mascon_alpha=mascon_alpha);
    fig.canvas.draw();
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    images.append(image)
    plt.close("all")

In [None]:
gifPath = "gifs/contourf.gif"
imageio.mimsave(gifPath, images)

In [None]:
# Display GIF in Jupyter, CoLab, IPython
with open(gifPath,'rb') as f:
    display.Image(data=f.read(), format='png')