In [13]:
import torch


def generate_nbr_list(xyz, cutoff, cell, index_tuple=None, ex_pairs=None, get_dis=False):    
    device = xyz.device    
    dis_mat = (xyz[..., None, :, :] - xyz[..., :, None, :])
    print(xyz)
    print(xyz[..., None, :, :])
    print(xyz[..., :, None :])
    print(xyz[..., :, None :])
    # project the position vector onto the cell basis 
    reduced_dis = dis_mat.matmul(cell.inverse())
    # using minimal image conv
    offsets_add = -(reduced_dis > torch.Tensor([0.5, 0.5, 0.5]).to(device)).to(torch.float).to(device)
    offsets_sub = (reduced_dis < -torch.Tensor([0.5, 0.5, 0.5]).to(device)).to(torch.float).to(device)
    offsets = offsets_add + offsets_sub
    
    dis_mat = dis_mat + offsets.matmul(cell)
    
    dis_sq = torch.triu( dis_mat.pow(2).sum(-1) )
    mask = (dis_sq < cutoff ** 2) & (dis_sq != 0)
    nbr_list = torch.nonzero( torch.triu(mask.to(torch.long)), as_tuple=False)

    if get_dis:
        return nbr_list, dis_sq[mask].sqrt(), offsets[nbr_list[:, 0], nbr_list[:, 1], :]
    else:
        return nbr_list, offsets[nbr_list[:, 0], nbr_list[:, 1], :]

# === Test Data ===
# Define atomic positions (3 atoms in 3D space)
xyz = torch.tensor([[0.0, 0.0, 0.0],  
                    [1.0, 0.0, 0.0],  
                    [0.0, 10.0, 0.0]], dtype=torch.float32)

xyz = xyz.unsqueeze(0)  # Add batch dimension (1, N, 3)

# Define cutoff distance
cutoff = 1.5

# Define cubic unit cell (size 3x3x3)
cell = torch.eye(3) * 3.0

# Run function and print results
nbr_list, distances, offsets = generate_nbr_list(xyz, cutoff, cell, get_dis=True)

# Print results
print("\n=== Neighbor List ===")
print(nbr_list)

print("\n=== Distances ===")
print(distances)

print("\n=== Offsets ===")
print(offsets)


tensor([[[ 0.,  0.,  0.],
         [ 1.,  0.,  0.],
         [ 0., 10.,  0.]]])
tensor([[[[ 0.,  0.,  0.],
          [ 1.,  0.,  0.],
          [ 0., 10.,  0.]]]])
tensor([[[ 0.,  0.,  0.],
         [ 1.,  0.,  0.],
         [ 0., 10.,  0.]]])
tensor([[[ 0.,  0.,  0.],
         [ 1.,  0.,  0.],
         [ 0., 10.,  0.]]])

=== Neighbor List ===
tensor([[0, 0, 1]])

=== Distances ===
tensor([1.])

=== Offsets ===
tensor([[[ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0., -1.,  0.]]])


In [2]:
import numpy as np
np.random.rand()

0.21626000619598806