In [7]:
from src.get_modelnet40.load_data import get_dls_for_viz, get_train_test_dls
from src.it_net.it_net import ITNet
import torch

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import pyvista as pv
from tqdm import tqdm

In [2]:
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import random

In [3]:
from src.pretrain_utils.transforms import quaternion_to_matrix, apply_tfm, compose_tfms
from src.pretrain_utils.corruptions import tfm_from_rand_pose, create_random_transform

In [4]:
device = "mps"
batch_size = 256

In [5]:
checkpoint = torch.load("results/it_net_continue/checkpoint.pth")

model = ITNet(channel=3, num_iters=5).to(device)
model.load_state_dict(checkpoint["model"])
model = model.eval()

In [8]:
train_loader, val_loader, test_loader = get_train_test_dls(batch_size=batch_size)
batch = next(iter(test_loader))

In [9]:
batch_point_clouds = batch["pointcloud"].to(torch.float32).to(device)
batch_point_clouds = batch_point_clouds.transpose(1, 2)

In [10]:
tfm = create_random_transform(batch_size, 30, 0, batch_point_clouds.dtype).to(batch_point_clouds.device)
view_1 = apply_tfm(batch_point_clouds, tfm)

view_1_post, _, _ = model(view_1)

In [11]:
tfm = create_random_transform(batch_size, 30, 0, view_1.dtype).to(view_1.device)
view_2 = apply_tfm(view_1, tfm)

view_2_post, _, _ = model(view_2)
view_2_post.shape

torch.Size([256, 3, 1024])

# Visualize

In [12]:
def visualize_rotate(data):
    x_eye, y_eye, z_eye = 1.25, 1.25, 0.8
    frames=[]

    def rotate_z(x, y, z, theta):
        w = x+1j*y
        return np.real(np.exp(1j*theta)*w), np.imag(np.exp(1j*theta)*w), z

    for t in np.arange(0, 10.26, 0.1):
        xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)
        frames.append(dict(layout=dict(scene=dict(camera=dict(eye=dict(x=xe, y=ye, z=ze))))))
    fig = go.Figure(data=data,
        layout=go.Layout(
            updatemenus=[dict(type='buttons',
                showactive=False,
                y=1,
                x=0.8,
                xanchor='left',
                yanchor='bottom',
                pad=dict(t=45, r=10),
                buttons=[dict(label='Play',
                    method='animate',
                    args=[None, dict(frame=dict(duration=50, redraw=True),
                        transition=dict(duration=0),
                        fromcurrent=True,
                        mode='immediate'
                        )]
                    )
                ])]
        ),
        frames=frames
    )

    return fig


def pcshow(xs_list, ys_list, zs_list, min_xyz, max_xyz):
    # Create subplots with 1 row and 2 columns
    fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]])

    # Iterate over the provided point clouds
    for i, (xs, ys, zs) in enumerate(zip(xs_list, ys_list, zs_list), start=1):
        # Add Scatter3d trace for each point cloud to respective subplot
        fig.add_trace(go.Scatter3d(x=xs, y=ys, z=zs, mode='markers', 
                                    name=f'Point Cloud {i}', marker=dict(size=2)), row=1, col=i)

    # Update layout with axis ranges
    fig.update_layout(scene=dict(aspectmode='data',
                                 aspectratio=dict(x=1, y=1, z=1),
                                 xaxis=dict(range=[min_xyz[0], max_xyz[0]]),
                                 yaxis=dict(range=[min_xyz[1], max_xyz[1]]),
                                 zaxis=dict(range=[min_xyz[2], max_xyz[2]])))
    
    # Show the figure
    fig.show()

In [13]:
min_xyz = [-1,-1,-1]
max_xyz = [1,1,1]

In [21]:
indices = list()
wanted_category = 1

for index in range(0, batch_size):
    category = batch["category"][index].item()
    if category == wanted_category:
        indices.append(index)

print(indices)

index = random.choice(indices)
index

# batch["category"][index]

[100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149]


105

In [22]:
cloud1 = view_1[index,:,:].transpose(0, 1).detach().cpu().numpy()
x1 = cloud1[:,0]
y1 = cloud1[:,1]
z1 = cloud1[:,2]

cloud2 = view_1_post[index,:,:].transpose(0, 1).detach().cpu().numpy()
x2 = cloud2[:,0]
y2 = cloud2[:,1]
z2 = cloud2[:,2]


# Example usage with two point clouds
pcshow([x1, x2], [y1, y2], [z1, z2], min_xyz, max_xyz)

In [24]:
cloud1 = view_2[index,:,:].transpose(0, 1).detach().cpu().numpy()
x1 = cloud1[:,0]
y1 = cloud1[:,1]
z1 = cloud1[:,2]

cloud2 = view_2_post[index,:,:].transpose(0, 1).detach().cpu().numpy()
x2 = cloud2[:,0]
y2 = cloud2[:,1]
z2 = cloud2[:,2]


# Example usage with two point clouds
pcshow([x1, x2], [y1, y2], [z1, z2], min_xyz, max_xyz)