In [17]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import os
import shutil
import imageio
import torch

no_cuda = False
use_cuda = not no_cuda and torch.cuda.is_available()
device_name  = "cuda:0" if use_cuda else "cpu"
device = torch.device(device_name)

In [18]:
@ torch.no_grad()
def loading(loaded):
    control_action_extracted = loaded['control_action']
    position_extracted = loaded['position']
    
    try:
        target_extracted = loaded['target']
        target = target_extracted
        position = position_extracted
        control_action = control_action_extracted

    except KeyError:
        target = []
        position = position_extracted
        control_action = control_action_extracted
    
    return control_action,position,target

In [21]:
# Convert Euler angles to rotation matrix
def euler_to_rotation_matrix(euler):
    phi, theta, psi = euler
    R_x = np.array([[1, 0, 0],
                    [0, np.cos(phi), -np.sin(phi)],
                    [0, np.sin(phi), np.cos(phi)]])
    R_y = np.array([[np.cos(theta), 0, np.sin(theta)],
                    [0, 1, 0],
                    [-np.sin(theta), 0, np.cos(theta)]])
    R_z = np.array([[np.cos(psi), -np.sin(psi), 0],
                    [np.sin(psi), np.cos(psi), 0],
                    [0, 0, 1]])
    return np.dot(R_z, np.dot(R_y, R_x))

# Create a folder to save frames
frame_folder = "frames"
os.makedirs(frame_folder, exist_ok=True)


In [22]:
name_of_test_file = './Tasks/chirp.pt'
chirp = torch.load(name_of_test_file,map_location=device) #

name_of_test_file = './Tasks/circle.pt'
fixed_cyrcle = torch.load(name_of_test_file,map_location=device) #

name_of_test_file = './Tasks/multisinusoidal.pt'
multisinusoidal = torch.load(name_of_test_file,map_location=device) #

name_of_test_file = './Tasks/vertical_spiral.pt'
variable_spyral = torch.load(name_of_test_file,map_location=device) #

control_action_chirp,position_chirp,target_nominal = loading(chirp)
control_action_multisinusoidal,position_multisinusoidal,target_nominal = loading(multisinusoidal)

# OSC TASKS
control_action_fixed_cyrcle,position_fixed_cyrcle,target_fixed_cyrcle = loading(fixed_cyrcle)
control_action_variable_spyral,position_variable_spyral,target_variable_spyral = loading(variable_spyral)

In [23]:
def quaternion_to_euler(quaternion):
    """
    Convert quaternion to Euler angles.
    """
    x, y, z, w = quaternion[3], quaternion[4], quaternion[5], quaternion[6]
    
    # Roll (x-axis rotation)
    sinr_cosp = 2 * (w * x + y * z)
    cosr_cosp = 1 - 2 * (x**2 + y**2)
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    if torch.abs(sinp) >= 1:
        pitch = torch.sign(sinp) * np.pi / 2  # Use 90 degrees if out of range
    else:
        pitch = torch.asin(sinp)

    # Yaw (z-axis rotation)
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y**2 + z**2)
    yaw = torch.atan2(siny_cosp, cosy_cosp)

    return roll, pitch, yaw

def plot_coordinate_system(ax, origin, rotation_matrix):
    """
    Plot a coordinate system.
    """
    axes = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
    rotated_axes = np.dot(axes, rotation_matrix)
    for i, color in enumerate(['r', 'g', 'b']):
        ax.quiver(origin[0], origin[1], origin[2], rotated_axes[i, 0], rotated_axes[i, 1], rotated_axes[i, 2], color=color)

def plot_orientation(position_fixed_circle,name_gif):
    """
    Plot orientation from quaternions using Euler angles.
    """
    # Create a folder to save frames
    frame_folder = "frames"
    os.makedirs(frame_folder, exist_ok=True)
    
    frames = []
    for idx, position in enumerate(position_fixed_circle):
        euler_angles = quaternion_to_euler(position)
        
        rotation_matrix = euler_to_rotation_matrix(euler_angles)
        
        fig, ax = plt.subplots(1, 1, dpi=150, subplot_kw={'projection': '3d'})

        # Plot the coordinate system at the origin
        plot_coordinate_system(ax, [0, 0, 0], np.eye(3))

        # Plot the coordinate system on the body
        plot_coordinate_system(ax, position[:3], rotation_matrix)

        # Plot the body as a point
        ax.scatter(position[0], position[1], position[2], color='k')

        # Plot dashed yellow line from origin to body
        ax.plot([0, position[0]], [0, position[1]], [0, position[2]], color='magenta', linestyle='--')

        # Set plot limits
        ax.set_xlim([0, 1.5])
        ax.set_ylim([0, 1.5])
        ax.set_zlim([0, 1.5])

        # Set labels
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.view_init(elev=15, azim=75, roll=0)

        # Save plot as image
        filename = os.path.join(frame_folder, f"frame_{idx}.png")
        plt.savefig(filename)
        plt.close()

        frames.append(imageio.imread(filename))

    # Create GIF
    imageio.mimsave(f'{name_gif}.gif', frames, duration=0.1)


In [24]:
position_chirp=position_chirp[:,43,:7].cpu()
plot_orientation(position_chirp,'output_chirp')

  frames.append(imageio.imread(filename))
