In [2]:
from dataclasses import astuple
import random

from rdkit import Chem
import torch
from torch import Tensor, nn

from mol_gnn.featurizers import BaseMoleculeMolGraphFeaturizer
from mol_gnn.data import BatchMolGraph

In [3]:
smis = ["c1ccccc1", "CCCC"]
mgf = BaseMoleculeMolGraphFeaturizer()

mols = [Chem.MolFromSmiles(smi) for smi in smis]
mgs = [mgf(mol) for mol in mols]

bmg = BatchMolGraph(mgs)
V, E, edge_index, rev_index, batch = astuple(bmg)

In [4]:
W = nn.Linear(sum(mgf.shape), 100)
W_v, W_e = [nn.Linear(d, 100) for d in mgf.shape]

In [30]:
from torch_scatter import scatter, scatter_softmax, scatter_sum

In [37]:
src =  torch.tensor([0, 1, 0, 2, 0, 3, 1, 2])
dest = torch.tensor([1, 0, 2, 0, 3, 0, 2, 1])
edge_index = torch.stack([src, dest])
rev_index = torch.tensor([1, 0, 3, 2, 5, 4, 7, 6])

X = torch.arange(len(dest) * 4).view(-1, 4).float()
w = torch.ones(len(dest))
X

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.],
        [16., 17., 18., 19.],
        [20., 21., 22., 23.],
        [24., 25., 26., 27.],
        [28., 29., 30., 31.]])

In [41]:
X[rev_index]

tensor([[ 4.,  5.,  6.,  7.],
        [ 0.,  1.,  2.,  3.],
        [12., 13., 14., 15.],
        [ 8.,  9., 10., 11.],
        [20., 21., 22., 23.],
        [16., 17., 18., 19.],
        [28., 29., 30., 31.],
        [24., 25., 26., 27.]])

In [55]:
class LinearFacade(nn.Module):
    def __init__(self, input_dim, output_dim) -> None:
        super().__init__()

        self.output_dim = output_dim
    
    def forward(self, X: Tensor) -> Tensor:
        return torch.zeros(*X.shape[:-1], self.output_dim, device=X.device) 

In [65]:
%timeit nn.Dropout(1)(X)
L = LinearFacade(4, 4)
%timeit L(X)

7.4 µs ± 46 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
1.71 µs ± 7.77 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [29]:
scatter_sum(X @ W_b, dest, 0)

tensor([[72., 78., 84., 90.],
        [56., 60., 64., 68.],
        [64., 68., 72., 76.],
        [32., 34., 36., 38.]])

IndexError: index 0 is out of bounds for dimension 0 with size 0

In [32]:
src =  torch.tensor([0, 1, 0, 2, 0, 3, 1, 2])
dest = torch.tensor([1, 0, 2, 0, 3, 0, 2, 1])
edge_index = torch.stack([src, dest])
rev_index = torch.tensor([1, 0, 3, 2, 5, 4, 7, 6])

X = torch.arange(len(dest) // 2 * 4).view(-1, 4).float()
w = torch.ones(len(dest))
X

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])

In [36]:
X[src], X[dest]

(tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.],
         [ 8.,  9., 10., 11.],
         [ 0.,  1.,  2.,  3.],
         [12., 13., 14., 15.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]),
 tensor([[ 4.,  5.,  6.,  7.],
         [ 0.,  1.,  2.,  3.],
         [ 8.,  9., 10., 11.],
         [ 0.,  1.,  2.,  3.],
         [12., 13., 14., 15.],
         [ 0.,  1.,  2.,  3.],
         [ 8.,  9., 10., 11.],
         [ 4.,  5.,  6.,  7.]]))

In [34]:
scatter_sum(X[src], dest, 0)

tensor([[24., 27., 30., 33.],
        [ 8., 10., 12., 14.],
        [ 4.,  6.,  8., 10.],
        [ 0.,  1.,  2.,  3.]])

In [None]:
scatter_sum(X @ W_b, dest, 0)

tensor([[72., 78., 84., 90.],
        [56., 60., 64., 68.],
        [64., 68., 72., 76.],
        [32., 34., 36., 38.]])

In [12]:
edge_index

tensor([[0, 1, 1, 2],
        [1, 0, 2, 1]])

In [15]:
scatter(X, edge_index[1], 0, reduce="sum")

tensor([[36, 39, 42, 45],
        [28, 30, 32, 34],
        [32, 34, 36, 38],
        [16, 17, 18, 19]])

In [7]:
W = nn.Linear(4, 1)

W(X)[rev_index], W(X[rev_index])

(tensor([[ 0.0616],
         [ 0.0762],
         [-0.0362],
         [ 0.2077]], grad_fn=<IndexBackward0>),
 tensor([[ 0.0616],
         [ 0.0762],
         [-0.0362],
         [ 0.2077]], grad_fn=<AddmmBackward0>))

In [73]:
scatter_softmax(X * 100, dest, 0, dim_size=len(X))

tensor([[1.0000e+00, 1.1028e-06, 1.6534e-07, 1.0000e+00],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
        [1.2543e-27, 1.0000e+00, 1.0000e+00, 1.5541e-10]])

In [76]:
1 / scatter(w.unsqueeze(1), dest, 0, reduce="sum")[dest]

tensor([[0.5000],
        [1.0000],
        [1.0000],
        [0.5000]])

In [56]:
%timeit scatter(w / scatter(w, dest, 0, reduce="sum")[dest][:, None] * X, dest, 0, reduce="sum")

19.5 µs ± 643 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
def permute_tensor(X: Tensor, num_perm: int):
    b, V, d = X.shape

    random_indices = torch.stack(
        [
            torch.stack(
                [
                    torch.stack([torch.Tensor(random.sample(range(d), d)) for _ in range(V)])
                    for _ in range(b)
                ]
            )
            for _ in range(num_perm)
        ]
    ).long()
    tensor_repeated = X.reshape(1, *X.shape).repeat(num_perm, 1, 1, 1)

    return torch.gather(tensor_repeated, 3, random_indices)


class Permutation(nn.Module):
    def __init__(self, dim: int = 1):
        super().__init__()

        if dim < 1:
            raise ValueError(f"arg 'dim' must be greater than! got: {dim}")
        
        self.dim = dim

    def forward(self, X: Tensor):
        batch_perms = torch.stack(
            [torch.randperm(X.shape[self.dim]) for _ in range(len(X))]
        )
        index = batch_perms.unsqueeze(X.ndim - self.dim).expand(X.shape)

        return X.gather(self.dim, index)

In [158]:
X.ndim - 1
dim = 2
perms.unsqueeze(X.ndim - dim).shape

torch.Size([2, 1, 4])

In [152]:
X = torch.arange(2*3*4).view(2, 3, 4)#.expand(2, 3, 4)
X

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [None]:
perms.view_as()

In [164]:
X = torch.rand(2, 3, 4)
dim = 1
batch_perms = torch.stack(
    [torch.randperm(X.shape[dim]) for _ in range(len(X))]
)
dest = batch_perms.unsqueeze(1).repeat_interleave(X.shape[1], dim=1).unsqueeze(-1)
print(X.shape)
print(batch_perms.shape)
# perms: 2   x   4
#     X: 2 x 3 x 4
# perms = add_dummy_dims(perms, X.shape).expand(X.shape)
# print(index.shape)
print(batch_perms.unsqueeze(X.ndim - dim).expand(X.shape).shape)

torch.Size([2, 3, 4])
torch.Size([2, 3])
torch.Size([2, 3, 4])


In [145]:
batch_perms.expand(X.shape)

RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 1.  Target sizes: [2, 3, 4].  Tensor sizes: [2, 4]

In [139]:
X.gather(3, dest)

tensor([[[[  2],
          [  6],
          [ 10],
          [ 18]],

         [[ 22],
          [ 26],
          [ 30],
          [ 38]],

         [[ 42],
          [ 46],
          [ 50],
          [ 58]]],


        [[[ 61],
          [ 68],
          [ 72],
          [ 75]],

         [[ 81],
          [ 88],
          [ 92],
          [ 95]],

         [[101],
          [108],
          [112],
          [115]]]])

In [130]:
X.shape

torch.Size([2, 3, 4, 5])

In [44]:
X[torch.arange(2)[:, None], batch_perms]

tensor([[[ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [ 0,  1,  2,  3]],

        [[20, 21, 22, 23],
         [16, 17, 18, 19],
         [12, 13, 14, 15]]])

In [102]:
X

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [73]:
A = torch.tensor(
    [[[1, 2],
      [3, 4]],
     [[5, 6],
      [7, 8]]]
)
I = torch.tensor(
    [[[1, 1],
      [0, 0]],
     [[0, 0],
      [1, 1]]]
)
# I = torch.tensor([1, 0, 2]).view()
A.gather(0, I)

tensor([[[5, 6],
         [3, 4]],

        [[1, 2],
         [7, 8]]])

In [70]:
torch.tensor([1, 0, 2]).view(-1, 1).repeat(1, 2)

tensor([[1, 1],
        [0, 0],
        [2, 2]])

In [106]:
batch_perms.shape

torch.Size([2, 3])

In [105]:
batch_perms[..., None].repeat(1, 1, 4)

tensor([[[1, 1, 1, 1],
         [2, 2, 2, 2],
         [0, 0, 0, 0]],

        [[2, 2, 2, 2],
         [1, 1, 1, 1],
         [0, 0, 0, 0]]])

In [112]:
batch_perms[..., None].repeat_interleave(X.shape[-1], dim=-1)

tensor([[[1, 1, 1, 1],
         [2, 2, 2, 2],
         [0, 0, 0, 0]],

        [[2, 2, 2, 2],
         [1, 1, 1, 1],
         [0, 0, 0, 0]]])

In [None]:
perms_final_shape = torch.tensor(
    [[1],
     [2], 0],
     [2, 1, 0]]
)