In [1]:
import numpy
import torch
import torch as th
import torch.nn.functional as F

In [2]:
# Thanks for StructTrans
# https://github.com/jingraham/neurips19-graph-protein-design
def nan_to_num(tensor: torch.Tensor, nan=0.0) -> torch.Tensor:
    return tensor.masked_fill_(torch.isnan(tensor), nan)
     
 
def _normalize(tensor: torch.Tensor, dim=-1) -> torch.Tensor:
    return nan_to_num(F.normalize(tensor, dim=dim))


In [3]:
def _quaternions(R: th.Tensor) -> th.Tensor:
    """((...), 3, 3) -> ((...), 4)"""
    ii, ij, ik, ji, jj, jk, ki, kj, kk = th.unbind(R.reshape(*R.shape[:-2], -1), -1)

    qi = +ii - jj - kk
    qj = -ii + jj - kk
    qk = -ii - jj + kk

    proto_q = th.stack((qi, qj, qk), dim=-1)
    magnitudes = proto_q.add_(1.0).abs().sqrt().mul_(0.5)

    signs = th.stack([kj - jk, ik - ki, ji - ij], -1).sign()
    xyz = signs * magnitudes
    w = F.relu(1 + ii + jj + kk).sqrt().mul_(0.5)
    Q = th.cat((xyz, w[..., None]), -1)

    return _normalize(Q, dim=-1)


In [5]:
def quaternions(R):
    diag = torch.diagonal(R, dim1=-2, dim2=-1)
    Rxx, Ryy, Rzz = diag.unbind(-1)

    magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
          Rxx - Ryy - Rzz, 
        - Rxx + Ryy - Rzz, 
        - Rxx - Ryy + Rzz
    ], -1)))

    _R = lambda i,j: R[:,:,:,i,j]

    signs = torch.sign(torch.stack([
        _R(2,1) - _R(1,2),
        _R(0,2) - _R(2,0),
        _R(1,0) - _R(0,1)
    ], -1))

    xyz = signs * magnitudes
    w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
    Q = torch.cat((xyz, w), -1)

    return _normalize(Q, dim=-1)
