In [1]:
import rosbag
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from cv_bridge import CvBridge
import ipyvolume as ipv
from decimal import Decimal
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from scipy import interpolate
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.widgets import Slider


#from zed_interfaces.msg import ObjectsStamped


In [2]:
#READ BAG

fileroot='24_07_02/'
filename='2024-07-02-BT-Maleen-01'

saverpath='/home/maleen/rosbags/research_data/Transformers/datasets/body_tracking_data/'

bag  = rosbag.Bag('/home/maleen/research_data/Transformers/datasets/'+fileroot+ filename + '.bag')

In [3]:
# Initialize variables
skeleton_size = 0
image_size = 0

skeleton_timestamps = []
skeleton_3d_data = []

image_timestamps = []
image_data = []

topiclist = ['/body_tracker/image', 'body_tracker/body_markers']


# Create a CvBridge object
bridge = CvBridge()

for topic, msg, t in bag.read_messages(topics=topiclist):
    if topic == 'body_tracker/body_markers':
        if len(msg.markers) > 0:
            # Extract the timestamp from the first marker
            timestamp = (msg.markers[0].header.stamp.secs) + (Decimal(msg.markers[0].header.stamp.nsecs) / 1000000000)
            skeleton_timestamps.append(timestamp)
            
            # Initialize an array for the current message's markers
            current_markers = []
            
            # Ensure we have exactly 6 markers
            for i in range(6):
                if i < len(msg.markers):
                    marker = msg.markers[i]
                    # Extract position (x, y, z) from the marker
                    position = marker.pose.position
                    if position.x == 0 and position.y == 0 and position.z == 0:
                        current_markers.append([np.nan, np.nan, np.nan])
                    else:
                        current_markers.append([position.x, position.y, position.z])
                else:
                    # If there are less than 6 markers, fill the remaining with NaNs
                    current_markers.append([np.nan, np.nan, np.nan])
            
            skeleton_3d_data.append(current_markers)
    
    elif topic == '/body_tracker/image':
        # Extract timestamp from the image message
        timestamp = msg.header.stamp.secs + (Decimal(msg.header.stamp.nsecs) / 1000000000)
        image_timestamps.append(timestamp)

        # Convert ROS Image message to OpenCV image
        cv_image = bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')
        
        # Append the image to our list
        image_data.append(cv_image)

In [4]:
skeleton_3d_data[834]

[[0.30453405011665297, 0.39559494516676713, 1.288],
 [0.2932540617004631, 0.1142400155210267, 1.411],
 [0.21349508838952558, -0.1379081670465976, 1.342],
 [nan, nan, nan],
 [-0.044247595584928724, -0.0053267533417909255, 1.181],
 [0.14658599171451236, -0.22402514126023318, 1.098]]

In [5]:
# Convert the list to a numpy array
skeleton_3d_data = np.array(skeleton_3d_data)
skeleton_timestamps= np.array(skeleton_timestamps)


In [6]:
def convert_timestamps_to_numeric(timestamps):
    """Convert Decimal timestamps to numeric values (seconds since first timestamp)"""
    return np.array([float(ts) for ts in timestamps])

def interpolate_missing_data_with_timestamps(data, timestamps):
    """
    Interpolate missing data in a 3D numpy array of shape (frames, points, dof),
    using corresponding timestamps for each frame.
    
    Args:
    data (numpy.ndarray): 3D array of shape (frames, points, dof)
    timestamps (numpy.ndarray): 1D array of Decimal timestamps for each frame
    
    Returns:
    numpy.ndarray: Interpolated data of the same shape
    """
    frames, points, dof = data.shape
    interpolated_data = np.zeros_like(data)
    
    # Convert timestamps to numeric values
    numeric_timestamps = convert_timestamps_to_numeric(timestamps)
    
    for point in range(points):
        for d in range(dof):
            # Extract the time series for this point and DOF
            time_series = data[:, point, d]
            
            # Find the indices of non-NaN values
            valid_indices = ~np.isnan(time_series)
            
            if np.sum(valid_indices) > 1:  # Need at least 2 points for interpolation
                # Create an interpolation function
                f = interpolate.interp1d(numeric_timestamps[valid_indices], 
                                         time_series[valid_indices], 
                                         kind='linear', 
                                         bounds_error=False, 
                                         fill_value='extrapolate')
                
                # Generate interpolated values for all timestamps
                interpolated_series = f(numeric_timestamps)
                
                interpolated_data[:, point, d] = interpolated_series
            elif np.sum(valid_indices) == 1:  # Only one valid point, fill with that value
                interpolated_data[:, point, d] = time_series[valid_indices][0]
            else:  # No valid points, fill with zeros or another placeholder
                interpolated_data[:, point, d] = 0  # or np.nan if you prefer
    
    return interpolated_data

# Convert timestamps and ensure data is sorted
numeric_timestamps = convert_timestamps_to_numeric(skeleton_timestamps)
sort_indices = np.argsort(numeric_timestamps)
numeric_timestamps = numeric_timestamps[sort_indices]
skeleton_3d_data = skeleton_3d_data[sort_indices]

# Apply the interpolation
interpolated_skeletal_data = interpolate_missing_data_with_timestamps(skeleton_3d_data, skeleton_timestamps)

# Print some statistics
print(f"Original data shape: {skeleton_3d_data.shape}")
print(f"Interpolated data shape: {interpolated_skeletal_data.shape}")
print(f"Number of NaN values in original data: {np.isnan(skeleton_3d_data).sum()}")
print(f"Number of NaN values in interpolated data: {np.isnan(interpolated_skeletal_data).sum()}")

# Visualization function
def plot_interpolation(original_data, interpolated_data, timestamps, point=0, dof=0):
    numeric_timestamps = convert_timestamps_to_numeric(timestamps)
    plt.figure(figsize=(12, 6))
    plt.scatter(numeric_timestamps, original_data[:, point, dof], label='Original Data', alpha=0.5)
    plt.plot(numeric_timestamps, interpolated_data[:, point, dof], 'r-', label='Interpolated Data')
    plt.title(f'Interpolation for Point {point}, DOF {dof}')
    plt.xlabel('Timestamp')
    plt.ylabel('Value')
    plt.legend()
    plt.show()

# Uncomment to plot
# plot_interpolation(skeleton_3d_data, interpolated_skeletal_data, skeleton_timestamps, point=0, dof=0)

Original data shape: (3225, 6, 3)
Interpolated data shape: (3225, 6, 3)
Number of NaN values in original data: 2277
Number of NaN values in interpolated data: 0


In [55]:
import plotly.graph_objects as go
import numpy as np
from ipywidgets import interact, IntSlider, Layout
from IPython.display import display

def visualize_skeleton_3d(skeleton_data, timestamps):
    # Define the connections between joints for 6 points
    connections = [
        (0, 1), (1, 2),  # Right arm
        (3, 4), (4, 5),  # Left arm
        (2, 3),  # Connection between arms
    ]
    
    # Calculate the overall min and max for each axis
    x_min, y_min, z_min = np.min(skeleton_data, axis=(0, 1))
    x_max, y_max, z_max = np.max(skeleton_data, axis=(0, 1))
    
    # Add some padding to the min and max values
    padding = 0.1  # 10% padding
    x_range = x_max - x_min
    y_range = y_max - y_min
    z_range = z_max - z_min
    x_min -= x_range * padding
    x_max += x_range * padding
    y_min -= y_range * padding
    y_max += y_range * padding
    z_min -= z_range * padding
    z_max += z_range * padding
    
    # Extend axes to negative values
    x_limit = max(abs(x_min), abs(x_max))
    y_limit = max(abs(y_min), abs(y_max))
    z_limit = max(abs(z_min), abs(z_max))
    
    def update_frame(frame):
        # Get the skeletal data for the current frame
        frame_data = skeleton_data[frame]
        
        # Create scatter plot for points
        scatter = go.Scatter3d(
            x=frame_data[:, 0], y=frame_data[:, 1], z=frame_data[:, 2],
            mode='markers',
            marker=dict(size=6, color='blue'),
            name='Joints'
        )
        
        # Create lines for connections
        lines = []
        for start, end in connections:
            lines.append(
                go.Scatter3d(
                    x=[frame_data[start, 0], frame_data[end, 0]],
                    y=[frame_data[start, 1], frame_data[end, 1]],
                    z=[frame_data[start, 2], frame_data[end, 2]],
                    mode='lines',
                    line=dict(color='red', width=2),
                    name=f'Connection {start}-{end}'
                )
            )
        
        # Combine scatter and lines
        data = [scatter] + lines
        
        # Create the layout with fixed axes and camera position
        layout = go.Layout(
            scene=dict(
                xaxis=dict(title='X', range=[-x_limit, x_limit], autorange=False),
                yaxis=dict(title='Y', range=[-y_limit, y_limit], autorange=False),
                zaxis=dict(title='Z', range=[-z_limit, z_limit], autorange=False),
                aspectmode='cube',  # This ensures the plot is a cube
                camera=dict(
                    eye=dict(x=1, y=-0.2, z=-2),  # Camera position: looking along z-axis
                    up=dict(x=0, y=0, z=0)     # Up direction
                )
            ),
            title=f"Frame: {frame}, Timestamp: {timestamps[frame]}",
            height=700,  # Increase the height of the plot
            width=700    # Increase the width of the plot
        )
        
        # Create and show the figure
        fig = go.Figure(data=data, layout=layout)
        fig.show()
    
    # Create an interactive widget with a longer slider
    interact(
        update_frame,
        frame=IntSlider(
            min=0, 
            max=len(skeleton_data)-1, 
            step=1, 
            value=0, 
            description='Frame:',
            style={'description_width': 'initial'},
            layout=Layout(width='1250px')  # This makes the slider longer
        )
    )

# Assuming interpolated_skeletal_data is your numpy array of shape (frames, 6, 3)
# and skeleton_timestamps is your array of timestamps
visualize_skeleton_3d(interpolated_skeletal_data, skeleton_timestamps)

interactive(children=(IntSlider(value=0, description='Frame:', layout=Layout(width='1250px'), max=3224, style=…

In [57]:
def show_image(frame):
    if frame < 0 or frame >= len(image_data):
        print("Invalid frame number")
        return
    
    plt.figure(figsize=(10, 8))
    plt.imshow(cv2.cvtColor(image_data[frame], cv2.COLOR_BGR2RGB))
    plt.title(f"Frame {frame}, Timestamp: {image_timestamps[frame]}")
    plt.axis('off')
    plt.show()
    
    print(f"Frame {frame} details:")
    print(f"Shape: {image_data[frame].shape}")
    print(f"Data type: {image_data[frame].dtype}")
    print(f"Min value: {np.min(image_data[frame])}")
    print(f"Max value: {np.max(image_data[frame])}")

# Create an interactive widget
interact(
    show_image, 
    frame=IntSlider(
        min=0, 
        max=len(image_data)-1, 
        step=1, 
        description='Frame:'
    )
);

# You can also directly call the function to display a specific frame
# For example, to show frame 5:
# show_image(5)

interactive(children=(IntSlider(value=0, description='Frame:', max=3224), Output()), _dom_classes=('widget-int…

In [44]:
# Check if there are any NaN values in skeleton_3d_frames_updated
nan_indices = np.isnan(skeleton_3d_data)

# Find the indices of frames and joints that have NaN values
frames_with_nans, joints_with_nans, _ = np.where(nan_indices)

# Create a set of unique joints that have NaN values
unique_joints_with_nans = set(joints_with_nans)

print("Frames with NaN values:", np.unique(frames_with_nans))
print("Joints with NaN values:", unique_joints_with_nans)

# Find the indices of frames that have NaN values
frames_with_nans = np.unique(np.where(nan_indices)[0])

# Remove the frames with NaN values
skeleton_3d_cleaned = np.delete(skeleton_3d_data, frames_with_nans, axis=0)
skeleton_timestamps_cleaned=np.delete(skeleton_timestamps, frames_with_nans, axis=0)

print("Original shape:", np.shape(skeleton_3d_data))
print("New shape after removing frames with NaNs:", np.shape(skeleton_3d_cleaned))

Frames with NaN values: [   0    3    6   11   77   83   86  111  113  114  121  122  123  124
  125  126  127  189  203  206  208  209  210  211  212  214  216  242
  244  248  249  253  261  272  284  312  314  318  319  321  335  377
  436  444  447  451  455  464  474  475  476  477  478  504  505  506
  507  509  511  514  515  518  519  520  521  522  523  527  529  535
  540  541  542  543  545  561  562  563  566  567  568  569  570  573
  574  615  616  618  644  645  647  657  715  719  720  721  722  726
  728  748  753  771  778  783  826  830  831  834  835  836  837  840
  842  845  847  848  881  905  931  959  964  972  973  976  983  986
  988  991  998 1000 1002 1020 1036 1037 1051 1074 1077 1078 1087 1102
 1104 1129 1130 1131 1132 1133 1134 1135 1139 1144 1147 1152 1165 1167
 1170 1173 1175 1176 1180 1185 1186 1187 1188 1189 1190 1191 1196 1197
 1201 1202 1206 1212 1219 1222 1224 1247 1248 1252 1253 1269 1270 1271
 1273 1274 1275 1276 1278 1279 1282 1294 1304 1343 13

array([Decimal('1719891177.500607252'), Decimal('1719891177.613994121'),
       Decimal('1719891177.730127811'), ...,
       Decimal('1719891542.148417234'), Decimal('1719891542.254627943'),
       Decimal('1719891542.366212368')], dtype=object)