In [1]:
import torch
import ase
from ase.io import read, write
import numpy as np
import scipy
from ase.build import molecule, bulk
from ase.neighborlist import neighbor_list
from typing import Tuple

In [2]:
from torch_nl.ase_impl import ase_neighbor_list
from torch_nl.torch_impl import compute_images

# n^2 impl

##  test 

In [3]:
frames = [molecule('OCHCHO'), molecule('C3H9C')]
# frames = [molecule('C3H9C')]
# frames = [bulk('Si', 'diamond', a=6, cubic=True)]
rcut = 3

pos = torch.from_numpy(frame.get_positions())
cell = torch.from_numpy(frame.get_cell().array)
pbc = torch.from_numpy(frame.get_pbc())
n_atoms = [0]
pos = []
cell = []
pbc = []
for ff in frames:
    n_atoms.append(len(ff))
    pos.append(torch.from_numpy(ff.get_positions()))
    cell.append(torch.from_numpy(ff.get_cell().array))
    pbc.append(torch.from_numpy(ff.get_pbc()))
pos = torch.cat(pos)
cell = torch.cat(cell)
pbc = torch.cat(pbc)
stride = torch.from_numpy(np.cumsum(n_atoms))
batch = torch.zeros(pos.shape[0],dtype=torch.long)
for ii,(st,nd) in enumerate(zip(stride[:-1],stride[1:])):
    batch[st:nd] = ii
n_atoms = torch.Tensor(n_atoms[1:]).to(dtype=torch.long)
batch,n_atoms, pos.shape, cell.shape, pbc.shape

NameError: name 'frame' is not defined

In [173]:
dist = []
for frame in frames:
    idx_i, idx_j, idx_S, dd = neighbor_list(
            "ijSd", frame, cutoff=rcut, self_interaction=False
        )
    dist.append(np.sort(dd))

In [174]:
def get_fully_connected_mapping(i_ids, j_ids) -> torch.Tensor:
    mapping = torch.cartesian_prod(i_ids, j_ids)
    mapping = mapping[mapping[:,0] != mapping[:,1], :]
    return mapping

def compute_images(
    positions: torch.Tensor,
    cell: torch.Tensor,
    pbc: torch.Tensor,
    cutoff: float,
    batch: torch.Tensor,
    n_atoms: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """TODO: add doc"""
    cell = cell.view((-1, 3, 3)).to(torch.float64)
    pbc = pbc.view((-1, 3))
    has_pbc = pbc.prod(dim=1, dtype=bool)
    print(has_pbc)
    stride = torch.zeros(n_atoms.shape[0]+1,dtype=torch.long)
    stride[1:] = torch.cumsum(n_atoms, dim=0,dtype=torch.long)
    reciprocal_cell = torch.zeros_like(cell)
    reciprocal_cell[has_pbc,:,:] = torch.linalg.inv(cell[has_pbc,:,:]).transpose(2, 1)
    # print('reciprocal_cell: ', reciprocal_cell)
    inv_distances = reciprocal_cell.norm(2, dim=-1)
    # print('inv_distances: ', inv_distances)
    num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
    print(num_repeats)
    num_repeats_ = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
    # print('num_repeats_: ', num_repeats_.device)
    ids = torch.arange(positions.shape[0], device=positions.device, dtype=torch.long)
    images, mapping, batch_images, shifts_expanded, shifts_idx_ = [], [], [], [], []
    for i_structure in range(n_atoms.shape[0]):
        num_repeats = num_repeats_[i_structure]
        r1 = torch.arange(
            -num_repeats[0],
            num_repeats[0] + 1,
            device=cell.device,
            dtype=torch.long,
        )
        r2 = torch.arange(
            -num_repeats[1],
            num_repeats[1] + 1,
            device=cell.device,
            dtype=torch.long,
        )
        r3 = torch.arange(
            -num_repeats[2],
            num_repeats[2] + 1,
            device=cell.device,
            dtype=torch.long,
        )
        shifts_idx = torch.cartesian_prod(torch.sort(r1)[1], torch.sort(r2)[1], torch.sort(r3)[1])
        n_cell_image = shifts_idx.shape[0]
        shifts = torch.matmul(shifts_idx.to(cell.dtype), cell[i_structure])
        # stnd = slice(stride[i_structure], stride[i_structure+1])
        n_atom = n_atoms[i_structure]
        pos = positions[stride[i_structure]:stride[i_structure+1]]
        i_ids = ids[stride[i_structure]:stride[i_structure+1]]
        print(n_atom, n_cell_image)
        #ss = n_atom*torch.repeat_interleave(torch.arange(n_cell_image), n_atom) + stride[i_structure]
        # ss = torch.arange(n_cell_image*n_atom) + stride[i_structure]
        #print(ss)
        #j_ids = i_ids.repeat(shifts.shape[0]) + ss
        j_ids = torch.arange(n_cell_image*n_atom) + stride[i_structure]
        # j_ids += 
        print(i_ids.shape, j_ids.shape,stride[i_structure])
        s_mapping = get_fully_connected_mapping(i_ids, j_ids)
        # s_mapping[:,1] += n_atom*torch.repeat_interleave(torch.arange(n_atom), n_cell_image)
        print(n_atoms[i_structure])
        shift_expanded = shifts.repeat(1, n_atoms[i_structure]).view((-1, 3))
        pos_expanded = pos.repeat(shifts.shape[0], 1)

        images.append(pos_expanded + shift_expanded)
        mapping.append(s_mapping)
        batch_images.append(
            i_structure
            * torch.ones(
                images[-1].shape[0], dtype=torch.int64, device=cell.device
            )
        )
        shifts_expanded.append(shift_expanded)
        shifts_idx_.append(
            shifts_idx.repeat(1, n_atoms[i_structure]).view((-1, 3))
        )
    return (
        torch.cat(images, dim=0).to(positions.dtype),
        torch.cat(mapping, dim=0).t(),
        torch.cat(batch_images, dim=0),
        torch.cat(shifts_expanded, dim=0).to(positions.dtype),
        torch.cat(shifts_idx_, dim=0),
    )

In [175]:
images, mapping, batch_images, shifts_expanded, shifts_idx = compute_images(pos, 
                                                         cell, pbc, rcut, 
                                                         batch, 
                                                         n_atoms)


d2 = (pos[mapping[0]] - images[mapping[1]]).square().sum(dim=1)
mask = d2 <= rcut*rcut
d = d2[mask].sqrt()
mapping = mapping[:, mask]
mapping_batch = batch_images[mapping[0]]
d.shape

tensor([False, False])
reciprocal_cell:  tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]], dtype=torch.float64)
inv_distances:  tensor([[0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)
tensor([[0, 0, 0],
        [0, 0, 0]])
tensor(6) 1
torch.Size([6]) torch.Size([6]) tensor(0)
tensor(6)
tensor(13) 1
torch.Size([13]) torch.Size([13]) tensor(6)
tensor(13)


torch.Size([128])

In [139]:
images.shape, pos.shape, mapping.shape, shifts_expanded.shape, shifts_idx.shape

(torch.Size([6, 3]),
 torch.Size([13, 3]),
 torch.Size([2, 30]),
 torch.Size([6, 3]),
 torch.Size([6, 3]))

In [178]:
for ii, dd in enumerate(dist):
    print(np.allclose(np.sort(dd), np.sort(d[mapping_batch == ii].numpy())))

True
True


In [8]:
d2.min(), rcut*rcut

(tensor(24.7500, dtype=torch.float64), 9)

In [103]:
dist.shape, d.shape, d2.shape

((32,), torch.Size([0]), torch.Size([1512]))

In [76]:
mapping.shape, images.shape, mapping

(torch.Size([2, 30]),
 torch.Size([6, 3]),
 tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
          4, 5, 5, 5, 5, 5],
         [1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 1, 3, 4, 5, 0, 1, 2, 4, 5, 0, 1, 2, 3,
          5, 0, 1, 2, 3, 4]]))

In [27]:
aa = torch.arange(3).reshape(3,1)
torch.cat([aa,aa], dim=1)

tensor([[0, 0],
        [1, 1],
        [2, 2]])

In [22]:
torch.arange?

In [75]:
mapping

tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
         4, 5, 5, 5, 5, 5],
        [1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 1, 3, 4, 5, 0, 1, 2, 4, 5, 0, 1, 2, 3,
         5, 0, 1, 2, 3, 4]])