In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import yaml
import torch
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
os.chdir('/data/rsg/nlp/sdobers/amine/diffdock-protein/src')

from data import load_data, get_data


In [4]:
from notebooks.utils_notebooks import Dict2Class

In [5]:
PATH = '/data/rsg/nlp/sdobers/amine/diffdock-protein/ckpts/dips_medium_model/'

In [6]:
CUDA_VISIBLE_DEVICE = 1

In [7]:
# load args
with open(os.path.join(PATH, 'args.yaml')) as f:
    args = yaml.safe_load(f)
args = Dict2Class(args)

args.num_gpu = 1
args.gpu = CUDA_VISIBLE_DEVICE
args.data_file = args.data_file.replace('data_file', 'data_file_only_val')
args.checkpoint_path = PATH
args.use_orientation_features = True

In [24]:
# load raw data
dips_loader_instance = load_data(args)


data loading: 100%|█| 985/985 [00:00<00:00, 921157


11:27:50 Loaded cached ESM embeddings
11:27:50 finished tokenizing residues with ESM
11:27:50 finished tokenizing all inputs
11:27:50 985 entries loaded


In [25]:
data = dips_loader_instance.data

In [None]:
def compute_edge_feat_ori_feat(complex_graph, key):
    src, dst = complex_graph[key, key].edge_index
    n_i_feat = complex_graph[key].n_i_feat
    u_i_feat = complex_graph[key].u_i_feat
    v_i_feat = complex_graph[key].v_i_feat
    # Change should start from here!
    # Loop over all edges of the graph and build the various p_ij, q_ij, k_ij, t_ij pairs
    edge_feat_ori_list = []
    for i in range(len(dist_list)):
        src = src_list[i]
        dst = dst_list[i]
        # place n_i, u_i, v_i as lines in a 3x3 basis matrix
        basis_matrix = np.stack((n_i_feat[dst, :], u_i_feat[dst, :], v_i_feat[dst, :]), axis=0)
        p_ij = np.matmul(basis_matrix,
                         residue_representatives_loc_feat[src, :] - residue_representatives_loc_feat[
                                                                    dst, :])
        q_ij = np.matmul(basis_matrix, n_i_feat[src, :])  # shape (3,)
        k_ij = np.matmul(basis_matrix, u_i_feat[src, :])
        t_ij = np.matmul(basis_matrix, v_i_feat[src, :])
        s_ij = np.concatenate((p_ij, q_ij, k_ij, t_ij), axis=0)  # shape (12,)
        edge_feat_ori_list.append(s_ij)
    edge_feat_ori_feat = np.stack(edge_feat_ori_list, axis=0)  # shape (num_edges, 4, 3)
    edge_feat_ori_feat = torch.from_numpy(edge_feat_ori_feat.astype(np.float32))


In [32]:
for item in data.values():
    print(item["graph"])
    print(item["graph"]["ligand"].n_i_feat)
    src, dst = item["graph"]["ligand", "ligand"].edge_index
    print(src)
    print(dst)
    break

HeteroData(
  name='nu/2nuu.pdb1_9.dill',
  center=[1, 3],
  [1mreceptor[0m={
    pos=[409, 3],
    x=[409, 1281],
    n_i_feat=[409, 3],
    u_i_feat=[409, 3],
    v_i_feat=[409, 3]
  },
  [1mligand[0m={
    pos=[112, 3],
    x=[112, 1281],
    n_i_feat=[112, 3],
    u_i_feat=[112, 3],
    v_i_feat=[112, 3]
  },
  [1m(receptor, contact, receptor)[0m={ edge_index=[2, 8180] },
  [1m(ligand, contact, ligand)[0m={ edge_index=[2, 2240] }
)
tensor([[-0.7323,  0.6726, -0.1068],
        [ 0.7958, -0.5957,  0.1093],
        [-0.9843, -0.0075, -0.1764],
        [ 0.9216,  0.1597,  0.3538],
        [-0.8532, -0.2970, -0.4288],
        [ 0.7782,  0.4185,  0.4683],
        [-0.5723, -0.6468, -0.5042],
        [ 0.6834,  0.6877,  0.2449],
        [ 0.8122, -0.5833,  0.0121],
        [-0.2689, -0.3585,  0.8940],
        [-0.0857, -0.9942, -0.0655],
        [-0.7687, -0.1191, -0.6284],
        [-0.0693,  0.3524,  0.9333],
        [ 0.0722, -0.8384,  0.5402],
        [-0.7932, -0.5983, -0.1136

In [14]:
409*20

8180

In [73]:
def compute_orientation_vectors(c_alpha_coords, n_coords, c_coords):
    ################## Extract 3D coordinates and n_i,u_i,v_i vectors of representative residues ################
    num_residues = c_alpha_coords.shape[0]
    
    n_i_list = []
    u_i_list = []
    v_i_list = []
    for i in range(num_residues):
        n_coord = n_coords[i]
        c_alpha_coord = c_alpha_coords[i]
        c_coord = c_coords[i]
        u_i = (n_coord - c_alpha_coord) / torch.linalg.vector_norm(n_coord - c_alpha_coord)
        t_i = (c_coord - c_alpha_coord) / torch.linalg.vector_norm(c_coord - c_alpha_coord)
        n_i = torch.linalg.cross(u_i, t_i) / torch.linalg.vector_norm(torch.linalg.cross(u_i, t_i))
        v_i = torch.linalg.cross(n_i, u_i)
        assert (torch.abs(torch.linalg.vector_norm(v_i) - 1.) < 1e-5), "protein utils protein_to_graph_dips, v_i norm larger than 1"
        n_i_list.append(n_i)
        u_i_list.append(u_i)
        v_i_list.append(v_i)

    n_i_feat = torch.stack(n_i_list)
    u_i_feat = torch.stack(u_i_list)
    v_i_feat = torch.stack(v_i_list)
    
    assert n_i_feat.shape == u_i_feat.shape == v_i_feat.shape
    
    return n_i_feat, u_i_feat, v_i_feat



In [77]:
for item in tqdm.tqdm(data.values()):
    all_res, all_atom, all_pos = item['ligand']
    
    c_alpha_coords = []
    n_coords = []
    c_coords = []
    for i, a in enumerate(all_atom):
        if a[0] == "CA":
            c_alpha_coords.append(all_pos[i])
        if a[0] == "N":
            n_coords.append(all_pos[i])
        if a[0] == "C":
            c_coords.append(all_pos[i])

    c_alpha_coords = torch.stack(c_alpha_coords)
    n_coords = torch.stack(n_coords)
    c_coords = torch.stack(c_coords)
    
    assert c_alpha_coords.shape == n_coords.shape == c_coords.shape
    
    
    # Create orientation vectors
    
    n_i_feat, u_i_feat, v_i_feat = compute_orientation_vectors(c_alpha_coords, n_coords, c_coords)
    continue
    print(n_i_feat)
    print(u_i_feat)
    print(v_i_feat)
    
    print('----------------')
    
    print(c_alpha_coords.shape)
    print(c_alpha_coords)
    print(n_coords.shape)
    print(n_coords)
    print(c_coords.shape)
    print(c_coords)
    
    print(len(all_res))
    print('----')
    print((all_atom))
    print('----')
    print((all_pos))
    print('----')

    break

100%|██████████████████████████████████████████████████████████████| 985/985 [00:09<00:00, 104.71it/s]


In [45]:
l = [all_pos[4], all_pos[20]]

In [46]:
l

[tensor([ 36.6910,  26.7940, -15.2200], dtype=torch.float64),
 tensor([ 33.1740,  19.8830, -10.5650], dtype=torch.float64)]

In [67]:
st = torch.stack(l)

In [70]:
st.shape[0]

2

In [72]:
st[1].shape

torch.Size([3])