In [1]:
import os
import time
import socket
import threading
import torch
import numpy as np
from datetime import datetime
from argparse import ArgumentParser
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
from pygame.time import Clock
import pickle

from articulate.math import *
from mobileposer.models import *
from mobileposer.utils.model_utils import *
from mobileposer.config import *

import coremltools as ct
import pytorch_lightning as pl

In [2]:
from mobileposer.models.rnn import RNN
class JointsBase(nn.Module):
    """
    Inputs: N IMUs.
    Outputs: 24 Joint positions. 
    """

    def __init__(self, n_imu, seq_length):
        super().__init__()
        # self.joints = net.joints.joints
        self.joints = RNN(n_imu, 24 * 3, 256, seq_length)

    def forward(self, batch, input_lengths: Tensor):
        # forward joint model
        joints, _, _ = self.joints(batch)
        return joints

class PoserBase(nn.Module):
    """
    Inputs: N IMUs.
    Outputs: SMPL Pose Parameters (as 6D Rotations).
    """
    def __init__(self, n_output_joints, n_imu, n_reduced, seq_length):
        super().__init__()
        # self.pose = net.pose.pose
        self.pose = RNN(n_output_joints*3 + n_imu, n_reduced*6, 256, seq_length)

    def forward(self, batch, input_lengths: Tensor):
        # forward the pose prediction model
        pred_pose, _, _ = self.pose(batch)
        return pred_pose

class VelocityBase(nn.Module):
    """
    Inputs: N IMUs.
    Outputs: Per-Frame Root Velocity. 
    """

    def __init__(self, n_output_joints, n_imu, seq_length):
        super().__init__()

        # model definitions
        # self.vel = net.velocity.vel
        self.vel = RNN(n_output_joints * 3 + n_imu, 24 * 3, 256, bidirectional=False, seq_length=seq_length)

    def forward(self, batch, h, c, input_lengths:Tensor):
        # forward velocity model
        vel, _, state = self.vel(batch, (h,c))
        h_out, c_out = state[0].detach(), state[1].detach()
        return vel, h_out, c_out
    
class FootContactBase(nn.Module):
    """
    Inputs: N IMUs.
    Outputs: Foot Contact Probability ([s_lfoot, s_rfoot]).
    """

    def __init__(self, n_output_joints, n_imu, seq_length):
        super().__init__()
        # self.footcontact = net.foot_contact.footcontact
        self.footcontact = RNN(n_output_joints * 3 + n_imu, 2, 64, seq_length=seq_length)

    def forward(self, batch, input_lengths: Tensor):
        # forward foot contact model
        foot_contact, _, _ = self.footcontact(batch)
        return foot_contact

In [31]:
class MobilePoserBase(nn.Module):
    # def __init__(self, net, body_model, joints_model, pose_model, contact_model, velocity_model, n_reduced, ignored):
    #     super().__init__()

    #     #constants
    #     self.n_reduced = n_reduced
    #     self.ignored = ignored

    #     #core model layers
    #     self.bodymodel = body_model
    #     self.joints = joints_model
    #     self.pose = pose_model
    #     self.foot_contact = contact_model
    #     self.velocity = velocity_model
    def __init__(self, n_reduced, ignored, n_imu, n_output_joints, seq_length):
        super().__init__()

        #constants
        self.n_reduced = n_reduced
        self.ignored = ignored

        #core model layers
        self.joints = JointsBase(n_imu=n_imu, seq_length=seq_length)
        self.pose = PoserBase(n_imu=n_imu, n_output_joints=n_output_joints, n_reduced=n_reduced, seq_length=seq_length)
        self.foot_contact = FootContactBase(n_output_joints=n_output_joints, n_imu=n_imu, seq_length=seq_length)
        self.velocity = VelocityBase(n_imu=n_imu, n_output_joints=n_output_joints, seq_length=seq_length)
    
    def normalize_tensor(self, x: torch.Tensor, dim: int=-1):
        norm = torch.norm(x, dim=dim, keepdim=True)
        normalized_x = x / norm
        return normalized_x
    
    def quaternion_to_rotation_matrix(self, q: torch.Tensor):
        q = self.normalize_tensor(q.view(-1, 4))
        a, b, c, d = q[:, 0:1], q[:, 1:2], q[:, 2:3], q[:, 3:4]
        r = torch.cat((- 2 * c * c - 2 * d * d + 1, 2 * b * c - 2 * a * d, 2 * a * c + 2 * b * d,
                    2 * b * c + 2 * a * d, - 2 * b * b - 2 * d * d + 1, 2 * c * d - 2 * a * b,
                    2 * b * d - 2 * a * c, 2 * a * b + 2 * c * d, - 2 * b * b - 2 * c * c + 1), dim=1)
        return r.view(-1, 3, 3)
    
    def forward(self, 
                # imu: torch.Tensor, 
                ori_raw: torch.Tensor, 
                acc_raw: torch.Tensor, 
                acc_offsets: torch.Tensor, 
                smpl2imu: torch.Tensor, 
                device2bone: torch.Tensor,
                h: torch.Tensor,
                c: torch.Tensor):
        ori_raw = self.quaternion_to_rotation_matrix(ori_raw).view(-1, 5, 3, 3) # hardcoded n_imus = 5
        glb_acc = (smpl2imu.matmul(acc_raw.view(-1, 5, 3, 1)) - acc_offsets).view(-1, 5, 3) # hardcoded n_imus = 5
        glb_ori = smpl2imu.matmul(ori_raw).matmul(device2bone)

        # normalization 
        _acc = glb_acc.view(-1, 5, 3)[:, [1, 4, 3, 0, 2]] / 30 #hardcoded acc_scale = 30
        _ori = glb_ori.view(-1, 5, 3, 3)[:, [1, 4, 3, 0, 2]]

        acc = torch.zeros_like(_acc)
        ori = torch.zeros_like(_ori)

        # device combo
        # c = [1, 3] # hardcoded rw_rp': [1, 3]
        # combo = torch.tensor([3], dtype=torch.long) #hardcoded rp: [3]

        mask_1d = torch.tensor([0, 0, 0, 1, 0], dtype=acc_raw.dtype)  
        mask_1d = mask_1d.view(1, 5)                           
        mask_acc = mask_1d.unsqueeze(-1)                # [1,5,1]
        mask_ori = mask_1d.view(1, 5, 1, 1)

        acc = _acc * mask_acc      # [1,5,3], all zeros except channel 3
        ori = _ori * mask_ori      # [1,5,3,3], all zeros except channel 3
             
        imu_input = torch.cat([acc.flatten(1), ori.flatten(1)], dim=1).squeeze(0)

        # Pushinng this logical if statement to Swift
        imu = imu_input.repeat(45, 1)
        # imu = torch.cat((imu[1:], imu_input.view(1, -1))) #hardcoded num_total_frames = 45

        pred_pose, pred_joints, pred_vel, foot_contact, velocity_h, velocity_c = self.run_model(imu.unsqueeze(0), input_lengths=torch.tensor([45]), h=h, c=c)

        return pred_pose, pred_joints, pred_vel, foot_contact, velocity_h, velocity_c, imu.squeeze(0)
        # imu to store

    def run_model(self, 
                  batch: torch.Tensor, 
                  h: torch.Tensor, 
                  c: torch.Tensor, 
                  input_lengths: torch.Tensor):
        # forward the joint prediction model
        pred_joints = self.joints(batch, input_lengths)

        # forward the pose prediction model
        pose_input = torch.cat((pred_joints, batch), dim=-1)
        pred_pose = self.pose(pose_input, input_lengths)

        # forward the foot-ground contact probability model
        tran_input = torch.cat((pred_joints, batch), dim=-1)
        foot_contact = self.foot_contact(tran_input, input_lengths)

        # foward the foot-joint velocity model
        pred_vel, velocity_h, velocity_c = self.velocity(tran_input, h, c, input_lengths)
        pred_vel = pred_vel.squeeze(0)

        pred_pose, pred_joints, pred_vel, foot_contact = self.process_base_outputs(pred_pose, pred_joints, pred_vel, foot_contact)

        return pred_pose, pred_joints, pred_vel, foot_contact, velocity_h, velocity_c
    
    def rotation_matrix_to_axis_angle(self, r: torch.Tensor) -> torch.Tensor:
        """
        :param r: Tensor of shape (..., 3, 3), a batch of rotation matrices
        :return: Tensor of shape (..., 3), the corresponding axis-angle vectors
        """
        # Flatten batch dims
        R = r.view(-1, 3, 3)

        # 1) compute the trace → cos θ
        tr = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
        cos_theta = (tr - 1.0) * 0.5
        cos_theta = cos_theta.clamp(-1.0, 1.0)

        # 2) recover θ
        theta = torch.acos(cos_theta)

        # 3) compute the "cross-differences" v = [R32-R23, R13-R31, R21-R12]
        rx = R[..., 2, 1] - R[..., 1, 2]
        ry = R[..., 0, 2] - R[..., 2, 0]
        rz = R[..., 1, 0] - R[..., 0, 1]
        v   = torch.stack((rx, ry, rz), dim=-1)

        # 4) normalize to get the rotation axis: axis = v / (2 sin θ)
        sin_theta = torch.sin(theta).clamp(min=1e-6).unsqueeze(-1)
        axis = v / (2.0 * sin_theta)

        # 5) axis-angle vector = axis * θ
        rot_vec = axis * theta.unsqueeze(-1)

        # return rot_vec.view(R.shape[0], 3)
        return rot_vec.flatten()
    
    def process_base_outputs(self, pose, pred_joints, vel, contact):
        
        pose = art.math.r6d_to_rotation_matrix(pose).reshape(-1, 24, 3, 3)
        
        # get pose
        curr_pose = pose[40]   # shape = [24, 3, 3]
        curr_pose = curr_pose.flatten(start_dim=2)
        curr_pose = self.rotation_matrix_to_axis_angle(curr_pose)

        # compute the joint positions from predicted pose
        joints = pred_joints.squeeze(0)[40].view(24, 3)

        # compute translation from foot-contact probability
        contact = contact[0][40]
        
        # velocity from network-based estimation
        root_vel = vel.view(-1, 24, 3)[:, 0]

        pred_vel = root_vel[40] / (30/2) #hardcoded fps = 30, vel_scale = 2

        # Need to implement in Swift

        # lfoot_pos, rfoot_pos = joints[10], joints[11]
        # if contact[0] > contact[1]:
        #     contact_vel = self.last_lfoot_pos - lfoot_pos + self.gravity_velocity
        # else:
        #     contact_vel = self.last_rfoot_pos - rfoot_pos + self.gravity_velocity
        # weight = self._prob_to_weight(contact.max())
        # velocity = art.math.lerp(pred_vel, contact_vel, weight)
        # current_foot_y = self.current_root_y + min(lfoot_pos[1].item(), rfoot_pos[1].item())
        # if current_foot_y + velocity[1].item() <= self.floor_y:
        #     velocity[1] = self.floor_y - current_foot_y

        # self.current_root_y += velocity[1].item()
        # self.last_lfoot_pos, self.last_rfoot_pos = lfoot_pos, rfoot_pos
        # self.last_root_pos += velocity

        return curr_pose, joints, pred_vel, contact

In [32]:
model = MobilePoserBase(
                    n_reduced=joint_set.n_full,
                    # n_reduced=joint_set.n_reduced,
                    ignored=joint_set.ignored,
                    n_imu=model_config.n_imu,
                    n_output_joints=model_config.n_output_joints,
                    seq_length=torch.tensor([model_config.past_frames+model_config.future_frames]))

In [33]:
model.load_state_dict(torch.load("/Users/brianchen/Research/MobilePoser/mobileposer/checkpoints/model_finetuned.pth", map_location="cpu"))
# base.load_state_dict(torch.load("/Users/brianchen/Research/MobilePoser/mobileposer/checkpoints/weights.pth"))
model.eval()

MobilePoserBase(
  (joints): JointsBase(
    (joints): RNN(
      (rnn): LSTM(256, 256, num_layers=2, batch_first=True, bidirectional=True)
      (linear1): Linear(in_features=60, out_features=256, bias=True)
      (linear2): Linear(in_features=512, out_features=72, bias=True)
      (dropout): Dropout(p=0.4, inplace=False)
    )
  )
  (pose): PoserBase(
    (pose): RNN(
      (rnn): LSTM(256, 256, num_layers=2, batch_first=True, bidirectional=True)
      (linear1): Linear(in_features=132, out_features=256, bias=True)
      (linear2): Linear(in_features=512, out_features=144, bias=True)
      (dropout): Dropout(p=0.4, inplace=False)
    )
  )
  (foot_contact): FootContactBase(
    (footcontact): RNN(
      (rnn): LSTM(64, 64, num_layers=2, batch_first=True, bidirectional=True)
      (linear1): Linear(in_features=132, out_features=64, bias=True)
      (linear2): Linear(in_features=128, out_features=2, bias=True)
      (dropout): Dropout(p=0.4, inplace=False)
    )
  )
  (velocity): Veloc

In [34]:
scripted_model = torch.jit.script(model)

In [35]:
num_past_frames = model_config.past_frames
num_future_frames = model_config.future_frames
num_total_frames = num_past_frames + num_future_frames
data = torch.zeros((60))
imu = data.repeat(num_total_frames, 1)

input_length = torch.tensor([num_total_frames])
imu_input = imu.unsqueeze(0)
h, c = (torch.zeros((2, 1, 256)), torch.zeros((2, 1, 256)))
imu_frames = torch.rand_like(torch.zeros((45, 60)))
ori_raw = torch.rand_like(torch.zeros((1, 5, 4)))
acc_raw = torch.rand_like(torch.zeros((1, 5, 3)))
smpl2imu = torch.rand_like(torch.zeros((3, 3)))
accOffset = torch.rand_like(torch.zeros((5, 3, 1)))
device2bone = torch.rand_like(torch.zeros((5, 3, 3)))

In [36]:
model.eval()
with torch.no_grad():
    out = model(
                # imu_frames, 
                ori_raw, 
                acc_raw,
                accOffset,
                smpl2imu,
                device2bone,
                h,
                c)

In [None]:
scripted_core = scripted_model.eval()
traced_core = torch.jit.trace(model, (
                                                # imu_frames, 
                                                ori_raw, 
                                                acc_raw,
                                                accOffset,
                                                smpl2imu,
                                                device2bone,
                                                h,
                                                c))
model_from_trace = ct.convert(
    traced_core,
    inputs=[
            # ct.TensorType(shape=imu_frames.shape), 
            ct.TensorType(shape=ori_raw.shape), 
            ct.TensorType(shape=acc_raw.shape), 
            ct.TensorType(shape=accOffset.shape),
            ct.TensorType(shape=smpl2imu.shape),
            ct.TensorType(shape=device2bone.shape),
            ct.TensorType(shape=h.shape),
            ct.TensorType(shape=c.shape)]
)

  mask_1d = torch.tensor([0, 0, 0, 1, 0], dtype=acc_raw.dtype)
  pred_pose, pred_joints, pred_vel, foot_contact, velocity_h, velocity_c = self.run_model(imu.unsqueeze(0), input_lengths=torch.tensor([45]), h=h, c=c)
When both 'convert_to' and 'minimum_deployment_target' not specified, 'convert_to' is set to "mlprogram" and 'minimum_deployment_target' is set to ct.target.iOS15 (which is same as ct.target.macOS12). Note: the model will not run on systems older than iOS15/macOS12/watchOS8/tvOS15. In order to make your model run on older system, please set the 'minimum_deployment_target' to iOS14/iOS13. Details please see the link: https://apple.github.io/coremltools/docs-guides/source/target-conversion-formats.html
Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 553/555 [00:00<00:00, 1491.99 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 76.30 passes/s]
Running MIL def

In [39]:
model_from_trace.save("MobilePoserCompleteInitial.mlpackage")

In [39]:
import torch
import math
from typing import List
import cv2

import torch
import torch.nn as nn

class ProcessInputs(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, 
                imu: torch.Tensor, 
                ori_raw: torch.Tensor, 
                acc_raw: torch.Tensor, 
                acc_offsets: torch.Tensor, 
                smpl2imu: torch.Tensor, 
                device2bone: torch.Tensor) -> torch.Tensor:
        ori_raw = quaternion_to_rotation_matrix(ori_raw).view(-1, 5, 3, 3) # hardcoded n_imus = 5
        glb_acc = (smpl2imu.matmul(acc_raw.view(-1, 5, 3, 1)) - acc_offsets).view(-1, 5, 3) # hardcoded n_imus = 5
        glb_ori = smpl2imu.matmul(ori_raw).matmul(device2bone)

        # normalization 
        _acc = glb_acc.view(-1, 5, 3)[:, [1, 4, 3, 0, 2]] / 30 #hardcoded acc_scale = 30
        _ori = glb_ori.view(-1, 5, 3, 3)[:, [1, 4, 3, 0, 2]]

        acc = torch.zeros_like(_acc)
        ori = torch.zeros_like(_ori)

        # device combo
        # c = [1, 3] # hardcoded rw_rp': [1, 3]
        c = [3] #hardcoded rp: [3]

        acc[:, c] = _acc[:, c] 
        ori[:, c] = _ori[:, c]
        
        imu_input = torch.cat([acc.flatten(1), ori.flatten(1)], dim=1).squeeze(0)

        # Pushinng this logical if statement to Swift
        # imu_input.repeat(45, 1) if imu is None else 
        imu = torch.cat((imu[1:], imu_input.view(1, -1))) #hardcoded num_total_frames = 45

        return imu.unsqueeze(0), torch.tensor([45]), imu.squeeze(0) #hardcoded num_total_frames = 45
        # imu_input, imu_shape, self.imu to store


class ProcessInputsInitial(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self,  
                ori_raw: torch.Tensor, 
                acc_raw: torch.Tensor, 
                acc_offsets: torch.Tensor, 
                smpl2imu: torch.Tensor, 
                device2bone: torch.Tensor) -> torch.Tensor:
        ori_raw = quaternion_to_rotation_matrix(ori_raw).view(-1, 5, 3, 3) # hardcoded n_imus = 5
        glb_acc = (smpl2imu.matmul(acc_raw.view(-1, 5, 3, 1)) - acc_offsets).view(-1, 5, 3) # hardcoded n_imus = 5
        glb_ori = smpl2imu.matmul(ori_raw).matmul(device2bone)

        # normalization 
        _acc = glb_acc.view(-1, 5, 3)[:, [1, 4, 3, 0, 2]] / 30 #hardcoded acc_scale = 30
        _ori = glb_ori.view(-1, 5, 3, 3)[:, [1, 4, 3, 0, 2]]

        acc = torch.zeros_like(_acc)
        ori = torch.zeros_like(_ori)

        # device combo
        # c = [1, 3] # hardcoded rw_rp': [1, 3]
        c = [3] #hardcoded rp: [3]

        acc[:, c] = _acc[:, c] 
        ori[:, c] = _ori[:, c]
        
        imu_input = torch.cat([acc.flatten(1), ori.flatten(1)], dim=1).squeeze(0)

        # Pushinng this logical if statement to Swift
        imu = imu_input.repeat(45, 1)

        return imu.unsqueeze(0), torch.tensor([45]), imu.squeeze(0) #hardcoded num_total_frames = 45
        # imu_input, imu_shape, self.imu to store

Orientation_shape: torch.Size([1, 5, 4]) \
Acceleration_shape: torch.Size([1, 5, 3]) \
SMPL2IMU_shape: torch.Size([3, 3]) \
AccOffset_shape: torch.Size([5, 3, 1]) \
Device2Bone_shape: torch.Size([5, 3, 3]) 

In [32]:
imu_frames = torch.rand_like(torch.zeros((45, 60)))
ori_raw = torch.rand_like(torch.zeros((1, 5, 4)))
acc_raw = torch.rand_like(torch.zeros((1, 5, 3)))
smpl2imu = torch.rand_like(torch.zeros((3, 3)))
accOffset = torch.rand_like(torch.zeros((5, 3, 1)))
device2bone = torch.rand_like(torch.zeros((5, 3, 3)))

In [41]:
process_func = ProcessInputs()
process_func.eval()
process_func_initial = ProcessInputsInitial()
process_func_initial.eval()

traced_func = torch.jit.trace(process_func, example_inputs=(imu_frames, ori_raw, acc_raw, accOffset, smpl2imu, device2bone))
process_func_model = ct.convert(
    traced_func,
    inputs=[ct.TensorType(shape=imu_frames.shape), 
            ct.TensorType(shape=ori_raw.shape), 
            ct.TensorType(shape=acc_raw.shape), 
            ct.TensorType(shape=accOffset.shape),
            ct.TensorType(shape=smpl2imu.shape), 
            ct.TensorType(shape=device2bone.shape)],
    convert_to="mlprogram")
process_func_model.save("ProcessInputs.mlpackage")

traced_func = torch.jit.trace(process_func_initial, example_inputs=(ori_raw, acc_raw, accOffset, smpl2imu, device2bone))
process_func_model_initial = ct.convert(
    traced_func,
    inputs=[ct.TensorType(shape=ori_raw.shape), 
            ct.TensorType(shape=acc_raw.shape), 
            ct.TensorType(shape=accOffset.shape),
            ct.TensorType(shape=smpl2imu.shape), 
            ct.TensorType(shape=device2bone.shape)],
    convert_to="mlprogram")
process_func_model_initial.save("ProcessInputsInitial.mlpackage")

  return imu.unsqueeze(0), torch.tensor([45]), imu.squeeze(0) #hardcoded num_total_frames = 45
Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:  99%|█████████▉| 320/322 [00:00<00:00, 5726.23 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 997.60 passes/s]
Running MIL default pipeline: 100%|██████████| 89/89 [00:00<00:00, 1712.00 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 4426.32 passes/s]
  return imu.unsqueeze(0), torch.tensor([45]), imu.squeeze(0) #hardcoded num_total_frames = 45
Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:  99%|█████████▉| 312/314 [00:00<00:00, 4404.85 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 1231.81 passes/s]
Running MIL default pipeline: 100%|██████████| 89/89 [00:00<00:00, 2257.46 passes/s]
Running MIL ba

In [None]:
import torch

def quaternion_to_matrix(q: torch.Tensor) -> torch.Tensor:
    """
    Convert quaternion q=[x,y,z,w] (shape (...,4)) to a rotation matrix (...,3,3).
    We use torch.stack on literal Python lists when forming each row.
    """
    q = q.to(dtype=torch.float32)
    q = q / q.norm(dim=-1, keepdim=True)

    x, y, z, w = q.unbind(-1)

    xx = x*x; yy = y*y; zz = z*z; ww = w*w
    xy = x*y; xz = x*z; xw = x*w
    yz = y*z; yw = y*w; zw = z*w

    m00 = ww + xx - yy - zz
    m01 = 2*(xy - zw)
    m02 = 2*(xz + yw)

    m10 = 2*(xy + zw)
    m11 = ww - xx + yy - zz
    m12 = 2*(yz - xw)

    m20 = 2*(xz - yw)
    m21 = 2*(yz + xw)
    m22 = ww - xx - yy + zz

    row0 = torch.cat((m00.unsqueeze(-1), m01.unsqueeze(-1), m02.unsqueeze(-1)), dim=-1)  # shape (...,3)
    row1 = torch.cat((m10.unsqueeze(-1), m11.unsqueeze(-1), m12.unsqueeze(-1)), dim=-1)
    row2 = torch.cat((m20.unsqueeze(-1), m21.unsqueeze(-1), m22.unsqueeze(-1)), dim=-1)

    return torch.cat(
        (row0.unsqueeze(-2),  # shape (...,1,3)
         row1.unsqueeze(-2),  # shape (...,1,3)
         row2.unsqueeze(-2)), # shape (...,1,3)
        dim=-2                # final shape (...,3,3)
    )


def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
    """
    Convert a valid rotation matrix (shape (...,3,3)) into a quaternion [x,y,z,w] of shape (...,4),
    using only elementwise ops + torch.clamp + torch.sqrt + torch.sign + torch.stack.
    This avoids any concat or boolean‐mask indexing.
    """
    # 1) Unpack the 3×3 entries as scalars of shape (...)
    m00 = matrix[..., 0, 0]
    m01 = matrix[..., 0, 1]
    m02 = matrix[..., 0, 2]
    m10 = matrix[..., 1, 0]
    m11 = matrix[..., 1, 1]
    m12 = matrix[..., 1, 2]
    m20 = matrix[..., 2, 0]
    m21 = matrix[..., 2, 1]
    m22 = matrix[..., 2, 2]

    # 2) Compute trace
    trace = m00 + m11 + m22  

    # a) Clamp arguments to sqrt to be ≥ 0
    t0 = torch.clamp(1.0 + trace, min=0.0)                   
    t1 = torch.clamp(1.0 + m00 - m11 - m22, min=0.0)
    t2 = torch.clamp(1.0 - m00 + m11 - m22, min=0.0)
    t3 = torch.clamp(1.0 - m00 - m11 + m22, min=0.0)

    # b) Compute square roots
    sqrt0 = torch.sqrt(t0)  
    sqrt1 = torch.sqrt(t1)
    sqrt2 = torch.sqrt(t2)
    sqrt3 = torch.sqrt(t3)

    # c) Assemble each component
    qw = 0.5 * sqrt0
    qx = 0.5 * torch.sign(m21 - m12) * sqrt1
    qy = 0.5 * torch.sign(m02 - m20) * sqrt2
    qz = 0.5 * torch.sign(m10 - m01) * sqrt3

    quat = torch.stack([qx, qy, qz, qw], dim=-1) 

    return quat / quat.norm(dim=-1, keepdim=True)

In [169]:
test_ori = torch.rand_like(torch.zeros((4)))

print(R.from_quat(test_ori).as_matrix())
print(quaternion_to_matrix(test_ori.unsqueeze(0)))

test_matrix = torch.tensor(R.from_quat(test_ori).as_matrix())

print(R.from_matrix(test_matrix).as_quat())
print(matrix_to_quaternion(test_matrix.unsqueeze(0)))

[[-0.55687259 -0.04613942  0.82931542]
 [ 0.79003913  0.27876994  0.54600869]
 [-0.25638074  0.95924891 -0.1187874 ]]
tensor([[[-0.5569, -0.0461,  0.8293],
         [ 0.7900,  0.2788,  0.5460],
         [-0.2564,  0.9592, -0.1188]]])
[0.26605679 0.69900464 0.53835751 0.38830077]
tensor([[0.2661, 0.6990, 0.5384, 0.3883]], dtype=torch.float64)


In [201]:
torch.matmul(torch.eye(3), quaternion_to_matrix(test_ori.unsqueeze(0))).shape

torch.Size([1, 3, 3])

In [212]:
matrix_to_quaternion(torch.matmul(torch.eye(3), quaternion_to_matrix(test_ori.unsqueeze(0))))

tensor([[0.2661, 0.6990, 0.5384, 0.3883]])

In [213]:
matrix_to_quaternion(torch.matmul(torch.eye(3), quaternion_to_matrix(test_ori)))

tensor([0.2661, 0.6990, 0.5384, 0.3883])

In [240]:
from scipy.spatial.transform import Rotation as R
class Sensor2Global(nn.Module):
    def __init__(self):
        super().__init__()
    
    def sensor2global(self, ori, acc):

        global_inertial_frame = torch.eye(3) # hardcoded
        og_mat = quaternion_to_matrix(ori)

        # global_mat = torch.matmul(global_inertial_frame.T, og_mat)
        global_mat = torch.matmul(global_inertial_frame, og_mat)
        global_ori = matrix_to_quaternion(global_mat).squeeze(0)       
        acc_ref   = torch.matmul(og_mat, acc.unsqueeze(-1)).squeeze(-1)
        # global_acc = torch.matmul(global_inertial_frame.T, acc_ref.unsqueeze(-1)).squeeze(-1)
        global_acc = torch.matmul(global_inertial_frame, acc_ref.unsqueeze(-1)).squeeze(-1)

        return global_ori, global_acc.squeeze(0)
    
    def forward(self, all_ori, all_acc):

        dev1_ori, dev1_acc = self.sensor2global(all_ori[0].unsqueeze(0), all_acc[0])
        dev2_ori, dev2_acc = self.sensor2global(all_ori[1].unsqueeze(0), all_acc[1])
        dev3_ori, dev3_acc = self.sensor2global(all_ori[2].unsqueeze(0), all_acc[2])
        dev4_ori, dev4_acc = self.sensor2global(all_ori[3].unsqueeze(0), all_acc[3])
        dev5_ori, dev5_acc = self.sensor2global(all_ori[4].unsqueeze(0), all_acc[4])

        all_ori_global = torch.stack([dev1_ori, dev2_ori, dev3_ori, dev4_ori, dev5_ori], dim=0).unsqueeze(0)
        all_acc_global = torch.stack([dev1_acc, dev2_acc, dev3_acc, dev4_acc, dev5_acc], dim=0).unsqueeze(0)

        return all_ori_global, all_acc_global

In [241]:
sensor2global_func = Sensor2Global()
sensor2global_func.eval()
ori_input = torch.rand((5,4))
acc_input = torch.rand((5,3))
traced_func = torch.jit.trace(sensor2global_func, example_inputs=(ori_input, acc_input))

In [242]:
sensor2global_func_mlpackage = ct.convert(
    traced_func,
    inputs=[ct.TensorType(shape=ori_input.shape), 
            ct.TensorType(shape=acc_input.shape)],
    convert_to="mlprogram")

Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 1274/1276 [00:00<00:00, 5086.17 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 63.50 passes/s]
Running MIL default pipeline: 100%|██████████| 89/89 [00:02<00:00, 36.36 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 53.10 passes/s]


In [243]:
sensor2global_func_mlpackage.save("Sensor2Global.mlpackage")

In [244]:
sensor2global_func_mlpackage.predict(data={"all_ori": ori_input, "all_acc": acc_input})

{'var_1287': array([[[0.05899048, 0.5283203 , 0.2854004 , 0.7973633 ],
         [0.6791992 , 0.1619873 , 0.15344238, 0.69921875],
         [0.7011719 , 0.53759766, 0.3256836 , 0.33666992],
         [0.2175293 , 0.5493164 , 0.52685547, 0.61083984],
         [0.31958008, 0.6220703 , 0.7133789 , 0.043396  ]]], dtype=float32),
 'var_1292': array([[[ 0.66064453,  1.3183594 ,  0.05773926],
         [ 0.60058594, -0.28735352,  0.15100098],
         [ 0.8496094 , -0.01779175, -0.28198242],
         [ 0.36816406,  0.56640625,  0.8261719 ],
         [-0.2944336 ,  0.4206543 ,  0.24401855]]], dtype=float32)}

In [None]:
class Calibrator(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, imu1_ori, oriMean, accMean):
        smpl2imu = quaternion_to_rotation_matrix(imu1_ori).view(3, 3).t()

        oris = quaternion_to_rotation_matrix(oriMean)
        device2bone = smpl2imu.matmul(oris).transpose(1, 2).matmul(torch.eye(3))
        acc_offsets = smpl2imu.matmul(accMean.unsqueeze(-1)) 

        return smpl2imu, device2bone, acc_offsets

In [24]:
imu1_ori_input = torch.rand((4,))
oriMean_input = torch.rand((5, 4))
accMean_input = torch.rand((5, 3))

calibrator  = Calibrator()
out = calibrator(imu1_ori_input, oriMean_input, accMean_input)

In [25]:
traced_func = torch.jit.trace(calibrator, example_inputs=(imu1_ori_input, oriMean_input, accMean_input))

calibrator_package = ct.convert(
    traced_func,
    inputs=[ct.TensorType(shape=imu1_ori_input.shape), 
            ct.TensorType(shape=oriMean_input.shape),
            ct.TensorType(shape=accMean_input.shape)],
    convert_to="mlprogram")

Model is not in eval mode. Consider calling '.eval()' on your model prior to conversion
Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:  99%|█████████▉| 298/300 [00:00<00:00, 5315.16 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 379.77 passes/s]
Running MIL default pipeline: 100%|██████████| 89/89 [00:00<00:00, 218.49 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 186.29 passes/s]


In [174]:
calibrator_package.save("Calibrator.mlpackage")