In [32]:
import torch
from torch import Tensor
from typing import Tuple, Union, List

cos_max, cos_min = (1 - 1e-9), -(1 - 1e-9)

In [89]:
def signed_dihedral_all_12(ps, *masks):
    """
    Computes signed dihedral of points taking 
    :param ps: 
    :param masks: 
    :return: 
    """
    na = None
    if not masks:
        p0, p1, p2, p3 = ps[0], ps[1], ps[2], ps[3]
    else:
        p0, p1, p2, p3 = ps[masks[0]], ps[masks[1]], ps[masks[2]], ps[masks[3]]
    b0, b1, b2 = p0 - p1, p2[na, :, :] - p1[:, na, :], p3 - p2
    nrm = torch.norm(b1, dim=2)
    nrm[nrm < 1e-7] = 1
    b1 /= nrm[:, :, na]
    print(b0[:, na, :].shape, b1.shape)
    tmp = torch.sum(b0[:, na, :] * b1, dim=2)[:, :, na]
    print("b1 shape",b1.shape)
    print('tmp',tmp.shape)
    v = b0[:, na, :] - torch.sum(b0[:, na, :] * b1, dim=2)[:, :, na] * b1
    w = b2[na, :, :] - torch.sum(b2[na, :, :] * b1, dim=2)[:, :, na] * b1
    x = torch.sum(v * w, dim=2)
    y = torch.sum(torch.cross(b1, v) * w, dim=2)
    return torch.atan2(y, x)

In [183]:
def signed_dihedral_all_12_batched(ps):
    """
    Computes signed dihedral of points taking 
    :param ps: 
    :param masks: 
    :return: 
    """
    p0, p1, p2, p3 = ps
    b0, b1, b2 = p0 - p1, p2.unsqueeze(-3) - p1.unsqueeze(-2), p3 - p2
    b1 = b1 / torch.norm(b1, dim=-1, keepdim=True).clamp_min(min_norm_clamp)
    v = b0.unsqueeze(-2) - torch.sum(b0.unsqueeze(-2) * b1, dim=-1, keepdim=True) * b1
    w = b2.unsqueeze(-3) - torch.sum(b2.unsqueeze(-3) * b1, dim=-1, keepdim=True) * b1
    x = torch.sum(v * w, dim=-1)
    y = torch.sum(torch.cross(b1, v) * w, dim=-1)
    return torch.atan2(y, x)

In [184]:
def signed_dihedral_all_123(ps, *masks):
    na = None
    if not masks:
        p0, p1, p2, p3 = ps[0], ps[1], ps[2], ps[3]
    else:
        p0, p1, p2, p3 = ps[masks[0]], ps[masks[1]], ps[masks[2]], ps[masks[3]]
    b0, b1, b2 = p0 - p1, p2 - p1, p3[na, :, :] - p2[:, na, :]
    nrm = torch.norm(b1, dim=1)
    b1[nrm > 0] /= nrm[nrm > 0][:, na]
    v = b0 - torch.sum(b0 * b1, dim=1)[:, na] * b1
    w = b2 - torch.sum(b2 * b1[:, na, :], dim=2)[:, :, na] * b1[:, na, :]
    x = torch.sum(v[:, na, :] * w, axis=2)
    print(torch.cross(b1, v).shape, torch.cross(b1, v)[:, na].shape)
    y = torch.sum(torch.cross(b1, v)[:, na] * w, dim=2)
    print(y.shape)
    ret = torch.atan2(y, x)
    return ret

In [185]:
min_norm_clamp = 1e-7
def signed_dihedral_all_123_batched(ps):
    na = None
    p0, p1, p2, p3 = ps
    b0, b1, b2 = p0 - p1, p2 - p1, p3.unsqueeze(-3) - p2.unsqueeze(-2)
    b1 = b1 / torch.norm(b1, dim=-1, keepdim=True).clamp_min(min_norm_clamp)
    v = b0 - torch.sum(b0 * b1, dim=-1, keepdim=True) * b1
    w = b2 - torch.sum(b2 * b1.unsqueeze(-2), dim=-1, keepdim=True) * b1.unsqueeze(-2)
    x = torch.sum(v.unsqueeze(-2) * w, dim=-1)
    y = torch.sum(torch.cross(b1, v).unsqueeze(-2) * w, dim=-1)
    ret = torch.atan2(y, x)
    return ret

In [197]:
def unsigned_angle_all(ps, *masks):
    """
    returns a matrix M where
    v01, v12 = ps[0,i]-ps[1,i], ps[1,i]-ps[2,j]
    M[i,j] = arccos (dot(v01,v12) / (||(ps[0,i]-ps[1,i])||*||(ps[2,j]-ps[1,i])||))
    i.e. angle btwn the lines formed by ps[0,i],ps[1,i] and ps[1,i],ps[2,j].
    """
    if not masks:
        p0, p1, p2 = ps[0], ps[1], ps[2]
    else:
        p0, p1, p2 = ps[masks[0]], ps[masks[1]], ps[masks[2]]
    b01, b12 = p0 - p1, p2[None, :, :] - p1[:, None, :]
    M = b01[:, None, :] * b12
    n01, n12 = torch.norm(b01, dim=1), torch.norm(b12, dim=2)
    prods = n01[:, None] * n12
    m = prods == 0
    prods[m] = 1
    cos_theta = torch.sum(M, dim=2) / prods
    cos_theta[cos_theta < cos_min] = cos_min
    cos_theta[cos_theta > cos_max] = cos_max
    return torch.acos(cos_theta)

In [198]:
def unsigned_angle_all_batched(ps, *masks):
    """
    returns a matrix M where
    v01, v12 = ps[0,i]-ps[1,i], ps[1,i]-ps[2,j]
    M[i,j] = arccos (dot(v01,v12) / (||(ps[0,i]-ps[1,i])||*||(ps[2,j]-ps[1,i])||))
    i.e. angle btwn the lines formed by ps[0,i],ps[1,i] and ps[1,i],ps[2,j].
    """
    
    p0, p1, p2 = ps[0], ps[1], ps[2]
    b01, b12 = p0 - p1, p2.unsqueeze(-3) - p1.unsqueeze(-2)
    M = b01.unsqueeze(-2) * b12
    n01, n12 = torch.norm(b01, dim=-1, keepdim=True), torch.norm(b12, dim=-1)
    prods = torch.clamp_min(n01 * n12, min_norm_clamp)
    cos_theta = torch.sum(M, dim=-1) / prods
    cos_theta[cos_theta < cos_min] = cos_min
    cos_theta[cos_theta > cos_max] = cos_max
    return torch.acos(cos_theta)

In [199]:
import torch
b,n = 2,20
ps = [torch.randn(b,n,3) for _ in range(4)]
unbatched_fn = signed_dihedral_all_123
batched_fn = signed_dihedral_all_123_batched
unbatched_fn = signed_dihedral_all_12
batched_fn = signed_dihedral_all_12_batched
unbatched_fn = unsigned_angle_all
batched_fn = unsigned_angle_all_batched

In [200]:
batched_out = batched_fn(ps)
unbatched_out = [unbatched_fn([p[batch] for p in ps]) for batch in range(b)]
unbatched_out = torch.cat([x.unsqueeze(0) for x in unbatched_out], dim=0)
unbatched_batched_out = [batched_fn([p[batch] for p in ps]) for batch in range(b)]
unbatched_batched_out = torch.cat([x.unsqueeze(0) for x in unbatched_batched_out], dim=0)

print(f"batched : {batched_out.shape}, unbatched: {unbatched_out.shape}")
print(torch.norm(batched_out[0]-unbatched_out[0]))
print(torch.norm(batched_out-unbatched_out))
print(torch.norm(unbatched_batched_out-unbatched_out))

batched : torch.Size([2, 20, 20]), unbatched: torch.Size([2, 20, 20])
tensor(0.)
tensor(0.)
tensor(0.)


In [194]:
p0,p1,p2,p3 = ps
o1 = signed_dihedral_all_12_batched(ps)
o2 = signed_dihedral_all_123_batched([p0, p3, p2, p1])
torch.norm(o1-o2)

tensor(56.3383)