In [None]:
import torch

from rnampnn.utils.data import analyse_dataset

analyse_dataset()

In [None]:
from rnampnn.utils.data import RNADataModule

data = RNADataModule()
data.setup(stage="fit")

In [None]:
train = data.train_dataloader()
for i in train:
    print(i)
    break

In [None]:
from rnampnn.utils.data import RNADataset

dataset = RNADataset.from_path("data/")


In [None]:
print(len(dataset))
dataset.slice_augmentation(num_gen=100)
print(len(dataset))

In [None]:
import torch
def dist(X, mask, eps=1E-6):
    mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
    dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
    dX = torch.sum(dX ** 2, 3)
    D = (1. - mask_2D) * 10000 + mask_2D * torch.sqrt(dX + eps)
    print("D:", D)

    D_max, _ = torch.max(D, -1, keepdim=True)
    print("D_max:", D_max.shape)
    D_adjust = D + (1. - mask_2D) * (D_max + 1)
    print("D_adjust:",D_adjust, D_adjust.shape)
    D_neighbors, E_idx = torch.topk(D_adjust, min(5, D_adjust.shape[-1]), dim=-1, largest=False)
    print("D_neighbors:", D_neighbors.shape)
    print("E_idx:", E_idx.shape)
    return D_neighbors, E_idx

X = torch.cat((torch.rand(1, 10, 6, 3),torch.zeros(1,5,6,3)), dim=1)
X = X[:, :, 0, :]
print(X.shape)
mask = torch.cat((torch.ones(1,10), torch.zeros(1,5)), dim=1)
Y = dist(X, mask)
print(Y[0], Y[1])

In [2]:
from rnampnn.model.rnampnn import AtomFeature, AtomMPNN, AtomPooling
import torch

atom_feature = AtomFeature(num_atom_neighbour=30)
X = torch.cat(( torch.rand(1, 32, 7, 3), torch.zeros(1, 4, 7, 3)), dim=1)
mask = torch.cat((torch.ones(1, 32), torch.zeros(1, 4)), dim=1)
atom_coords, atom_mask, encode, dist_neighbors, edge_index = atom_feature(X, mask)
atom_mpnn = AtomMPNN(atom_hidden_dim=32, num_layers=2)
updated_encode, _, _, _ = atom_mpnn(encode, atom_mask, dist_neighbors, edge_index)

raw_feature = torch.cat(( torch.rand(1, 32, 8), torch.zeros(1, 4, 8)), dim=1)
atom_pooling = AtomPooling(raw_feature_dim=8, atom_hidden_dim=32, num_layers=2)
pooling = atom_pooling(updated_encode, atom_mask, raw_feature)
print(pooling)

tensor([[[-0.0383, -0.0812,  0.0097,  ...,  0.1096, -0.0273,  0.0492],
         [-0.0463, -0.0595, -0.0170,  ...,  0.0239, -0.0100, -0.0092],
         [-0.0438, -0.0757, -0.0109,  ...,  0.1274, -0.0072,  0.0564],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<SumBackward1>)


In [2]:
from rnampnn.model.rnampnn import ResFeature
import torch

X = torch.cat(( torch.rand(1, 32, 7, 3), torch.zeros(1, 4, 7, 3)), dim=1)
mask = torch.cat((torch.ones(1, 32), torch.zeros(1, 4)), dim=1)

res_feature = ResFeature(num_neighbour=40)

dist_neighbour, edge_index = res_feature._get_res_graph(X, mask)
print('dist_neighbour: ', dist_neighbour)
print('edge_index: ', edge_index)

cos = res_feature._inside_angles(X, mask)
print('cos: ', cos)

dist = res_feature._inside_dist(X, mask)
print('dist: ', dist)

dist_neighbour:  tensor([[[1.0000e-03, 1.0366e-01, 1.1080e-01,  ..., 1.0000e+06,
          1.0000e+06, 1.0000e+06],
         [1.0000e-03, 8.0849e-02, 8.7057e-02,  ..., 1.0000e+06,
          1.0000e+06, 1.0000e+06],
         [1.0000e-03, 4.7565e-02, 9.6630e-02,  ..., 1.0000e+06,
          1.0000e+06, 1.0000e+06],
         ...,
         [1.0000e+06, 1.0000e+06, 1.0000e+06,  ..., 1.0000e+06,
          1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06, 1.0000e+06,  ..., 1.0000e+06,
          1.0000e+06, 1.0000e+06],
         [1.0000e+06, 1.0000e+06, 1.0000e+06,  ..., 1.0000e+06,
          1.0000e+06, 1.0000e+06]]])
edge_index:  tensor([[[ 0,  8, 27,  ..., -1, -1, -1],
         [ 1, 26, 10,  ..., -1, -1, -1],
         [ 2, 10,  3,  ..., -1, -1, -1],
         ...,
         [-1, -1, -1,  ..., -1, -1, -1],
         [-1, -1, -1,  ..., -1, -1, -1],
         [-1, -1, -1,  ..., -1, -1, -1]]])
cos:  tensor([[[-0.5537,  0.1459, -0.8908, -0.0627],
         [-0.7664, -0.5576, -0.8504, -0.7487]