In [88]:
import torch
import ase
from ase.io import read, write
import numpy as np
import scipy
from ase.build import molecule, bulk
from ase.neighborlist import neighbor_list
from typing import Tuple

In [28]:
from torch_nl.ase_impl import ase_neighbor_list
from torch_nl.torch_impl import compute_images

In [87]:
diamond100?

# n^2 impl

##  test 

In [97]:
frame = molecule('OCHCHO')
frame = bulk('Si', 'diamond', a=6, cubic=True)
rcut = 3

pos = torch.from_numpy(frame.get_positions())
cell = torch.from_numpy(frame.get_cell().array)
pbc = torch.from_numpy(frame.get_pbc())

In [98]:
idx_i, idx_j, idx_S, dist = neighbor_list(
        "ijSd", frame, cutoff=rcut, self_interaction=False
    )
np.sort(dist)

array([2.59807621, 2.59807621, 2.59807621, 2.59807621, 2.59807621,
       2.59807621, 2.59807621, 2.59807621, 2.59807621, 2.59807621,
       2.59807621, 2.59807621, 2.59807621, 2.59807621, 2.59807621,
       2.59807621, 2.59807621, 2.59807621, 2.59807621, 2.59807621,
       2.59807621, 2.59807621, 2.59807621, 2.59807621, 2.59807621,
       2.59807621, 2.59807621, 2.59807621, 2.59807621, 2.59807621,
       2.59807621, 2.59807621])

In [99]:
def get_fully_connected_mapping(i_ids, j_ids) -> torch.Tensor:
    mapping = torch.cartesian_prod(i_ids, j_ids)
    mapping = mapping[mapping[:,0] != mapping[:,1], :]
    return mapping

def compute_images(
    positions: torch.Tensor,
    cell: torch.Tensor,
    pbc: torch.Tensor,
    cutoff: float,
    batch: torch.Tensor,
    n_atoms: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """TODO: add doc"""
    cell = cell.view((-1, 3, 3)).to(torch.float64)
    pbc = pbc.view((-1, 3))
    has_pbc = pbc.prod(dim=1, dtype=bool)
    print(has_pbc)
    reciprocal_cell = torch.zeros_like(cell)
    reciprocal_cell[has_pbc,:,:] = torch.linalg.inv(cell[has_pbc,:,:]).transpose(2, 1)
    # print('reciprocal_cell: ', reciprocal_cell.device)
    inv_distances = reciprocal_cell.norm(2, dim=-1)
    # print('inv_distances: ', inv_distances.device)
    num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
    print(num_repeats)
    num_repeats_ = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
    # print('num_repeats_: ', num_repeats_.device)
    ids = torch.arange(positions.shape[0], device=positions.device, dtype=torch.long)
    images, mapping, batch_images, shifts_expanded, shifts_idx_ = [], [], [], [], []
    for i_structure in range(num_repeats_.shape[0]):
        num_repeats = num_repeats_[i_structure]
        r1 = torch.arange(
            -num_repeats[0],
            num_repeats[0] + 1,
            device=cell.device,
            dtype=torch.long,
        )
        r2 = torch.arange(
            -num_repeats[1],
            num_repeats[1] + 1,
            device=cell.device,
            dtype=torch.long,
        )
        r3 = torch.arange(
            -num_repeats[2],
            num_repeats[2] + 1,
            device=cell.device,
            dtype=torch.long,
        )
        shifts_idx = torch.cartesian_prod(r1, r2, r3)
        shifts = torch.matmul(shifts_idx.to(cell.dtype), cell[i_structure])
        pos = positions[batch == i_structure]
        i_ids = ids[batch == i_structure]
        j_ids = i_ids.repeat(shifts.shape[0])
        print(i_ids.shape, j_ids.shape)
        s_mapping = get_fully_connected_mapping(i_ids, j_ids)
        print(n_atoms[i_structure])
        shift_expanded = shifts.repeat(1, n_atoms[i_structure]).view((-1, 3))
        pos_expanded = pos.repeat(shifts.shape[0], 1)

        images.append(pos_expanded + shift_expanded)
        mapping.append(s_mapping)
        batch_images.append(
            i_structure
            * torch.ones(
                images[-1].shape[0], dtype=torch.int64, device=cell.device
            )
        )
        shifts_expanded.append(shift_expanded)
        shifts_idx_.append(
            shifts_idx.repeat(1, n_atoms[i_structure]).view((-1, 3))
        )
    return (
        torch.cat(images, dim=0).to(positions.dtype),
        torch.cat(mapping, dim=0).t(),
        torch.cat(batch_images, dim=0),
        torch.cat(shifts_expanded, dim=0).to(positions.dtype),
        torch.cat(shifts_idx_, dim=0),
    )

In [100]:
images, mapping, batch_images, shifts_expanded, shifts_idx_ = compute_images(pos, 
                                                                             cell, pbc, rcut, 
                                                                             torch.zeros(pos.shape[0]), 
                                                                             pos.shape[0]*torch.ones(1, dtype=int))




d2 = (pos[mapping[0]] - images[mapping[1]]).square().sum(dim=1)
mask = d2 <= rcut*rcut
d = d2[mask].sqrt()
mapping = mapping[:, mask]

tensor([True])
tensor([[1, 1, 1]])
torch.Size([8]) torch.Size([216])
tensor(8)


In [101]:
np.sort(dist) - np.sort(d.numpy())

ValueError: operands could not be broadcast together with shapes (32,) (0,) 

In [106]:
d2.min(), rcut*rcut

(tensor(24.7500, dtype=torch.float64), 9)

In [103]:
dist.shape, d.shape, d2.shape

((32,), torch.Size([0]), torch.Size([1512]))

In [76]:
mapping.shape, images.shape, mapping

(torch.Size([2, 30]),
 torch.Size([6, 3]),
 tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
          4, 5, 5, 5, 5, 5],
         [1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 1, 3, 4, 5, 0, 1, 2, 4, 5, 0, 1, 2, 3,
          5, 0, 1, 2, 3, 4]]))

In [27]:
aa = torch.arange(3).reshape(3,1)
torch.cat([aa,aa], dim=1)

tensor([[0, 0],
        [1, 1],
        [2, 2]])

In [22]:
torch.arange?

In [75]:
mapping

tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
         4, 5, 5, 5, 5, 5],
        [1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 1, 3, 4, 5, 0, 1, 2, 4, 5, 0, 1, 2, 3,
         5, 0, 1, 2, 3, 4]])