### Notebook for Testing Trained Models  
Cells are run sequentially:  
    **~ Cell 1**: Importing libraries, ensure correct filepath to the models.  
    **~ Cell 2**: Importing model, `supervised_import` determines whether to import a supervised model or semi-supervised model,   
    ...then the specific run can be selected by with the corresponding index as the `run_number`, or specifying the supervised run name.  
    **~ Cell 3**: Generate test data, parameters of the test data can be varied and distributions plotted.  
    **~ Cell 4**: Compute the Ricci tensors for the loaded model on the test data.  
    **~ Cell 5**: Plot visualisations of the loaded model metrics and computed Ricci tensors on the test data. Plotting parameters can be  
    ...edited, including setting the `plot_radial_limit` to reduce to the patch portions used for the global manifold definition.  
    **~ Cell 6**: Compute the Global test loss for the loaded model on the test data, printing components.  
    **~ Cell 7**: Plot visualisations of analytic metrics, either the identity initialisation or analytic round metric.  

In [None]:
import sys
print(sys.executable)

In [None]:
# Import libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import tensorflow as tf
import tensorflow_probability as tfp
tf.keras.backend.set_floatx('float64')

# Import relevant functions
from helper_functions.helper_functions import cholesky_from_vec, plot_fig
from sampling.ball import BallSample, CubeSample
from network.ball import BallGlobalModel, BallPatchSubModel
from geometry.common import compute_ricci_tensor
from geometry.ball import PatchChange_Coordinates_Ball, PatchChange_Metric_Ball, AnalyticMetric_Ball
from losses.ball import GlobalLossBall

# Output the list of saved semi-supervised runs (to select which to import with in the following cell)
root_runs_path = os.getcwd()+"/runs/" #...set the correct filepath
saved_runs = [run for run in os.listdir(root_runs_path) if os.path.isdir(os.path.join(root_runs_path, run))]
for run_name_idx, run_name in enumerate(saved_runs):
    print(run_name_idx, run_name)


In [None]:
# Import the model of interest
# Boolean toggle to select a semi-supervised or supervised model import
supervised_import = False

# Define semi-supervised model path
if not supervised_import:
    # Pick run number from saved_runs list printed from previous cell    
    run_number = 0 
    model_path = root_runs_path + saved_runs[run_number] + f"/final_model.keras"
    # Alternative import code if all epochs saved
    #epoch_number = 85 
    #model_path = root_runs_path + saved_runs[run_number] + f"/epoch_{epoch_number}_model.keras"

# Define supervised model path
else:
    # Set the name of the supervised model to import
    model_name = "identity2d2p"
    model_path = root_runs_path + f"supervised_model_{model_name}.keras" #...change path to this to import a supervised model instead

# Load the model
loaded_model = tf.keras.models.load_model(model_path,
                                          custom_objects={'BallGlobalModel': BallGlobalModel, 'BallPatchSubModel': BallPatchSubModel}
)

In [None]:
# Define the test sample hyperparameters and generate it
dim = loaded_model.dim
num_samples = int(1e4)
density_power = 1.  # ...the \alpha (>0) value in the beta function, values < 1 skew to extremeties, >1 skew to centre
patch_width = 1.    #...the radii of the patches considered

# Generate the test data
test_samples = [tf.convert_to_tensor(BallSample(
        num_samples, dimension=dim, patch_width=patch_width, density_power=density_power))]
# If the model is 2-patch copy the test data to the 2nd patch also
if loaded_model.n_patches > 1:
    test_samples.append(PatchChange_Coordinates_Ball(test_samples[0]))

# Plot the sample
fig, axes = plt.subplots(1, loaded_model.n_patches, figsize=(5*loaded_model.n_patches, 5))
if loaded_model.n_patches == 1:
    axes = [axes]
for i in range(loaded_model.n_patches):
    axes[i].set_title(f'Patch {i+1}')
    axes[i].scatter(test_samples[i][:,0], test_samples[i][:,1], alpha=0.1)
    axes[i].set_xlim(-1,1)
    axes[i].set_ylim(-1,1)
plt.tight_layout()

# Print warning that we consider only a section if dim > 2
if dim != 2:
    print(f'...note these images are 2d sections of the {dim}d patches!')


In [None]:
# Predict the metric and the Ricci on the test sample in both patches
predicted_vielbeins = [loaded_model.patch_submodels[patch_idx].predict(test_samples[patch_idx]) for patch_idx in range(loaded_model.n_patches)]
predicted_metrics = [cholesky_from_vec(predicted_vielbeins[patch_idx]).numpy() for patch_idx in range(loaded_model.n_patches)]
predicted_riccis = [compute_ricci_tensor(test_samples[patch_idx], loaded_model.patch_submodels[patch_idx]).numpy() for patch_idx in range(loaded_model.n_patches)]


In [None]:
# Plotting: examine the trained metrics / ricci tensors
# Set the plotting hyperparameters
metric_index_1, metric_index_2 = 0, 0 #...choose the metric component to plot
dim_x, dim_y = 0, 1      #...select the input dimensions to plot
plot_radial_limit = 0.6  #...range limit for plotting ricci --> set to patch_width for no restriction
save_pdf_filename = None #...select the filename for this visualisation as a pdf, None means do not save

# Setup the plots
elevation_angle, azimuth_angle = 30, 45 #...set the orientation of the 3d plots
fig, axes = plt.subplots(2, loaded_model.n_patches, figsize=(5*loaded_model.n_patches,10), subplot_kw={'projection': '3d'})
if loaded_model.n_patches == 1:
    axes = np.expand_dims(axes,-1)

# Restrict the plotted points to a max radius (avoiding numerical instabilities) 
if plot_radial_limit < patch_width:
    masks = [tf.sqrt(tf.reduce_sum(tf.square(ts), axis=-1)) < plot_radial_limit for ts in test_samples]
    test_samples_limited = [tf.boolean_mask(test_samples[patch_idx], masks[patch_idx]) for patch_idx in range(loaded_model.n_patches)]
    predicted_metrics_limited = [tf.boolean_mask(predicted_metrics[patch_idx], masks[patch_idx]) for patch_idx in range(loaded_model.n_patches)]
    predicted_riccis_limited = [tf.boolean_mask(predicted_riccis[patch_idx], masks[patch_idx]) for patch_idx in range(loaded_model.n_patches)]
else:
    test_samples_limited, predicted_metrics_limited, predicted_riccis_limited = test_samples, predicted_metrics, predicted_riccis
    
### Metric ###
# Plot the metric component in each patch
for i in range(loaded_model.n_patches):
    axes[0,i].scatter(test_samples_limited[i][:,dim_x], test_samples_limited[i][:,dim_y],
            predicted_metrics_limited[i][:, metric_index_1, metric_index_2],
            c=predicted_metrics_limited[i][:, metric_index_1, metric_index_2],
            cmap="viridis")
    axes[0,i].set_title(rf"$g_{{{metric_index_1},{metric_index_2}}}$ Train (patch {i+1})")
    axes[0,i].set_xlim(-plot_radial_limit*1.2,plot_radial_limit*1.2)
    axes[0,i].set_ylim(-plot_radial_limit*1.2,plot_radial_limit*1.2)
    axes[0,i].set_xlabel(r'$x_1$')
    axes[0,i].set_ylabel(r'$x_2$')
    axes[0,i].view_init(elev=elevation_angle, azim=azimuth_angle) 

### Ricci ###
# Plot the ricci component in each patch
for i in range(loaded_model.n_patches):
    axes[1,i].scatter(test_samples_limited[i][:,dim_x], test_samples_limited[i][:,dim_y],
            predicted_riccis_limited[i][:, metric_index_1, metric_index_2],
            c=predicted_riccis_limited[i][:, metric_index_1, metric_index_2],
            cmap="viridis")
    axes[1,i].set_title(rf"$R_{{{metric_index_1},{metric_index_2}}}$ Train (patch {i+1})")
    axes[1,i].set_xlim(-plot_radial_limit*1.2,plot_radial_limit*1.2)
    axes[1,i].set_ylim(-plot_radial_limit*1.2,plot_radial_limit*1.2)
    axes[1,i].set_xlabel(r'$x_1$')
    axes[1,i].set_ylabel(r'$x_2$')
    axes[1,i].view_init(elev=elevation_angle, azim=azimuth_angle) 

# Save the plots
plt.tight_layout()
if save_pdf_filename:
    plt.savefig(save_pdf_filename)
    

In [None]:
# Compute the Global test loss
# Set the radial limit for each patch
radial_midpoint = np.sqrt(3. - 2 * np.sqrt(2.))
rl = radial_midpoint + 0.1

# Define the loss
global_loss = GlobalLossBall(loaded_model.hp, radial_limit=rl)

# Compute the full network output on the test sample
metric_pred = loaded_model.predict(test_samples[0])

# Compute the losses, and print
global_loss_value, global_loss_data, global_sample_sizes = global_loss.call(loaded_model, test_samples[0], metric_pred)
print(f'Global loss: {global_loss_value}\nLoss data:\n{global_loss_data}\nSample sizes ([patches], overlap): {global_sample_sizes}')


In [None]:
# Plot the analytic metrics 
identity_bool = False #...select whether the analytic metric is the identity non-geometric metric (True), or round metric (False)

# Compute the analytic metrics on the test data
analytic_metrics = [AnalyticMetric_Ball(test_samples[patch_idx], identity=identity_bool) for patch_idx in range(loaded_model.n_patches)]

# Choose the metric component to plot
metric_index_1, metric_index_2 = 0, 0

# Setup the plots
elevation_angle, azimuth_angle = 30, 45 #...set the orientation of the 3d plots
fig, axes = plt.subplots(1, loaded_model.n_patches, figsize=(5*loaded_model.n_patches, 5), subplot_kw={'projection': '3d'})
if loaded_model.n_patches == 1:
    axes = [axes]

# Plot the metric component in each patch
for i in range(loaded_model.n_patches):
    axes[i].scatter(test_samples[i][:,dim_x], test_samples[i][:,dim_y],
            analytic_metrics[i][:, metric_index_1, metric_index_2],
            c=analytic_metrics[i][:, metric_index_1, metric_index_2],
            cmap="viridis")
    axes[i].set_title(rf"$g_{{{metric_index_1},{metric_index_2}}}$ Analytic (patch {i+1})")
    axes[i].set_xlabel(r'$x_1$')
    axes[i].set_ylabel(r'$x_2$')
    axes[i].view_init(elev=elevation_angle, azim=azimuth_angle) 

plt.tight_layout()