In [1]:
import random
from cg_objs.cgs import *
from architecture.typical_embedding import *
from utils.file_and_folder_ops import remove_path_after_folder
from geom.transformations import sample_r3, sample_so3, zyz_euler_angles_from_rotation_matrix

In [2]:
n_c = 64
l_max = 1
n_t = len(CG_ids)
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
layer = EmbeddingLayer(n_c=n_c, l_max=1, n_t=n_t, device=device)

In [3]:
eg_aa = random.choice(list(CGs.keys()))
print(eg_aa, CGs[eg_aa])
ej_ai = layer(CGs[eg_aa])
sl = ej_ai.shape[0]
print(ej_ai.size())

F (4, 11, 18)
torch.Size([3, 256])


In [4]:
tsj_ai = sample_r3(sl, device)
Rsj_ai = sample_so3(sl, device)

In [5]:
tsj_ai, Rsj_ai

(tensor([[-1.5883,  0.8510, -0.7424],
         [ 0.2390,  1.8061, -0.8610],
         [-0.8601,  0.3638,  1.4979]]),
 tensor([[[ 0.4316,  0.3200, -0.8434],
          [ 0.6244,  0.5689,  0.5353],
          [ 0.6511, -0.7576,  0.0457]],
 
         [[-0.9537,  0.2811,  0.1070],
          [-0.1969, -0.3146, -0.9286],
          [-0.2274, -0.9067,  0.3553]],
 
         [[-0.1867,  0.8704, -0.4555],
          [-0.3583, -0.4920, -0.7934],
          [-0.9147,  0.0150,  0.4038]]]))

In [6]:
import os
cwd = os.getcwd()
base = remove_path_after_folder(cwd, 'equifold')
DATA_PATH = path = os.path.join(base, 'data')

try:
    path = os.path.join(DATA_PATH, 'J_dense.pt')
    Jd = torch.load(str(path))
except:
    path = os.path.join(DATA_PATH, 'J_dense.npy')
    Jd_np = np.load(str(path), allow_pickle = True)
    Jd = list(map(torch.from_numpy, Jd_np))

In [7]:
alpha_beta_gamma = zyz_euler_angles_from_rotation_matrix(Rsj_ai, device)

In [36]:
from functools import wraps

def cast_torch_tensor(fn):
    @wraps(fn)
    def inner(t):
        if not torch.is_tensor(t):
            t = torch.tensor(t, dtype = torch.get_default_dtype())
        return fn(t)
    return 

def default(val, d):
    return val if exists(val) else d

def exists(val):
    return val is not None

def to_order(degree):
    return 2 * degree + 1

def wigner_d_matrix(degree, alpha, beta, gamma, dtype = None, device = None):
    """Create wigner D matrices for batch of ZYZ Euler angles for degree l."""
    J = Jd[degree].type(dtype).to(device)
    order = to_order(degree)
    x_a = z_rot_mat(alpha, degree)
    x_b = z_rot_mat(beta, degree)
    x_c = z_rot_mat(gamma, degree)
    res = x_a @ J @ x_b @ J @ x_c
    return res.view(order, order)

def z_rot_mat(angle, l):
    device, dtype = angle.device, angle.dtype
    order = to_order(l)
    m = angle.new_zeros((order, order))
    inds = torch.arange(0, order, 1, dtype=torch.long, device=device)
    reversed_inds = torch.arange(2 * l, -1, -1, dtype=torch.long, device=device)
    frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device)[None]

    m[inds, reversed_inds] = torch.sin(frequencies * angle[None])
    m[inds, inds] = torch.cos(frequencies * angle[None])
    return m


In [39]:
wigner_d_matrix(1, alpha[0], beta[0], gamma[0], dtype = torch.float32, device = device)

tensor([[ 0.5689,  0.5353,  0.6244],
        [-0.7576,  0.0457,  0.6511],
        [ 0.3200, -0.8434,  0.4316]])

In [38]:
wigner_d_matrix(1, alpha[1], beta[1], gamma[1], dtype = torch.float32, device = device)

tensor([[-0.3146, -0.9286, -0.1969],
        [-0.9067,  0.3553, -0.2274],
        [ 0.2811,  0.1070, -0.9537]])

In [40]:
def direct_sum(*matrices):
    r"""Direct sum of matrices, put them in the diagonal"""
    front_indices = matrices[0].shape[:-2]
    m = sum(x.size(-2) for x in matrices)
    n = sum(x.size(-1) for x in matrices)
    total_shape = list(front_indices) + [m, n]
    out = matrices[0].new_zeros(total_shape)
    i, j = 0, 0
    for x in matrices:
        m, n = x.shape[-2:]
        out[..., i : i + m, j : j + n] = x
        i += m
        j += n
    return out

In [109]:
matrices = [
    wigner_d_matrix(1, alpha[0], beta[0], gamma[0], dtype = torch.float32, device = device),
    torch.tensor([[1]])]
matrices

[tensor([[ 0.5689,  0.5353,  0.6244],
         [-0.7576,  0.0457,  0.6511],
         [ 0.3200, -0.8434,  0.4316]]),
 tensor([[1]])]

In [110]:
m = sum(x.size(-2) for x in matrices)
m

4

In [111]:
n = sum(x.size(-1) for x in matrices)
n

4

In [112]:
total_shape = list(front_indices) + [m, n]
total_shape

[4, 4]

In [113]:
r"""Direct sum of matrices, put them in the diagonal"""
front_indices = matrices[0].shape[:-2]
m = sum(x.size(-2) for x in matrices)
n = sum(x.size(-1) for x in matrices)
total_shape = list(front_indices) + [m, n]
out = matrices[0].new_zeros(total_shape)
i, j = 0, 0
for x in matrices:
    m, n = x.shape[-2:]
    out[..., i : i + m, j : j + n] = x
    i += m
    j += n

In [127]:
(out @ ej_ai[0].reshape(n_c, (l_max + 1) ** 2).T).shape

torch.Size([4, 64])