In [2]:
# %%
# Import necessary libraries
import sys
import torch
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import time
import os
import importlib.util
import ipywidgets as widgets
from IPython.display import display
import plotly.graph_objects as go


In [3]:
# %%
# Specify the path to run.py
run_py_path = "./run.py"  # Adjust this path if run.py is located elsewhere

# Check if run.py exists
if not os.path.exists(run_py_path):
    raise FileNotFoundError(f"Cannot find run.py at {run_py_path}")

# Dynamically load run.py as a module named 'grig'
spec = importlib.util.spec_from_file_location("grig", run_py_path)
grig = importlib.util.module_from_spec(spec)
spec.loader.exec_module(grig)

print("Successfully imported run.py as 'grig' module.")
# %%
# Define the path to the trained clustering.npz file
clustering_filepath = "clustering.npz"  # Adjust the path if necessary

# Check if clustering.npz exists
if not os.path.exists(clustering_filepath):
    raise FileNotFoundError(f"Cannot find clustering.npz at {clustering_filepath}")

# Load scene_data and clustering from the saved file
scene_data, clustering = grig.load_scene_and_clustering(clustering_filepath)

print("Loaded scene_data and clustering from", clustering_filepath)
print(f"Number of timesteps: {len(scene_data)}")
print(f"Number of clusters: {clustering.num_clusters}")


Successfully imported run.py as 'grig' module.
Loaded scene_data and clustering from clustering.npz
Number of timesteps: 150
Number of clusters: 28


In [8]:
import numpy as np
import plotly.graph_objects as go
import matplotlib.cm as cm
import os
import importlib.util

# %%
# Specify the path to run.py
run_py_path = "./run.py"

# Check if run.py exists
if not os.path.exists(run_py_path):
    raise FileNotFoundError(f"Cannot find run.py at {run_py_path}")

# Dynamically load run.py as a module named 'grig'
spec = importlib.util.spec_from_file_location("grig", run_py_path)
grig = importlib.util.module_from_spec(spec)
spec.loader.exec_module(grig)

print("Successfully imported run.py as 'grig' module.")
assignments_filepath = "cluster_assignments.npz"

# 2. Validate the assignments to ensure all required body parts are assigned
required_parts = [
    grig.BodyPart.TORSO, grig.BodyPart.HEAD, grig.BodyPart.NECK, 
    grig.BodyPart.UPPER_ARM_LEFT, grig.BodyPart.UPPER_ARM_RIGHT,
    grig.BodyPart.LOWER_ARM_LEFT, grig.BodyPart.LOWER_ARM_RIGHT, 
    grig.BodyPart.HAND_LEFT, grig.BodyPart.HAND_RIGHT,
    grig.BodyPart.UPPER_LEG_LEFT, grig.BodyPart.UPPER_LEG_RIGHT, 
    grig.BodyPart.LOWER_LEG_LEFT, grig.BodyPart.LOWER_LEG_RIGHT,
    grig.BodyPart.FOOT_LEFT, grig.BodyPart.FOOT_RIGHT, 
    grig.BodyPart.CHEST, grig.BodyPart.WAIST, 
    grig.BodyPart.SHOULDERS_LEFT, grig.BodyPart.SHOULDERS_RIGHT, 
    grig.BodyPart.CLAVICLE_LEFT, grig.BodyPart.CLAVICLE_RIGHT
]

ignored_clusters = [1, 16, 24, 4, 22, 27, 8, 18]

cluster_assignments = grig.assign_clusters_to_body_parts_enforced(clustering, scene_data, threshold=1, ignored_clusters=ignored_clusters)

# Uncomment this line to see the best result saved in the .npz file
# cluster_assignments = grig.assign_clusters_to_body_parts_from_file(assignments_filepath)

if not grig.validate_assignments(cluster_assignments, required_parts):
    print("Warning: Not all required body parts have been assigned.")

# 3. Find skeleton chains based on the hard-coded cluster assignments
cluster_chains = grig.find_skeleton_chains(clustering, cluster_assignments)

# 4. Compute joints based on the skeleton chains
joints_t = grig.compute_joints(clustering, cluster_chains)

print(f"Number of joint chains found: {len(cluster_chains)}")
print(f"Number of joints computed: {joints_t.shape[1]}")


Successfully imported run.py as 'grig' module.
Combined feature shape: (28, 9)
Ignored Cluster 1 has been excluded from assignments.
Ignored Cluster 16 has been excluded from assignments.
Ignored Cluster 24 has been excluded from assignments.
Ignored Cluster 4 has been excluded from assignments.
Ignored Cluster 22 has been excluded from assignments.
Ignored Cluster 27 has been excluded from assignments.
Ignored Cluster 8 has been excluded from assignments.
Ignored Cluster 18 has been excluded from assignments.
Assigned BodyPart.TORSO to cluster 5.

Processing BodyPart.torso (Cluster 5)
Assigned BodyPart.clavicle_left to Cluster 25 (Distance: 0.373).
Assigned BodyPart.clavicle_right to Cluster 0 (Distance: 0.375).
Assigned BodyPart.head to Cluster 3 (Distance: 0.386).
Assigned BodyPart.upper_leg_left to Cluster 13 (Distance: 0.287).
Assigned BodyPart.upper_leg_right to Cluster 20 (Distance: 0.313).

Processing BodyPart.clavicle_left (Cluster 25)
Assigned BodyPart.shoulders_left to Clust

In [9]:
import numpy as np
import plotly.graph_objects as go
import matplotlib.cm as cm
import os
import importlib.util

centers_t = clustering.centers.cpu().numpy()  # Shape: (T, K, 3)
joints = joints_t  # Shape: (T, num_joints, 3)
num_timesteps = centers_t.shape[0]
num_clusters = centers_t.shape[1]
num_joints = joints.shape[1]

cmap = cm.get_cmap('turbo', num_clusters)
cluster_colors = [cmap(i)[:3] for i in range(num_clusters)]

joint_color = [1, 0, 0]

chain_lines = []
for chain in cluster_chains:
    if len(chain) < 2:
        continue
    for i in range(len(chain) - 1):
        start_idx = chain[i]
        end_idx = chain[i + 1]
        if start_idx in ignored_clusters or end_idx in ignored_clusters:
            continue
        chain_lines.append((start_idx, end_idx))

def create_cluster_trace(centers, colors, ignored):
    mask = ~np.isin(range(len(centers)), ignored)
    trace = go.Scatter3d(
        x=centers[mask, 0],
        y=centers[mask, 1],
        z=centers[mask, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=np.array(colors)[mask],
            opacity=0.8
        ),
        name='Cluster Centers'
    )
    return trace

def create_joints_trace(joints_at_t):
    trace = go.Scatter3d(
        x=joints_at_t[:, 0],
        y=joints_at_t[:, 1],
        z=joints_at_t[:, 2],
        mode='markers',
        marker=dict(
            size=3,
            color=joint_color,
            opacity=0.9
        ),
        name='Joints'
    )
    return trace

def create_chains_trace(centers_at_t, chain_lines):
    if not chain_lines:
        return None
    x_lines = []
    y_lines = []
    z_lines = []
    for start, end in chain_lines:
        x_lines.extend([centers_at_t[start, 0], centers_at_t[end, 0], None])
        y_lines.extend([centers_at_t[start, 1], centers_at_t[end, 1], None])
        z_lines.extend([centers_at_t[start, 2], centers_at_t[end, 2], None])
    
    trace = go.Scatter3d(
        x=x_lines,
        y=y_lines,
        z=z_lines,
        mode='lines',
        line=dict(
            color='black',
            width=2
        ),
        name='Chains'
    )
    return trace

initial_centers = centers_t[0]
initial_joints = joints[0]

cluster_trace = create_cluster_trace(initial_centers, cluster_colors, ignored_clusters)
joints_trace = create_joints_trace(initial_joints)
chains_trace = create_chains_trace(initial_centers, chain_lines)

frames = []
for t in range(num_timesteps):
    frame_traces = [
        create_cluster_trace(centers_t[t], cluster_colors, ignored_clusters),
        create_joints_trace(joints[t]),
        create_chains_trace(centers_t[t], chain_lines)
    ]
    frames.append(go.Frame(data=frame_traces, name=str(t)))

fig = go.Figure(
    data=[cluster_trace, joints_trace, chains_trace],
    layout=go.Layout(
        title="Cluster Centers, Joints, and Chains",
        width=600,
        height=600,
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z'),
            aspectmode='data'
        ),
        updatemenus=[
            dict(
                type="buttons",
                buttons=[
                    dict(label="Play",
                         method="animate",
                         args=[None, {"frame": {"duration": 50, "redraw": True},
                                      "fromcurrent": True}]),
                    dict(label="Pause",
                         method="animate",
                         args=[[None], {"frame": {"duration": 0, "redraw": False},
                                        "mode": "immediate",
                                        "transition": {"duration": 0}}])
                ],
                showactive=False,
                x=0,
                y=1.05,
                xanchor="left",
                yanchor="top"
            )
        ]
    ),
    frames=frames
)

# Add sliders
fig.update_layout(
    sliders=[
        dict(
            steps=[
                dict(
                    method='animate',
                    args=[
                        [str(t)],
                        dict(mode='immediate', frame=dict(duration=50, redraw=True), transition=dict(duration=0))
                    ],
                    label=str(t)
                )
                for t in range(num_timesteps)
            ],
            transition=dict(duration=0),
            x=0.1,
            y=0,
            currentvalue=dict(font=dict(size=12), prefix='Timestep: ', visible=True, xanchor='center'),
            len=0.9
        )
    ]
)
fig.show()
