In [None]:
import sys
sys.path.append('/Users/kevin/Projects/CS236_Course_Project')

In [None]:
import plotly.graph_objs as go
import plotly as plt
import numpy as np
import torch

from src.model import Generator
from src.trainer import Trainer
from src.metrics import compute_metrics

In [None]:
BATCH_SIZE = 1
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
CKPT_PATH = "/Users/kevin/Projects/CS236_Course_Project/checkpoints/sinkhorn_energy_gaussian_laplacian_2023-12-02_14-11-58/100.pth"

In [None]:
# Load point cloud file
real_pc_one = torch.from_numpy(np.load("/Users/kevin/Projects/CS236_Course_Project/mock_data/Test/451927.8000000001_453201.865.npy")).unsqueeze(0)
real_pc_two = torch.from_numpy(np.load("/Users/kevin/Projects/CS236_Course_Project/mock_data/Test/451996.5849999999_453310.49.npy")).unsqueeze(0)

In [None]:
def visualize_point_cloud(data: np.ndarray):
    # Assuming your data is a NumPy array of shape [1000, 3]
    # Create 3D scatter plot
    trace = go.Scatter3d(
        x=data[:, 0],
        y=data[:, 1],
        z=data[:, 2],
        mode='markers',
        marker=dict(
            size=3,
            color=data[:, 2],  # You can use another column for color
            colorscale='Viridis',
            opacity=0.8
        )
    )

    # Create layout
    layout = go.Layout(scene=dict(aspectmode='data'))

    # Create figure
    fig = go.Figure(data=[trace], layout=layout)

    # Show the plot
    fig.show()
    
def plot_samples(samples, num=5, rows=2, cols=3):
    fig = plt.subplots.make_subplots(
        rows=rows,
        cols=cols,
        specs=[[{"type": "Scatter3d"} for _ in range(cols)] for _ in range(rows)],
    )

    for i, sample in enumerate(samples):
        fig.add_trace(
            plt.graph_objects.Scatter3d(
                x=sample[:, 0],
                y=sample[:, 2],
                z=sample[:, 1],
                mode="markers",
                marker=dict(size=3, opacity=0.8),
            ),
            row=i // cols + 1,
            col=i % cols + 1,
        )
    fig.update_layout(showlegend=False)
    return fig

In [None]:
# Load trained generator model checkpoint 

# Setup model
net_g = Generator()
net_g.eval()

# Setup trainer
trainer = Trainer(net_g=net_g, batch_size=BATCH_SIZE, device=DEVICE)

# Load checkpoint
trainer.load_checkpoint(CKPT_PATH)

In [None]:
_, z1_one = net_g(real_pc_one)
_, z1_two = net_g(real_pc_two)

In [None]:
# Check shape of z1 vectors
z1_one.shape

In [None]:
# Vector a is defined as going from z1_one to z1_two
a = z1_two.squeeze() - z1_one.squeeze()
a.shape

In [None]:
latent_codes = []
interpolation_steps = 3

for step in range(interpolation_steps+1):
    if step == 0:
        latent_codes.append(z1_one)
    else:
        latent_vector = z1_one + (step/interpolation_steps) * a
        latent_codes.append(latent_vector)

latent_codes.append(z1_two)

In [None]:
generated_samples = []

for z1 in latent_codes:
    
    decoded_output = net_g.decode(z1, 1, 500, DEVICE, interpolating=True).squeeze().detach().numpy()
    
    generated_samples.append(decoded_output)

plot_samples(generated_samples)

In [None]:
z1_one_output = net_g.decode(z1_one, 1, 500, DEVICE).squeeze().detach().numpy()
visualize_point_cloud(z1_one_output)

In [None]:
z1_two_output = net_g.decode(z1_two, 1, 500, DEVICE).squeeze().detach().numpy()
visualize_point_cloud(z1_two_output)

In [None]:
latent_one = torch.randn(1, 500, 512).to(DEVICE)

In [None]:
generated_one = net_g.decoder_network.latent_to_point_cloud(latent_one)
generated_one.shape

### TBD what is the correct approach:

net_g.decode() adds random noise to the encoded latent point cloud to generate new point cloud. We do this in an attempt to learn a meaningful latent space representation of point clouds. However, during inference time, my hypothesis would be that we do not add random noise and just generate point clouds from the interpolated latents directly (as indicated in the cell below)

In [None]:
# Example taken from Trainer.test()
generated_point_cloud, latent_point_cloud = net_g(real_point_cloud)
# generated_point_cloud has shape: torch.Size([1, 500, 3])
metrics_original = compute_metrics(generated_point_cloud, real_point_cloud, BATCH_SIZE)
generated_point_cloud = generated_point_cloud.squeeze().detach().numpy()

print(f"Original Metrics: {metrics_original}")
visualize_point_cloud(generated_point_cloud)

In [None]:
# My idea
# Wait, net_g.decode() also adds noise to the latent point cloud. So, we should just use the latent point cloud to generate the point cloud.
generated_output = net_g.decode(latent_point_cloud, 1, 500, DEVICE)
metrics_output = compute_metrics(generated_output, real_point_cloud, BATCH_SIZE)
generated_output = generated_output.squeeze().detach().numpy()
print(f"Output Metrics: {metrics_output}")
visualize_point_cloud(generated_output)