In [None]:
import sys; sys.path.append("..")
import Bio.PDB.PDBParser
from Bio.PDB.Polypeptide import protein_letters_3to1

import numpy as np
import torch

import open3d as o3d

In [None]:
def extract_point_clouds_N_Ca_C_O(pdb_filename: str, cath_id: str):
  # https://stackoverflow.com/questions/14463277/how-to-disable-python-warnings
  import warnings
  warnings.filterwarnings("ignore")

  pdb_parser = Bio.PDB.PDBParser()
  structure = pdb_parser.get_structure(cath_id, pdb_filename)

  # Expect only one model per structure.
  assert(len(structure) == 1)

  points_N = []
  points_Ca = []
  points_C = []
  points_O = []

  for residue in structure.get_residues():
    for a in residue.get_atoms():
      name = a.get_fullname().strip()

      if name == 'N':
        points_N.append(a.get_coord())
      elif name == 'CA':
        points_Ca.append(a.get_coord())
      elif name == 'C':
        points_C.append(a.get_coord())
      elif name == 'O':
        points_O.append(a.get_coord())

  return [torch.Tensor(l) for l in [points_N, points_Ca, points_C, points_O]]


def center_and_scale_unit_sphere(points: torch.Tensor) -> torch.Tensor:
  mu = points.mean(dim=0)
  vmax, _ = points.max(dim=0)
  vmin, _ = points.min(dim=0)
  max_dim = (vmax - vmin).norm()
  points = 2 * (points - mu) / max_dim
  return points + 0.5


def center_and_scale_unit_box(points: torch.Tensor) -> torch.Tensor:
  """
  Scales all of the points uniformly so that they are in the range [0, 1].
  """
  mu = points.mean(dim=0)
  # Min and max corners of the bounding box.
  vmax = torch.Tensor([points[:,0].amax(), points[:,1].amax(), points[:,2].amax()])
  vmin = torch.Tensor([points[:,0].amin(), points[:,1].amin(), points[:,2].amin()])
  sf = 1.0 / (vmax - vmin + 1).max() # largest bbox dimension
  centered = (points - mu)
  return (centered * sf) + 0.5


def create_occupancy_grid(points: torch.Tensor, G: int = 100) -> torch.Tensor:
  """Create a 3D occupancy grid from a collection of poins."""
  pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points))
  voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=(1/G))
  voxels = voxel_grid.get_voxels()
  indices = np.stack(list(vx.grid_index for vx in voxels))

  O = torch.zeros((G, G, G))

  for idx in indices:
    O[tuple(idx)] += 1

  return O
  

In [None]:
cath_id = "1a0gA02"
pdb_filename = f"../data/pdb_share/16pkA02.pdb"
points_N, points_Ca, points_C, points_O = extract_point_clouds_N_Ca_C_O(pdb_filename, cath_id)

In [None]:
points_N_cent = center_and_scale_unit_box(points_N)
points_C_cent = center_and_scale_unit_box(points_C)
points_Ca_cent = center_and_scale_unit_box(points_Ca)
points_O_cent = center_and_scale_unit_box(points_O)

O_N = create_occupancy_grid(points_N_cent, G=512)

In [None]:
points_N_cent.max(dim=0)

In [None]:
# Visualize the points.
import numpy as np
import matplotlib.pyplot as plt

def plot_point_cloud(points_C, points_N, points_Ca, points_O):
  fig = plt.figure()
  ax = plt.axes(projection='3d')

  ax.scatter(points_C[:,0], points_C[:,1], points_C[:,2], 'red')
  ax.scatter(points_N[:,0], points_N[:,1], points_N[:,2], 'green')
  ax.scatter(points_Ca[:,0], points_Ca[:,1], points_Ca[:,2], 'gray')
  ax.scatter(points_O[:,0], points_O[:,1], points_O[:,2], 'red')

plot_point_cloud(points_C, points_N, points_Ca, points_O)

In [None]:
# Visualize the occupancy grid.
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes(projection='3d')

G = 1000
points = O_N.nonzero() / G

xline = points[:,0]
yline = points[:,1]
zline = points[:,2]
ax.scatter(xline, yline, zline, 'gray')

In [None]:
v = o3d.utility.Vector3dVector(points_Ca_cent)
pcd = o3d.geometry.PointCloud(v)

grid_dim = 100

voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(
  pcd, voxel_size=1/grid_dim
)
voxels = voxel_grid.get_voxels()  # returns list of voxels
indices = np.stack(list(vx.grid_index for vx in voxels))
# colors = np.stack(list(vx.color for vx in voxels))

O = torch.zeros((grid_dim, grid_dim, grid_dim))
# O[[tuple(idx) for idx in indices]] = 1

for idx in indices:
#   print(idx)
  O[tuple(idx)] += 1

In [None]:
O.sum()

In [None]:
indices

In [None]:
O = torch.zeros((3, 3, 3))

O[torch.LongTensor([
  [0, 0, 1],
  [0, 0, 0],
  [0, 0, 2]
])]

In [None]:
O.shape

In [None]:
torch.concat([O.unsqueeze(-1), O.unsqueeze(-1)], -1).shape

In [None]:
import sys; sys.path.append("..")
from gvpgnn.datasets import ProteinVoxelDataset

In [None]:
d = ProteinVoxelDataset("../data/challenge_test_set/", None, "cpu", voxel_grid_dim=256)

In [None]:
for i, data in enumerate(d):
  print(i, data["name"])