In [2]:
import pandas as pd
import collections
import numpy as np
import sys
import torch

sys.path.append("../")


from preprocessing.protein_chemistry import list_atoms,list_atoms_types,VanDerWaalsRadii,atom_mass,atom_type_to_index,atom_to_index,index_to_type,atom_type_mass
from preprocessing.protein_chemistry import residue_dictionary,hetresidue_field

from modeling.graph.frames import get_aa_frameCloud, get_atom_frameCloud
from preprocessing import sequence_utils



def binarize_categorical(matrix, n_classes, out=None):
    L = matrix.shape[0]
    matrix = matrix.astype(np.int32)
    if out is None:
        out = np.zeros([L, n_classes], dtype=np.bool_)
    subset = (matrix>=0) & (matrix<n_classes)
    out[np.arange(L)[subset],matrix[subset]] = 1
    return out

def readData(file_path):
  #处理pdb文本，转为dataframe
  with open(file = file_path, mode ='r') as f1:
    data = f1.read()
    data = data.split('\n')
    del data[-3:]

  pdb = []
  for i in range(len(data)):
    element  = data[i].split()
    pdb.append(element)

  input = pd.DataFrame(pdb)
  #定义存放结果的字典
  amino_dict = collections.OrderedDict()
  atom_dict= collections.OrderedDict()

  for  i in range(len(input)):
    #判断是否是H原子
    if input.loc[i,11] != 'H':
      atom_coord = np.array(input.loc[i,6:8].values,dtype= np.float64)
      atom_name = input.loc[i,2]
      atom_dict[atom_name] = atom_coord
    #判断是否为该pdb文件的最后一个原子
    if i == len(input)-1:
      amino_name = str(input.loc[i,5]) + '_' + input.loc[i, 3]
      amino_dict[amino_name] = atom_dict
      atom_dict= collections.OrderedDict()
    #非最后一个原子情况下判断是否为该氨基酸最后一个原子
    else:
      if input.loc[i,5] != input.loc[i+1,5]:
        amino_name = str(input.loc[i,5]) + '_' + input.loc[i, 3]
        amino_dict[amino_name] = atom_dict
        atom_dict= collections.OrderedDict()
  return amino_dict

def processData(amino_dict):
  sequence = ""
  all_coordinates = []
  all_atoms = []
  all_atom_types = []
  for aa_key, atom_dict in amino_dict.items():
    _, aa_name = aa_key.split("_")
    sequence += residue_dictionary[aa_name]
    # List((3,)) ==> (atoms, 3)
    residue_atom_coordinates = np.stack([coord for _, coord in atom_dict.items()], axis=0)
    # (atoms,)
    residue_atoms = [atom_to_index[atom_name] for atom_name in atom_dict.keys()]
    residue_atom_type = [atom_type_to_index[atom_name[0]] for atom_name in atom_dict.keys()]

    all_coordinates.append(residue_atom_coordinates)
    all_atoms.append(residue_atoms)
    all_atom_types.append(residue_atom_type)

  return sequence, all_coordinates, all_atoms, all_atom_types
  

def getdData(file_paths):
  batch_sequences = []
  batch_all_coordinates = []
  batch_all_atoms = []
  batch_all_atom_types = []
  for file_path in file_paths:
    amino_dict = readData(file_path)
    sequence, all_coordinates, all_atoms, all_atom_types = processData(amino_dict)

    batch_sequences.append(sequence)
    batch_all_coordinates.append(all_coordinates)
    batch_all_atoms.append(all_atoms)
    batch_all_atom_types.append(all_atom_types)

  return batch_sequences, batch_all_coordinates, batch_all_atoms, batch_all_atom_types



file_paths = ["../dataset/P44_relaxed_rank_002_alphafold2_ptm_model_2_seed_000.pdb",]
batch_sequences, batch_all_coordinates, batch_all_atoms, batch_all_atom_types = getdData(file_paths)

sequence = batch_sequences[0]
all_coordinates, all_atoms = batch_all_coordinates[0], batch_all_atoms[0]

aa_clouds, aa_triplets, aa_indices = get_aa_frameCloud(all_coordinates, all_atoms)

nsequence_features = 20
aa_attributes = binarize_categorical(
    sequence_utils.seq2num(sequence)[0], 20)


atom_clouds, atom_triplets, atom_attributes, atom_indices = get_atom_frameCloud(sequence, all_coordinates, all_atoms)


########################################
from modeling.graph.neighborhoods import FrameBuilder

tensor_aa_clouds = torch.Tensor(aa_clouds).unsqueeze(0)
tensor_aa_triplets = torch.Tensor(aa_triplets).unsqueeze(0)
tensor_aa_triplets = tensor_aa_triplets.long()

tensor_aa_indices = torch.Tensor(aa_indices).unsqueeze(0)

config = None
frame_builder = FrameBuilder(config)

inputs = [tensor_aa_clouds, tensor_aa_triplets]
frames = frame_builder(inputs)


In [None]:
print(len(all_coordinates), len(all_atoms))
print(aa_clouds.shape, tensor_aa_clouds.shape)
print(aa_indices.shape, tensor_aa_indices.shape)

In [14]:
local_neighborhood.first_format, local_neighborhood.second_format

(['frame', 'index'], ['frame', 'index'])

In [None]:
local_neighborhood.coordinates, local_neighborhood.first_format, local_neighborhood.first_format.index('index')

In [4]:
########################################
from modeling.graph.neighborhoods import LocalNeighborhood

coordinates=['euclidian', 'index_distance', 'ZdotZ', 'ZdotDelta']

local_neighborhood = LocalNeighborhood(config, Kmax=16, coordinates=coordinates, self_neighborhood=True, index_distance_max=8, nrotations=1)

tensor_aa_attributes = torch.Tensor(aa_attributes)
tensor_aa_attributes = tensor_aa_attributes.unsqueeze(0)

input2localneighborhood = [frames, tensor_aa_indices, tensor_aa_attributes]
output = local_neighborhood(input2localneighborhood)

neighbor_coordinates, neighbors_attributes = output[0][0], output[1]

neighbor_coordinates.shape, neighbors_attributes.shape

for out in output[0]:
    print(out.shape)

torch.Size([1, 518, 16, 3])
torch.Size([1, 518, 16, 1])
torch.Size([1, 518, 16, 1])
torch.Size([1, 518, 16, 1])
torch.Size([1, 518, 16, 1])


In [None]:
frames.shape, tensor_aa_attributes.shape

In [None]:
tensor_aa_attributes.shape

In [None]:
from modeling.graph.layers import Linear

hidden_dim = 64
norm = "GN"
ng = 1
l1 = Linear(20, hidden_dim, norm=norm, ng=ng, act=False)
l1

In [None]:
x = l1(tensor_aa_attributes.reshape(-1, 20))
x = x.reshape(1, -1, 64)

In [None]:
input2localneighborhood = [frames, x]
output = local_neighborhood(input2localneighborhood)

neighbor_coordinates, neighbors_attributes = output[0][0], output[1]

In [None]:
neighbors_attributes.shape

In [None]:
neighbors_attributes[0,0,:,0],neighbors_attributes[0,0,:,13]

In [None]:
x2= torch.sum(neighbors_attributes, -2)

In [None]:
torch.sum(neighbors_attributes[0,0,:,13])

In [None]:
import numpy as np
import os
import sys
from fractions import gcd
from numbers import Number

import torch
from torch import Tensor, nn
from torch.nn import functional as F

In [None]:
class MapNet(nn.Module):

  def __init__(self, hidden_dim, ):
    super(MapNet, self).__init__()
    # self.config = config
    # hidden_dim = config["hidden_dim"] # 128
    norm = "GN"
    ng = 1

    self.input = nn.Sequential(
      nn.Linear(20, hidden_dim),
      nn.ReLU(inplace=True),
      Linear(hidden_dim, hidden_dim, norm=norm, ng=ng, act=False),
    )
    self.relu = nn.ReLU(inplace=True)

    coordinates=['euclidian',]
    self.fuse = []
    self.edge = []
    self.norm = []
    self.ctr2 = []
    # self.local_neighborhood = []
    self.local_neighborhood = LocalNeighborhood(
      config, Kmax=16, coordinates=coordinates, 
      self_neighborhood=True, index_distance_max=8, nrotations=1)
      
    for i in range(4):
      self.fuse.append(
        nn.Linear(hidden_dim, hidden_dim, bias=False)
      )
      self.edge.append(
        nn.Linear(hidden_dim, hidden_dim, bias=False)
      )
      self.norm.append(nn.GroupNorm(gcd(ng, hidden_dim), hidden_dim))
      self.ctr2.append(Linear(hidden_dim, hidden_dim, norm=norm, ng=ng, act=False))
    self.fuse = nn.ModuleList(self.fuse)
    self.edge = nn.ModuleList(self.edge)
    self.norm = nn.ModuleList(self.norm)
    self.ctr2 = nn.ModuleList(self.ctr2)


  def forward(self, x, frames, ):
    # input (bs, seq, 20)
    bs_dim, seq_dim, input_dim = x.shape
    x = x.reshape(bs_dim * seq_dim, -1)
    x = self.input(x) # (bs*seq, 20) => (bs*seq, hidden)
    x = self.relu(x)

    # x = x.reshape(bs_dim, seq_dim, -1)

    res = x # (bs*seq, h)
    for i in range(4):
      x_node = self.fuse[i](x)
      x_edge = self.edge[i](x)

      x_edge = x_edge.reshape(bs_dim, seq_dim, -1)
      input2localneighborhood = [frames, x_edge] # (bs, seq, 1)
      output = self.local_neighborhood(input2localneighborhood)
      # (bs, seq, 16, 3), (bs, seq, 16, h)
      neighbor_coordinates, neighbors_attributes = output[0][0], output[1]
      tmp = torch.sum(neighbors_attributes, -2) # (bs, seq, h)
      tmp = tmp.reshape(bs_dim * seq_dim, -1)

      x = x_node + tmp

      x = self.norm[i](x)
      x = self.relu(x)

      x = self.ctr2[i](x)
      x += res
      x = self.relu(x)
      res = x
    
    return x

net = MapNet(64)

In [None]:
out = net(tensor_aa_attributes, frames)
out.shape

In [37]:
torch.Tensor(torch.Size([2,3] + [1])).shape

torch.Size([2, 3, 1])

In [34]:
x = torch.rand(2,10,32,5)
centers = torch.rand(5,3)
x.shape[:-1] + torch.Size([5,3])


torch.Size([2, 10, 32, 5, 3])

In [39]:
x = torch.rand(2,10,6,1,5,32)
y = torch.rand(1,5,5,32)
(x*y).shape

torch.Size([2, 10, 6, 5, 5, 32])