In [1]:
import torch
import ase
from ase.io import read, write
import numpy as np
import scipy
from ase.build import molecule, bulk
from ase.neighborlist import neighbor_list
from typing import Tuple

In [2]:
from torch_nl.ase_impl import ase_neighbor_list
from torch_nl.torch_impl import compute_images

# n^2 impl

##  test 

In [212]:
frames = [molecule('OCHCHO'), molecule('C3H9C')]
# frames = [molecule('C3H9C')]
frames += [bulk('Si', 'diamond', a=6, cubic=True)]
frames += [bulk('Si', 'diamond', a=4, cubic=True)]
# frames = [bulk('Si', 'diamond', a=6.1, cubic=True)]
rcut = 3

# pos = torch.from_numpy(frame.get_positions())
# cell = torch.from_numpy(frame.get_cell().array)
# pbc = torch.from_numpy(frame.get_pbc())
n_atoms = [0]
pos = []
cell = []
pbc = []
for ff in frames:
    n_atoms.append(len(ff))
    pos.append(torch.from_numpy(ff.get_positions()))
    cell.append(torch.from_numpy(ff.get_cell().array))
    pbc.append(torch.from_numpy(ff.get_pbc()))
pos = torch.cat(pos)
cell = torch.cat(cell)
pbc = torch.cat(pbc)
stride = torch.from_numpy(np.cumsum(n_atoms))
batch = torch.zeros(pos.shape[0],dtype=torch.long)
for ii,(st,nd) in enumerate(zip(stride[:-1],stride[1:])):
    batch[st:nd] = ii
n_atoms = torch.Tensor(n_atoms[1:]).to(dtype=torch.long)
batch,n_atoms, pos.shape, cell.shape, pbc.shape

(tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2,
         2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3]),
 tensor([ 6, 13,  8,  8]),
 torch.Size([35, 3]),
 torch.Size([12, 3]),
 torch.Size([12]))

In [213]:
dist = []
mm = []
for frame in frames:
    idx_i, idx_j, idx_S, dd = neighbor_list(
            "ijSd", frame, cutoff=rcut, self_interaction=False
        )
    dist.append(np.sort(dd))
    mm.append((idx_i, idx_j))

In [219]:
def get_fully_connected_mapping(i_ids, shifts_idx) -> torch.Tensor:
    n_atom = i_ids.shape[0]
    n_atom2 = n_atom * n_atom
    n_cell_image = shifts_idx.shape[0]
    j_ids = torch.repeat_interleave(i_ids, n_cell_image)
    mapping = torch.cartesian_prod(i_ids, j_ids)
    shifts_idx = shifts_idx.repeat((n_atom2, 1))
    
    mask = torch.ones(mapping.shape[0], dtype=bool, device=i_ids.device)
    ids = n_cell_image*torch.arange(n_atom, device=i_ids.device) \
                + torch.arange(0, mapping.shape[0], n_atom*n_cell_image, device=i_ids.device)
    # print(n_atom*n_cell_image, ids)
    mask[ids] = False
    mapping = mapping[mask, :]
    shifts_idx = shifts_idx[mask]
    return mapping, shifts_idx

def compute_images(
    positions: torch.Tensor,
    cell: torch.Tensor,
    pbc: torch.Tensor,
    cutoff: float,
    batch: torch.Tensor,
    n_atoms: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """TODO: add doc"""
    cell = cell.view((-1, 3, 3)).to(torch.float32)
    pbc = pbc.view((-1, 3))
    dtype = cell.dtype
    has_pbc = pbc.prod(dim=1, dtype=bool)
    stride = torch.zeros(n_atoms.shape[0]+1,dtype=torch.long)
    stride[1:] = torch.cumsum(n_atoms, dim=0,dtype=torch.long)
    reciprocal_cell = torch.zeros_like(cell)
    reciprocal_cell[has_pbc,:,:] = torch.linalg.inv(cell[has_pbc,:,:]).transpose(2, 1)
    inv_distances = reciprocal_cell.norm(2, dim=-1)
    num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
    num_repeats_ = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
    ids = torch.arange(positions.shape[0], device=positions.device, dtype=torch.long)
    images, mapping, batch_mapping, shifts_expanded, shifts_idx_ = [], [], [], [], []
    for i_structure in range(n_atoms.shape[0]):
        num_repeats = num_repeats_[i_structure]
        reps = []
        for ii in range(3):
            r1 = torch.arange(
                -num_repeats[ii],
                num_repeats[ii] + 1,
                device=cell.device,
                dtype=dtype,
            )
            _, indices = torch.sort(torch.abs(r1))
            reps.append(r1[indices])
        shifts_idx = torch.cartesian_prod(*reps)
        n_cell_image = shifts_idx.shape[0]
        n_atom = n_atoms[i_structure]
        pos = positions[stride[i_structure]:stride[i_structure+1]]
        i_ids = ids[stride[i_structure]:stride[i_structure+1]]
        
        s_mapping, shifts_idx = get_fully_connected_mapping(i_ids, shifts_idx)
        mapping.append(s_mapping)
        batch_mapping.append(
            i_structure
            * torch.ones(
                s_mapping.shape[0], dtype=torch.long, device=cell.device
            )
        )
        shifts_idx_.append(shifts_idx)
    return (
        torch.cat(mapping, dim=0).t(),
        torch.cat(batch_mapping, dim=0),
        torch.cat(shifts_idx_, dim=0),
    )

In [221]:
from typing import Optional
def compute_distances(
    pos: torch.Tensor,
    mapping: torch.Tensor,
    cell_shifts: Optional[torch.Tensor] = None,
):
    assert mapping.dim() == 2
    assert mapping.shape[0] == 2

    if cell_shifts is None:
        dr = pos[mapping[1]] - pos[mapping[0]]
    else:
        dr = pos[mapping[1]] - pos[mapping[0]] + cell_shifts

    return dr.norm(p=2, dim=1)

def compute_cell_shifts(cell, shifts_idx, batch_mapping):
    if cell is None:
        cell_shifts = None
    else:
        cell_shifts = torch.einsum("jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[batch_mapping])
    return cell_shifts

def compute_strict_nl_n2(pos, cell, mapping, batch_mapping,  shifts_idx):
    cell_shifts = compute_cell_shifts(cell, shifts_idx, batch_mapping)
    d2 = (pos[mapping[0]] - pos[mapping[1]] - cell_shifts).square().sum(dim=1)
    mask = d2 <= rcut*rcut
    mapping = mapping[:, mask]
    mapping_batch = batch_mapping[mask]
    shifts_idx = shifts_idx[mask]
    return mapping, mapping_batch, shifts_idx # , d2[mask].sqrt()

def compute_nl(pos, cell, pbc, rcut, batch, method='n2'):
    n_atoms = torch.bincount(batch)
    mapping, batch_mapping, shifts_idx = compute_images(pos, 
                                     cell, pbc, rcut, 
                                     batch, 
                                     n_atoms)
    if method == 'n2':
        mapping, mapping_batch, shifts_idx = compute_strict_nl_n2(pos, cell, mapping, batch_mapping,  shifts_idx)
    
    return mapping, mapping_batch, shifts_idx

In [222]:
mapping, batch_mapping, shifts_idx = compute_images(pos, 
                                     cell, pbc, rcut, 
                                     batch, 
                                     n_atoms)

cell_shifts = compute_cell_shifts(cell, shifts_idx, batch_mapping)
d2 = (pos[mapping[0]] - pos[mapping[1]] - cell_shifts ).square().sum(dim=1)
mask = d2 <= rcut*rcut
d = d2[mask].sqrt()
mapping = mapping[:, mask]
mapping_batch = batch_mapping[mask]
d.shape

RuntimeError: expected scalar type Float but found Double

In [216]:
dist[0].shape,dist[0]

((26,),
 array([1.10563938, 1.10563938, 1.10563938, 1.10563938, 1.22295086,
        1.22295086, 1.22295086, 1.22295086, 1.51286   , 1.51286   ,
        2.05241169, 2.05241169, 2.05241169, 2.05241169, 2.22229623,
        2.22229623, 2.22229623, 2.22229623, 2.38769877, 2.38769877,
        2.38769877, 2.38769877, 2.61851372, 2.61851372, 2.61851372,
        2.61851372]))

In [217]:
images.shape, pos.shape, mapping.shape, shifts_expanded.shape, shifts_idx.shape

(torch.Size([216, 3]),
 torch.Size([35, 3]),
 torch.Size([2, 288]),
 torch.Size([216, 3]),
 torch.Size([3626, 3]))

In [218]:
for ii, dd in enumerate(dist):
    print(np.allclose(np.sort(dd), np.sort(d[mapping_batch == ii].numpy())))

True
True
True
True


In [12]:
d2.min(), rcut*rcut

(tensor(1.1977, dtype=torch.float64), 9)

In [35]:
dist[0].shape, d.shape, d2.shape

((32,), torch.Size([23]), torch.Size([1720]))

In [36]:
mapping.shape, images.shape, mapping

(torch.Size([2, 23]),
 torch.Size([216, 3]),
 tensor([[ 0,  1,  1,  1,  1,  2,  2,  3,  3,  3,  3,  4,  4,  5,  5,  5,  5,  6,
           6,  7,  7,  7,  7],
         [ 1,  0,  2,  4,  6,  1,  3,  2, 14, 28, 32,  1,  5,  4, 14, 74, 80,  1,
           7,  6, 28, 74, 96]]))

In [27]:
aa = torch.arange(3).reshape(3,1)
torch.cat([aa,aa], dim=1)

tensor([[0, 0],
        [1, 1],
        [2, 2]])

In [22]:
torch.arange?

In [75]:
mapping

tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
         4, 5, 5, 5, 5, 5],
        [1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 1, 3, 4, 5, 0, 1, 2, 4, 5, 0, 1, 2, 3,
         5, 0, 1, 2, 3, 4]])

# linked cell 

In [None]:
from torch_nl import