<a href="https://colab.research.google.com/github/mosamdabhi/3dlfm/blob/main/demo_vis_collab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
## Relevant installations

import subprocess
import sys
import os
import requests
from zipfile import ZipFile

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

required_packages = ["numpy", "plotly", "PIL", "tqdm", "shutil", "random", "gdown", "nbformat"]

for package in required_packages:
    try:
        __import__(package)
    except ImportError:
        install(package)

import gdown

def download_and_extract(file_id, target_folder):
    if not os.path.exists(target_folder):
        zip_path = f"{target_folder}.zip"
        print(f"Downloading {zip_path}...")

        # Using gdown to download the file from Google Drive
        gdown.download(f'https://drive.google.com/uc?id={file_id}', zip_path, quiet=False)

        print(f"Extracting {zip_path}...")
        with ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(target_folder)

        os.remove(zip_path)
        print(f"Data downloaded and extracted to {target_folder}")
    else:
        print(f"{target_folder} already exists, no need to download.")

In [2]:
## Download the data
file_id = '1VK2ewrOAPR3ZEjW9gBGupr5GfKTtpviC'
target_folder = 'data'
download_and_extract(file_id, target_folder)


Downloading data.zip...


Downloading...
From: https://drive.google.com/uc?id=1VK2ewrOAPR3ZEjW9gBGupr5GfKTtpviC
To: /content/data.zip
100%|██████████| 52.3M/52.3M [00:00<00:00, 143MB/s]


Extracting data.zip...
Data downloaded and extracted to data


In [3]:
### This cell selects the object category to visualize
import pickle
import numpy as np
from tqdm import tqdm
import os
import shutil
import random
import plotly.graph_objs as go
from PIL import Image
import numpy as np
from PIL import Image
import numpy as np
import plotly.graph_objects as go

### Ask to choose from the list of categories
category_name = input("Pick object category: ")

# If typed string is not supported, raise error and ask again
if category_name not in ["aeroplane", "bicycle", "boat", "bottle", "bus", "car", "chair", "cow", "diningtable", "dog", "hippo", "horse_zebra", "motorbike", "openmonkey", "sofa", \
                         "cats", "train", "tvmonitor", "wholebodyh36m", "cheetah", "chimpanzee", "clownfish", "colobusmonkey", "fish", "tiger"]:
    raise ValueError("Category not supported. Please choose from the list of categories stated in the above Markdown cell.")
else:
    print("Great! You have chosen {}. Let us visualize the 3D reconstruction of {} object category.".format(category_name, category_name))

Pick object category: colobusmonkey
Great! You have chosen colobusmonkey. Let us visualize the 3D reconstruction of colobusmonkey object category.


In [4]:
### This cell defines required functions
def retrieve_joint_connections(dataset):

    if dataset == "cheetah":
        joint_connections = [[2, 0], [0, 1], [2, 3], [3, 4],
                            [4, 5], [6, 7], [7, 8], [2, 6],
                            [2, 9], [9, 10], [10, 11], [3, 15],
                            [15, 16], [16, 17], [3, 12], [12, 13], [13, 14]]

    elif dataset == "openmonkey":
        joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [2, 5], [5, 6], [2, 7], [7, 8], [8, 9], [7, 10], [10, 11], [7, 12]]

    elif dataset == "wholebodyh36m":
        joint_connections = [
            # Connect points as defined by the tuples
            [0, 1], [1, 3], [0, 2], [2, 4],
            [5, 7], [7, 9], [9, 91], #[91, 92], [93, 96], [96, 100], [100, 104], [104, 108], [91, 108],
            [6, 8], [8, 10], [10, 112], #[112, 113], [114, 117], [117, 121], [121, 125], [125, 129], [112, 129],
            [5, 6], [6, 12], [11, 12], [5, 11],
            [12, 14], [14, 16], [16, 20], [16, 21], [16, 22],
            [11, 13], [13, 15], [15, 17], [15, 18], [15, 19],

            # Face
            [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [31, 32], [32, 33], [33, 34],
            [34, 35], [35, 36], [36, 37], [37, 38], [38, 39], [40, 41], [41, 42], [42, 43], [43, 44], [59, 60], [60, 61],
            [61, 62], [62, 63], [63, 64], [59, 64], [45, 46], [46, 47], [47, 48], [48, 49], [65, 66], [66, 67], [67, 68],
            [68, 69], [69, 70], [65, 70], [50, 51], [51, 52], [52, 53], [54, 55], [55, 56], [56, 57], [57, 58], [71, 72],
            [72, 73], [73, 74], [74, 75], [75, 76], [76, 77], [77, 78], [78, 79], [79, 80], [80, 81], [81, 82], [82, 83],
            [83, 84], [84, 85], [85, 86], [86, 87], [87, 88], [88, 89], [89, 90], [91, 92],

            # Left hand
            [91, 92], [92, 93], [93, 94], [94, 95], [91, 96], [96, 97], [97, 98], [98, 99], [91, 100], [100, 101], [101, 102],
            [102, 103], [91, 104], [104, 105], [105, 106], [106, 107], [91, 108], [108, 109], [109, 110], [110, 111],

            # Right hand
            [112, 113], [113, 114], [114, 115], [115, 116], [112, 117], [117, 118], [118, 119], [119, 120], [112, 121],
            [121, 122], [122, 123], [123, 124], [112, 125], [125, 126], [126, 127], [127, 128], [112, 129], [129, 130],
            [130, 131], [131, 132]
        ]

    elif dataset == 'aeroplane':
        joint_connections = [[2, 5], [1, 4], [5, 3], [3, 7], [7, 0], [0, 5], [5, 7], [5, 6], [6, 0], [6, 3], [2, 4], [2, 1]]

    elif dataset == 'bicycle':
        joint_connections = [[0,3], [0,7], [0, 2], [0, 6], [0, 10], [9, 10], [4, 10], [8, 10], [1, 9], [5, 9]]

    elif dataset == "cow" or dataset == "horse" or dataset == "hippo" or dataset == "dog" or dataset == "cats" or dataset == "horse_zebra":
        joint_connections = [[0,24], [0, 20], [1, 21], [1, 24], [7, 25], [19, 25], [6, 17],
                        [4, 15], [3, 14], [9, 15], [8, 14], [9, 13], [8, 12],
                        [2, 23], [2, 22], [2, 24], [11, 17], [10, 16], [5, 16],
                        [7, 10], [7, 11], [13,18], [12, 18], [7, 18], [24,18]]

    elif dataset == 'boat':
        joint_connections = [[0, 2], [0, 3], [0, 1], [1, 2], [1, 3], [2, 4], [3, 5], [4, 5], [1, 5], [1, 4]]

    elif dataset == 'bottle':
        joint_connections = [[0, 1], [1, 2], [0, 2], [3, 4], [3, 5], [4, 5], [1, 4], [0, 3], [2, 5], [1, 6], [0, 6], [2, 6]]

    elif dataset == 'bus':
        joint_connections = [[5, 7], [4, 5], [6, 7], [4, 6], [1, 5], [1, 3], [3, 7], [0, 1], [2, 3], [0, 2], [2, 10], [0, 8], [8, 9], [10, 11], [6, 11], [4, 9]]

    elif dataset == 'car':
        joint_connections = [[0, 8], [0, 4], [4, 10], [8, 10],
                                [10, 9], [9, 11], [8, 11], [11, 6],
                                [9, 2], [2, 6], [4, 1], [5, 1],
                                [0, 5], [5, 7], [1, 3], [7, 3], [3, 2], [7, 6]]

    elif dataset == 'diningtable':
        joint_connections = [[0, 2], [4, 6], [1, 3], [5, 7], [1, 5], [3, 7], [0, 4], [2, 6], [0, 1], [2, 3], [4, 5], [6, 7]]

    elif dataset == 'tvmonitor':
        joint_connections = [[5, 7], [4, 5], [4, 6], [6, 7], [0, 1], [0, 2], [2, 3], [1, 3], [3, 7], [1, 5], [2, 6], [0, 4]]

    elif dataset == 'train':
        joint_connections = [[4, 5], [4, 6], [6, 7], [5, 7], [0, 1], [1, 3], [2, 3], [0, 2], [1, 5], [0, 4], [2, 6], [3, 7], [1, 5]]

    elif dataset == 'motorbike':
        joint_connections = [[6, 2], [2, 9], [2, 3], [3, 8], [5, 8],
                                [3, 5], [2, 1], [1, 0], [0, 7], [0, 4],
                                [4, 7], [1, 4], [1, 7], [1, 5], [1, 8]]

    elif dataset == 'sofa':
        joint_connections = [[1, 5], [5, 4], [4, 6], [6, 2], [2, 0],
                                [1, 0], [0, 4], [1, 3], [7, 5], [2, 3],
                                [3, 7], [9, 7], [7, 6], [6, 8], [8, 9]]

    elif dataset == 'chair':
        joint_connections = [[7, 3], [6, 2], [9, 5], [8, 4], [7, 9],
                                [8, 6], [6, 7], [9, 8], [9, 1], [8, 0], [1, 0]]

    # MBW datasets
    elif dataset == 'colobusmonkey':
        joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [10, 11], [11, 12], [9, 13], [13, 14], [14, 15]]

    elif dataset == 'chimpanzee':
        joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [10, 11], [11, 12], [9, 13], [13, 14], [14, 15]]

    elif dataset == 'tiger':
        joint_connections = [[0, 1], [1, 2], [2, 3], [3,4], [4,5], [1,6], [6,7], [1,8], [8,9], [3,10], [10,11], [3,12], [12,13]]

    elif dataset == 'clownfish':
        joint_connections = [[0, 1], [1, 2], [2, 3], [1, 4], [1, 5]]

    elif dataset == 'fish':
        joint_connections = [[0, 1], [1, 2], [2, 3], [1, 3], [3, 4], [4, 5], [5, 6], [6, 7], [5, 7], [5, 8], [8, 9], [9, 10], [8, 10], [10, 11], [11, 0]]

    elif dataset == 'seahorse':
        joint_connections = [[0, 1], [1, 2], [2,3], [1,3], [3,4], [4,5]]

    return joint_connections


def plot_3d_skeleton(predictions_3d, labels_3d, joint_connections, range_scale=2500, masks=None):
    """Visualize 3D skeletons for predicted and ground truth data."""

    # Extract 3D coordinates and masks for the given sample index
    pred_coordinates = predictions_3d
    label_coordinates = labels_3d

    # Extract X, Y, Z coordinates after filtering
    label_x, label_y, label_z = label_coordinates.T

    # Filter joint connections based on the mask
    if masks is not None:
        updated_connections = [connection for connection in joint_connections if masks[connection[0]] == 1.0 and masks[connection[1]] == 1.0]
        # print("Updated connections: {}".format(updated_connections))
    else:
        updated_connections = joint_connections



    # Plotly Traces
    traces = []
    # Predicted skeleton
    traces.extend(get_trace3d(updated_connections, pred_coordinates, 'blue', 'blue', "Predicted KP", masks=masks))
    # Ground truth skeleton
    traces.extend(get_trace3d(updated_connections, label_coordinates, 'red', 'red', "Groundtruth KP", masks=masks))

    # Define layout
    layout = go.Layout(
        scene=dict(
            aspectratio=dict(x=1, y=1, z=2),
            xaxis=dict(range=[-label_x.max() * range_scale, label_x.max() * range_scale, ], showticklabels=False),
            yaxis=dict(range=[-label_z.max() * range_scale, label_z.max() * range_scale], showticklabels=False),
            zaxis=dict(range=[-label_y.max() * range_scale, label_y.max() * range_scale], showticklabels=False),
        ),
        width=700,
        margin=dict(r=20, l=10, b=10, t=10),
        scene_camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=0, y=-1.5, z=1.25),
        )
    )

    # Create and display the plot
    # fig = go.Figure(data=traces, layout=layout)
    fig = go.Figure(data=traces)
    fig.update_layout(scene=dict(aspectmode="data"))
    fig.update_layout(
        scene=dict(
            xaxis=dict(title='', showticklabels=False),
            yaxis=dict(title='', showticklabels=False),
            zaxis=dict(title='', showticklabels=False)
        )
    )
    fig.show()


def get_trace3d(joint_connections, points3d, point_color, line_color, name, masks=None):
    """Generate plotly traces for 3D points and connections."""

    # Filter 3D coordinates based on the mask
    if masks is not None:
        masked_coordinates = points3d[masks == 1.0]
    else:
        masked_coordinates = points3d

    x, z, y = masked_coordinates.T  # Swap Y and Z here
    x_trace, z_trace, y_trace = points3d.T  # Swap Y and Z here

    # Trace of points
    trace_pts = go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        name=name,
        marker=dict(symbol='circle', size=6, color=point_color)
    )

    # Trace of lines
    x_lines = []
    y_lines = []
    z_lines = []

    for start, end in joint_connections:
        x_lines.extend([x_trace[start], x_trace[end], None])
        y_lines.extend([y_trace[start], y_trace[end], None])
        z_lines.extend([z_trace[start], z_trace[end], None])

    trace_lines = go.Scatter3d(
        x=x_lines, y=y_lines, z=z_lines,
        mode='lines',
        name=name,
        line=dict(width=6, color=line_color)
    )

    return [trace_pts, trace_lines]


def plot_2d_skeleton(predictions_2d, labels_2d, joint_connections, masks=None):
    """Visualize 2D skeletons for predicted and ground truth data."""

    # Extract 2D coordinates and masks for the given sample index
    pred_coordinates = predictions_2d
    label_coordinates = labels_2d

    # Filter joint connections based on the mask
    if masks is not None:
        updated_connections = [connection for connection in joint_connections if masks[connection[0]] == 1.0 and masks[connection[1]] == 1.0]
        # print("Updated connections: {}".format(updated_connections))
    else:
        updated_connections = joint_connections




    # Plotly Traces
    traces = []
    # Predicted skeleton
    traces.extend(get_trace2d(updated_connections, pred_coordinates, 'blue', 'blue', "Predicted KP", masks=masks))
    # Ground truth skeleton
    traces.extend(get_trace2d(updated_connections, label_coordinates, 'red', 'red', "Groundtruth KP", masks=masks))

    # Define layout
    layout = go.Layout(
        width=700,
        height=700,
        margin=dict(r=20, l=10, b=10, t=10)
    )

    # Create and display the plot
    fig = go.Figure(data=traces, layout=layout)
    fig.show()

def get_trace2d(joint_connections, points2d, point_color, line_color, name, masks=None, get_lines=None):
    """Generate plotly traces for 2D points and connections."""

    # Filter 2D coordinates based on the mask
    if masks is not None:
        masked_coordinates = points2d[masks == 1.0]
    else:
        masked_coordinates = points2d

    x, y = masked_coordinates.T  # Swap Y and Z here
    x_trace, y_trace = points2d.T  # Swap Y and Z here

    # Trace of points
    trace_pts = go.Scatter(
        x=x, y=y,
        mode='markers',
        name=name,
        marker=dict(symbol='circle', size=6, color=point_color)
    )

    # Trace of lines
    x_lines = []
    y_lines = []

    for start, end in joint_connections:
        x_lines.extend([x_trace[start], x_trace[end], None])
        y_lines.extend([y_trace[start], y_trace[end], None])

    trace_lines = go.Scatter(
        x=x_lines, y=y_lines,
        mode='lines',
        name=name,
        line=dict(width=2, color=line_color)
    )

    if get_lines is not None:
        if get_lines:
            return [trace_pts, trace_lines]
        else:
            return [trace_pts]
    else:
        return [trace_pts, trace_lines]



def plot_2d_skeleton_on_image(predictions_2d, labels_2d, joint_connections, image_path, masks=None, get_lines=None):
    """Visualize 2D skeletons for predicted and ground truth data on top of an image."""

    # Load the image
    image = Image.open(image_path)
    width, height = image.size

    # Extract 2D coordinates and masks for the given sample index
    pred_coordinates = predictions_2d
    label_coordinates = labels_2d
    masks_ = masks

    # Filter joint connections based on the mask
    if masks is not None:
        updated_connections = [connection for connection in joint_connections if masks_[connection[0]] == 1.0 and masks_[connection[1]] == 1.0]
    else:
        updated_connections = joint_connections

    # print("updated connections: {}".format(updated_connections))

    # Plotly Traces
    traces = []
    # Image as background
    traces.append(go.Scatter(
        x=[0, width],
        y=[0, height],
        mode="markers",
        marker_opacity=0,
        hoverinfo="none",
        showlegend=False
    ))

    # Predicted skeleton
    traces.extend(get_trace2d(updated_connections, pred_coordinates, 'blue', 'blue', None, masks_, get_lines=get_lines))
    # Ground truth skeleton
    traces.extend(get_trace2d(updated_connections, label_coordinates, 'red', 'red', None, masks_, get_lines=get_lines))

    # Define layout
    layout = go.Layout(
        width=width,
        height=height,
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[0, width]),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[height, 0], scaleanchor="x"),
        images=[go.layout.Image(source=image, xref="x", yref="y", x=0, y=0, sizex=width, sizey=height, sizing="stretch", opacity=1.0, layer="below")],
        margin=dict(r=10, l=10, b=10, t=10),
        hovermode="closest",
        showlegend=False,  # Hide legend
    )

    # Create and display the plot
    fig = go.Figure(data=traces, layout=layout)
    fig.show()

In [5]:
# This cell loads the input and predicted data

if category_name == "aeroplane" or category_name == "bicycle" or category_name == "boat" or category_name == "bottle" or category_name == "bus" or category_name == "car" or category_name == "chair" or category_name == "cow" or category_name == "diningtable" or category_name == "dog" or category_name == "hippo" or category_name == "horse_zebra" or category_name == "motorbike" or category_name == "openmonkey" or category_name == "sofa" or category_name == "cats" or category_name == "train" or category_name == "tvmonitor" or category_name == "wholebodyh36m":
    base_directory = "data/demo-data/validation/"
else:
    base_directory = "data/demo-data/ood_validation/"
    print("Caution: We will be operating over OOD data so performance may not be as good.")


data_path = base_directory + category_name + '/final/' + category_name + '.pkl'
with open(data_path, 'rb') as f:
    data = pickle.load(f)

inputs_2d = data['inputs_2d']
labels_3d = data['labels_3d']
outputs_3d = data['outputs_3d']
image_path = data['image_path']
if 'masks' in data.keys():
    masks = data['masks']
else:
    masks = np.ones((inputs_2d.shape[0], inputs_2d.shape[1]))
joint_connections = retrieve_joint_connections(category_name)

## Append base directory to image path for all image_path entries
image_path = [base_directory + category_name + '/' + path for path in image_path]

Caution: We will be operating over OOD data so performance may not be as good.


In [6]:
## Visualize the chosen frame
## Add following message along with asking for frame input visualization
print("Total number of available frames for {} category are: {}".format(category_name, len(inputs_2d)))
print("Please choose a frame number between 0 and {}".format(len(inputs_2d)-1))

frame_number = int(input("Please choose a frame number between 0 and {}".format(len(inputs_2d)-1)))

# frame_number = int(input())
print("The chosen frame is: {}".format(frame_number))

get_lines = False;
if category_name == "aeroplane" or category_name == "bicycle" or category_name == "clownfish" or category_name == "tiger" or category_name == "cheetah":
    visualization_flip = -1
else: visualization_flip = 1

# If base directory is OOD, by default turn off the labels. Please a message to turn on labels if needed.
overlay_labels = True
if base_directory == "data/demo-data/ood_validation/":
    overlay_labels = False
    print("-"*100)
    if category_name == "colobusmonkey" or category_name == "chimpanzee" or category_name == "tiger" or category_name == "clownfish" or category_name == "fish" or category_name == "seahorse":
        print("Caution: \n We are operating over OOD data captured in the wild via MBW (NeurIPS 2022) dataset. \n `pseudo` groundtruth labels are available for this OOD data generated using 2-view bootstrapping method given in this paper.")
        print("-"*100)
    else:
        print("Caution: \n We are operating over OOD data captured via AcinoSet (ICRA, 2021) dataset. \n `pseudo` groundtruth labels are available using 6-view triangulation method given in this paper.")
        print("Note: We use 2D/3D keypoints from the following paper: \nhttps://arxiv.org/pdf/2103.13282.pdf")
        print("-"*100)
    print("-"*100)
    print("By default, `pseudo` groundtruth labels are disabled for visualization. \n Please turn on `pseudo` groundtruth labels using `overlay_labels=True` flag")

plot_3d_skeleton(visualization_flip*outputs_3d[frame_number], visualization_flip*labels_3d[frame_number]*overlay_labels, joint_connections, range_scale=2500, masks=masks[frame_number])
plot_2d_skeleton_on_image(inputs_2d[frame_number], inputs_2d[frame_number], joint_connections, image_path[frame_number], masks=masks[frame_number], get_lines=get_lines)


Total number of available frames for colobusmonkey category are: 10
Please choose a frame number between 0 and 9
Please choose a frame number between 0 and 92
The chosen frame is: 2
----------------------------------------------------------------------------------------------------
Caution: 
 We are operating over OOD data captured in the wild via MBW (NeurIPS 2022) dataset. 
 `pseudo` groundtruth labels are available for this OOD data generated using 2-view bootstrapping method given in this paper.
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
By default, `pseudo` groundtruth labels are disabled for visualization. 
 Please turn on `pseudo` groundtruth labels using `overlay_labels=True` flag
