In [1]:
from typing import Tuple
import numpy as np
from scipy import linalg, io
import torch
from torch.utils.data import Dataset

In [2]:
CHARGES_LIST_QM9 = [1, 6, 7, 8, 9]
CHARGES_LIST_QM7 = [1, 6, 7, 8, 16]

class PointCloudMoleculeDataSet(Dataset):
    def __init__(self, coords_cart: np.ndarray, charges: np.ndarray, energies: np.ndarray) -> None:
        """
        coords_cart has shape (n_samples, max_n_atoms, 3)
        charges has shape (n_samples, max_n_atoms)
        energies has shape (n_samples,)
        """
        # print(charges.shape)
        self._coords_cart = coords_cart
        self._charges = charges
        self.n_samples = self._coords_cart.shape[0]
        self.n_atoms = np.sum(charges != 0, axis=1)
        # print(self.n_atoms.shape)
        self.energies = energies
        self.coords_aligned = None
        self.one_hot_point_features = None
        self.U_matrices = None

    def __len__(self) -> int:
        return self.n_samples

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        coords_out = self.coords_aligned[index, :self.n_atoms[index]]
        charge_features_out = self.one_hot_point_features[index, :self.n_atoms[index]]
        energies_out = self.energies[index]
        return (coords_out, charge_features_out, energies_out)

    def align_coords_cart(self) -> None:
        out = np.full_like(self._coords_cart, np.nan)
        out_U_mats = np.zeros((self.n_samples, 3, 3))

        for i in range(self.n_samples):
            n_atoms_i = self.n_atoms[i]
            coords_i = self._coords_cart[i, :n_atoms_i]
            coords_i = coords_i - np.mean(coords_i, axis=0)
            U, _, _ = linalg.svd(coords_i.transpose(), full_matrices=False)
            coords_aligned = np.matmul(U.transpose(), coords_i.transpose()).transpose()
            out[i, :n_atoms_i] = coords_aligned
            out_U_mats[i] = U

        self.coords_aligned = torch.Tensor(out)
        self.U_matrices = torch.Tensor(out_U_mats)
        

    def charges_to_one_hot_QM7(self) -> None:
        out = np.full((self.n_samples, 
                        self._charges.shape[1], 
                        len(CHARGES_LIST_QM7)), np.nan)
        charges_lst = CHARGES_LIST_QM7
        charges_lst_arr = np.array(CHARGES_LIST_QM7)
        for i in range(self.n_samples):
            n_atoms_i = self.n_atoms[i]
            out[i, :n_atoms_i] = np.zeros_like(out[i, :n_atoms_i])
            charges_i = self._charges[i, :n_atoms_i]
            col_idxes = [charges_lst.index(x) for x in charges_i]
            for atom_idx, charge_col_idx in enumerate(col_idxes):
                out[i, atom_idx, charge_col_idx] = 1.
        self.one_hot_point_features = torch.Tensor(out)

    def charges_to_one_hot_QM9(self) -> None:
        pass


In [3]:
def test_charges_to_one_hot() -> None:
    
    coords = np.full((2, 5, 3), np.nan)
    coords[0, :4] = np.random.normal(size=(4, 3))
    coords[1, :3] = np.random.normal(size=(3, 3))
    charges = np.array([[1, 1, 6, 6, 0],
                        [7, 7, 1, 0, 0]])
    
    energies = np.random.normal(size=2)
    
    x = PointCloudMoleculeDataSet(coords, charges, energies)
    
    x.align_coords_cart()
    x.charges_to_one_hot_QM7()
    
    expected_one_hot_encoding = np.full((charges.shape[0], charges.shape[1], len(CHARGES_LIST_QM7)), np.nan)
    expected_one_hot_encoding[0, :4] = np.array([[1, 0, 0, 0, 0],
                                                [1, 0, 0, 0, 0],
                                                [0, 1, 0, 0, 0],
                                                [0, 1, 0, 0, 0]])
    expected_one_hot_encoding[1, :3] = np.array([[0, 0, 1, 0, 0],
                                                [0, 0, 1, 0, 0],
                                                [1, 0, 0, 0, 0]])
    assert np.allclose(x.one_hot_point_features, 
                       expected_one_hot_encoding, equal_nan=True), "{}, {}".format(x.one_hot_point_features[1].numpy(),
                                                                expected_one_hot_encoding[1])
    
test_charges_to_one_hot()

In [32]:
def farthest_point_sample(xyz: torch.Tensor, npoint: int) -> torch.Tensor:
    """
    Uses a farthest point sampling scheme to downsample the point cloud
    Input:
        xyz (torch.Tensor): pointcloud data, has shape [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, which has shape [B, npoint]
    """
    print(xyz)
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    # List of random integers in [0, N] with length B
#     farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 
    
    # Prevent the initial choice from being a NaNed Out Row
    n_points_in_cloud = torch.sum(torch.logical_not(torch.isnan(xyz[:, :, 0])), axis=1)
#     print("N_POINTS", n_points_in_cloud)
    rand_draws = torch.rand(size=(B,)).to(device)
#     print("RAND DRAWS", rand_draws)
    scaled_rand_draws = torch.mul(n_points_in_cloud, rand_draws)
#     print("SCALED R", scaled_rand_draws)
    farthest = torch.floor(scaled_rand_draws).type(torch.long)
#     print("FURTHEST", farthest)
#     print([x.data for x in n_points_in_cloud])
#     farthest = torch.Tensor([torch.randint(0, x.data, dtype=torch.long) for x in n_points_in_cloud]).to(device)
    
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        # Start with the random indices in the 0th column.
        centroids[:, i] = farthest
        
        # The centroid thing has one xyz location for each element in the batch
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        
        # Dist finds distance squared between each XYZ point and the centroid
        dist = torch.sum((xyz - centroid) ** 2, -1)
#         print("DIST", dist)
        mask = dist < distance
        mask += torch.isnan(dist)
        distance[mask] = torch.nan_to_num(dist[mask]) # Caps distances at 1e10
#         print("DISTANCE", distance)
        farthest = torch.max(distance, -1)[1]
    return centroids

In [40]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

def query_ball_point(radius: float, 
                        nsample: int, 
                        xyz: torch.Tensor, 
                        query_centroids: torch.Tensor) -> torch.Tensor:
    """
    Input:
        radius (float): local region radius
        nsample (int): max sample number in local region
        xyz (torch.Tensor): all points, [B, N, 3]
        query_centroids (torch.Tensor): query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = query_centroids.shape
    group_idx = torch.arange(N, 
                             dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) # Shape [B, S, N]
    sqrdists = square_distance(query_centroids, xyz) # Shape [B, S, N]
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx

In [49]:
def test_query_ball_point_simple_input() -> None:
    
    in_xyz = torch.Tensor([[[0, 1, 0],[1, 0, 0]]])
    in_new_xyz = torch.Tensor([[[0, 1, 0]]])
    
    out = query_ball_point(1., 2, in_xyz, in_new_xyz)
    print(out)
    
test_query_ball_point_simple_input()

tensor([[[0, 0]]])


In [33]:
in_xyz_arr = np.full((2, 7, 3), np.nan)
in_xyz_arr[0, :6] = np.random.normal(size=(6, 3))
in_xyz_arr[1, :5] = np.random.normal(size=(5, 3))

in_tensor = torch.Tensor(in_xyz_arr)
# in_t = torch.Tensor(in_arr)

In [36]:
def test_FPS_simple_input_with_nans() -> None:
    in_arr = np.array([[[0, 0, 1], # 0
                       [0, 0, 1.1], # 1
                       [1, 0, 0], # 2
                       [np.nan, np.nan, np.nan]]]) # 3
    in_tensor = torch.Tensor(in_arr)
    print(in_tensor.shape)
    out = farthest_point_sample(in_tensor, 2)
    print("OUT", out)
    if out[0, 0] in [0, 1]:
        assert out[0, 1] == 2
    else:
        assert out[0, 1] == 1
    return out
out = test_FPS_simple_input_with_nans()
print("OUT", out)

torch.Size([1, 4, 3])
tensor([[[0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 1.1000],
         [1.0000, 0.0000, 0.0000],
         [   nan,    nan,    nan]]])
N_POINTS tensor([3])
RAND DRAWS tensor([0.2754])
SCALED R tensor([0.8261])
FURTHEST tensor([0])
DIST tensor([[0.0000, 0.0100, 2.0000,    nan]])
DISTANCE tensor([[0.0000, 0.0100, 2.0000, 0.0000]])
DIST tensor([[2.0000, 2.2100, 0.0000,    nan]])
DISTANCE tensor([[0.0000, 0.0100, 0.0000, 0.0000]])
OUT tensor([[0, 2]])
OUT tensor([[0, 2]])


In [116]:
x = torch.sum(torch.logical_not(torch.isnan(in_tensor[:, :, 0])), axis=1)

In [117]:
x

tensor([6, 5])

In [112]:
torch.randint(torch.zeros_like(x), x)

TypeError: randint(): argument 'high' (position 1) must be int, not Tensor

In [82]:
def test_FPS_simple_input() -> None:
    in_arr = np.array([[[0, 0, 1], # 0
                       [0, 0, 1.1], # 1
                       [1, 0, 0], # 2
                       [1.1, 0, 0]]]) # 3
    in_tensor = torch.Tensor(in_arr)
    out = farthest_point_sample(in_tensor, 2)
    
    if out[0, 0] in [2, 3]:
        assert out[0, 1] == 1
    else:
        assert out[0, 1] == 3
#     return out
out = test_FPS_simple_input()
print("OUT", out)

DIST tensor([[2.0000, 2.2100, 0.0000, 0.0100]])
DISTANCE tensor([[2.0000, 2.2100, 0.0000, 0.0100]])
torch.return_types.max(
values=tensor([2.2100]),
indices=tensor([1]))
DIST tensor([[0.0100, 0.0000, 2.2100, 2.4200]])
DISTANCE tensor([[0.0100, 0.0000, 0.0000, 0.0100]])
torch.return_types.max(
values=tensor([0.0100]),
indices=tensor([0]))
OUT tensor([[2, 1]])


In [73]:
def test_FPS_input_NaNs() -> None:
    """
    Tests that inputting NaNs responds in well-defined behavior for
    the furthest_point_sampling
    """
    
    
    in_xyz_arr = np.full((2, 7, 3), np.nan)
    in_xyz_arr[0, :6] = np.random.normal(size=(6, 3))
    in_xyz_arr[1, :5] = np.random.normal(size=(5, 3))
    
    in_tensor = torch.Tensor(in_xyz_arr)
    
    out = farthest_point_sample(in_tensor, 2)
    return out

out = test_FPS_input_NaNs()

print(out)
print(out[0])
print(out[1])

tensor([[2.8269, 3.2433, 4.1188, 7.2420, 0.0000, 8.7664,    nan],
        [   nan,    nan,    nan,    nan,    nan,    nan,    nan]])
tensor([[2.8269, 3.2433, 4.1188, 7.2420, 0.0000, 8.7664,    nan],
        [   nan,    nan,    nan,    nan,    nan,    nan,    nan]])
tensor([[   nan,    nan,    nan,    nan,    nan,    nan,    nan],
        [0.0000, 3.7153, 4.2614, 7.8100, 4.8693,    nan,    nan]])
tensor([[nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan]])
tensor([[4, 6],
        [6, 0]])
tensor([4, 6])
tensor([6, 0])
