In [90]:
import torch
import numpy as np

In [65]:
# batch*n
def normalize_vector( v, return_mag =False):
    batch=v.shape[0]
    v_mag = torch.sqrt(v.pow(2).sum(1))# batch
    v_mag = v_mag.view(batch,1).expand(batch,v.shape[1])
    v = v/v_mag
    if(return_mag==True):
        return v, v_mag[:,0]
    else:
        return v

def transform_matrix_torch(axisAngle):
    
    batch = axisAngle.shape[0]
    batch,dim = axisAngle.shape
    axis = axisAngle[:,0:3]
    #axis_rotate = axisAngle[:,3:]
    axis_norm = normalize_vector(axis)
    #axis_rotate_norm = normalize_vector(axis_rotate)
    b1 = axis_norm[:,0]
    b2 = axis_norm[:,1]
    b3 = axis_norm[:,2]
    # Check for the specific input (0,0,1) and return identity matrix
    identity_mask = (b1 == 0) & (b2 == 0) & (b3 == 1)
    anti_identity_mask = (b1 == 0) & (b2 == 0) & (b3 == -1)
    identity_matrix = torch.eye(3, device=axisAngle.device).unsqueeze(0).repeat(axisAngle.shape[0], 1, 1)
    anti_identity_matrix = -identity_matrix
    print(identity_mask,anti_identity_mask)
    # Calculate matrix elements
    m00 = b3 + (b2**2 * (-1 + b3)) / (-1 + torch.abs(b3)**2)
    m01 = -(b1 * b2 * (-1 + b3)) / (-1 + torch.abs(b3)**2)
    m02 = b1
    
    m10 = -(b1 * b2 * (-1 + b3)) / (-1 + torch.abs(b3)**2)
    m11 = b3 + (b1**2 * (-1 + b3)) / (-1 + torch.abs(b3)**2)
    m12 = b2
    
    m20 = -b1 
    m21 = -b2
    m22 = b3
    
    # Stack the results to form the output matrix of shape (batch_size, 3, 3)
    output = torch.stack([
        torch.stack([m00, m01, m02], dim=-1),
        torch.stack([m10, m11, m12], dim=-1),
        torch.stack([m20, m21, m22], dim=-1)
    ], dim=1)
    
    # Replace the output with identity matrix where the input is (0,0,1)
    output[identity_mask] = identity_matrix[identity_mask]
    output[anti_identity_mask] = anti_identity_matrix[anti_identity_mask]
    
    return output


In [93]:
# Example usage:
input_data_torch = torch.tensor([[1,2,3],[0,0,1],[0,0,-1],[1,0,0],[0,1,0],[1,1,0],[1,0,1],[0,1,1]])
A=transform_matrix_torch(input_data_torch)[0]
#print(transform_matrix_torch(input_data_torch))

tensor([False,  True, False, False, False, False, False, False]) tensor([False, False,  True, False, False, False, False, False])


In [80]:
rot = transform_matrix_torch(input_data_torch)

tensor([False,  True,  True, False, False, False, False, False]) tensor([False, False, False, False, False, False, False, False])


In [97]:

vectors = torch.tensor([0,np.sqrt(2)/2,np.sqrt(2)/2]).unsqueeze(0).repeat(8,1).to(torch.float32)
print(vectors.shape)

torch.Size([8, 3])


In [98]:
# Perform batch matrix-vector multiplication
result = torch.bmm(rot, vectors.unsqueeze(-1)).squeeze(-1)
print(result)

tensor([[ 0.1329,  0.9729,  0.1890],
        [ 0.0000,  0.7071,  0.7071],
        [ 0.0000,  0.7071,  0.7071],
        [ 0.7071,  0.7071,  0.0000],
        [ 0.0000,  0.7071, -0.7071],
        [ 0.1464,  0.8536, -0.5000],
        [ 0.5000,  0.7071,  0.5000],
        [ 0.0000,  1.0000,  0.0000]])


In [96]:
print(rot)

tensor([[[ 0.9604, -0.0793,  0.2673],
         [-0.0793,  0.8414,  0.5345],
         [-0.2673, -0.5345,  0.8018]],

        [[ 1.0000,  0.0000,  0.0000],
         [ 0.0000,  1.0000,  0.0000],
         [ 0.0000,  0.0000,  1.0000]],

        [[ 1.0000,  0.0000,  0.0000],
         [ 0.0000,  1.0000,  0.0000],
         [ 0.0000,  0.0000,  1.0000]],

        [[ 0.0000, -0.0000,  1.0000],
         [-0.0000,  1.0000,  0.0000],
         [-1.0000, -0.0000,  0.0000]],

        [[ 1.0000, -0.0000,  0.0000],
         [-0.0000,  0.0000,  1.0000],
         [-0.0000, -1.0000,  0.0000]],

        [[ 0.5000, -0.5000,  0.7071],
         [-0.5000,  0.5000,  0.7071],
         [-0.7071, -0.7071,  0.0000]],

        [[ 0.7071, -0.0000,  0.7071],
         [-0.0000,  1.0000,  0.0000],
         [-0.7071, -0.0000,  0.7071]],

        [[ 1.0000, -0.0000,  0.0000],
         [-0.0000,  0.7071,  0.7071],
         [-0.0000, -0.7071,  0.7071]]])


In [None]:
a = {'name':'Adam','age':10}
print('name' in a)