In [1]:
import sys
if '/mnt/c/Users/mm851/PycharmProjects/ProteinLearning' not in sys.path:
    sys.path = ['/mnt/c/Users/mm851/PycharmProjects/ProteinLearning'] + sys.path

In [2]:
from protein_learning.protein_utils.align.kabsch_align import kabsch_align
from protein_learning.protein_utils.align.per_residue import get_per_res_alignment
import torch
import numpy as np

# Test kabsch alignment

In [3]:
def rotation(a, b, c, device = 'cpu'):
    cos, sin = torch.cos, torch.sin
    sina, sinb, sinc = sin(a), sin(b), sin(c)
    cosa, cosb, cosc = cos(a), cos(b), cos(c)
    r = [torch.cat([cosc * cosb - sinc * sina * sinb,
          -sinc * cosa,
          cosc * sinb + sinc * sina * cosb]),
         torch.cat([sinc * cosb + cosc * sina * sinb,
          cosc * cosa,
          sinc * sinb - cosc * sina * cosb]),
         torch.cat([-cosa * sinb, sina, cosa * cosb])]
    return torch.vstack(r)

def random_rotation(n = 1):
    return torch.cat([rotation(*torch.randn(3,1)).unsqueeze(0) for _ in range(n)], dim=0)

### simplest case - both sets of coordinates are the same

In [4]:
b,n,a,c = 1,4,16,3
align_to = torch.randn(n,a,c)
align_from = torch.clone(align_to)
aligned_to, aligned_from = kabsch_align(align_to = align_to, align_from=align_from)
# align_to coords should not change
assert torch.allclose(align_to,aligned_to)
#both sets of coordinates should be the same
assert torch.allclose(aligned_to,aligned_from,atol=1e-5), f"norm : {torch.norm(aligned_to-aligned_from)}"

### apply rotation

In [5]:
b,n,a,c = 1,4,16,3
rotn = random_rotation()
align_to = torch.randn(n,a,c)
align_from = torch.einsum("bij,nai->naj",rotn,align_to)
aligned_to, aligned_from = kabsch_align(align_to = align_to, align_from=align_from)
# align_to coords should not change
print("norm(align_to-align_from) :",torch.norm(align_to-align_from))
print("norm(align_to-aligned_to) :",torch.norm(aligned_to-align_to))
print("norm(aligned_to-aligned_from) :",torch.norm(aligned_to-aligned_from))
print("norm(aligned_from-align_from) :",torch.norm(aligned_from - align_from))

assert torch.allclose(align_to,aligned_to)
#both sets of coordinates should be the same
assert torch.allclose(aligned_to,aligned_from,atol=1e-5), f"norm : {torch.norm(aligned_to-aligned_from)}"


norm(align_to-align_from) : tensor(11.4945)
norm(align_to-aligned_to) : tensor(0.)
norm(aligned_to-aligned_from) : tensor(3.7099e-06)
norm(aligned_from-align_from) : tensor(11.4945)


### apply rotation and translation

In [6]:
b,n,a,c = 1,4,16,3
rotn = random_rotation()
align_to = torch.randn(n,a,c)
align_from = torch.einsum("bij,nai->naj",rotn,align_to) + torch.randn(1,1,3)
aligned_to, aligned_from = kabsch_align(align_to = align_to, align_from=align_from)
# align_to coords should not change
print("norm(align_to-align_from) :",torch.norm(align_to-align_from))
print("norm(align_to-aligned_to) :",torch.norm(aligned_to-align_to))
print("norm(aligned_to-aligned_from) :",torch.norm(aligned_to-aligned_from))
print("norm(aligned_from-align_from) :",torch.norm(aligned_from - align_from))

assert torch.allclose(align_to,aligned_to)
#both sets of coordinates should be the same
assert torch.allclose(aligned_to,aligned_from,atol=1e-5), f"norm : {torch.norm(aligned_to-aligned_from)}"


norm(align_to-align_from) : tensor(21.9438)
norm(align_to-aligned_to) : tensor(0.)
norm(aligned_to-aligned_from) : tensor(3.0045e-06)
norm(aligned_from-align_from) : tensor(21.9438)


# Test Per-Residue Alignment

In [16]:
b,n,a,c = 1,16,6,3
rotn = random_rotation(n=n)
align_to = torch.randn(n,a,c)
align_from = torch.einsum("nij,nai->naj",rotn,align_to) + torch.randn(n,1,3)
aligned_from = get_per_res_alignment(align_to = align_to[:,:3,:], align_from=align_from[:,:3,:], to_align=align_from)
aligned_to = align_to
# align_to coords should not change
print("norm(align_to-align_from) :",torch.norm(align_to-align_from))
print("norm(align_to-aligned_to) :",torch.norm(aligned_to-align_to))
print("norm(aligned_to-aligned_from) :",torch.norm(aligned_to-aligned_from))

assert torch.allclose(align_to,aligned_to)
#both sets of coordinates should be the same
assert torch.allclose(aligned_to,aligned_from,atol=1e-5), f"norm : {torch.norm(aligned_to-aligned_from)}"


norm(align_to-align_from) : tensor(23.9387)
norm(align_to-aligned_to) : tensor(0.)
norm(aligned_to-aligned_from) : tensor(5.0102e-06)
