Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/PACKAGES/metatomic/in.metatomic
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mass Ni 58.693

velocity all create 123 42

pair_style metatomic nickel-lj.pt
pair_style metatomic nickel-lj.pt uncertainty_threshold off
# pair_style metatomic nickel-lj-extensions.pt extensions collected-extensions/
pair_coeff * * 28

Expand Down
Binary file modified examples/PACKAGES/metatomic/nickel-lj.pt
Binary file not shown.
114 changes: 80 additions & 34 deletions src/KOKKOS/metatomic_system_kokkos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,74 @@
Filippo Bigi <filippo.bigi@epfl.ch>
------------------------------------------------------------------------- */
#include "metatomic_system_kokkos.h"

#include "metatomic_timer.h"

#include "domain.h"
#include "error.h"

#include "comm.h"
#include "atom_kokkos.h"

#include <torch/cuda.h>

using namespace LAMMPS_NS;

/// Compute the inverse of the cell matrix of the system, accounting for
/// non-periodic directions by setting the corresponding rows to an unit vector
/// orthogonal to the periodic directions. This is used to compute the cell
/// shifts of neighbor pairs.
static torch::Tensor cell_inverse(const metatomic_torch::System& system) {
auto cell = system->cell().clone();
auto periodic = system->pbc();

// find number of periodic directions and their indices
int n_periodic = 0;
int periodic_idx_1 = -1;
int periodic_idx_2 = -1;
for (int i = 0; i < 3; ++i) {
if (periodic[i].item<bool>()) {
n_periodic += 1;
if (periodic_idx_1 == -1) {
periodic_idx_1 = i;
} else if (periodic_idx_2 == -1) {
periodic_idx_2 = i;
}
}
}

// adjust the box matrix to have a simple orthogonal dimension along
// non-periodic directions
if (n_periodic == 0) {
return torch::eye(3, cell.options());
} else if (n_periodic == 1) {
assert(periodic_idx_1 != -1);
// Make the two non-periodic directions orthogonal to the periodic one
auto a = cell[periodic_idx_1];
auto b = torch::tensor({0, 1, 0}, cell.options());
if (torch::abs(torch::dot(a / a.norm(), b)).item<double>() > 0.9) {
b = torch::tensor({0, 0, 1}, cell.options());
}
auto c = torch::cross(a, b);
c /= c.norm();
b = torch::cross(c, a);
b /= b.norm();

// Assign back to the cell picking the "non-periodic" indices without ifs
cell[(periodic_idx_1 + 1) % 3] = b;
cell[(periodic_idx_1 + 2) % 3] = c;
} else if (n_periodic == 2) {
assert(periodic_idx_1 != -1 && periodic_idx_2 != -1);
// Make the one non-periodic direction orthogonal to the two periodic ones
auto a = cell[periodic_idx_1];
auto b = cell[periodic_idx_2];
auto c = torch::cross(a, b);
c /= c.norm();

// Assign back to the matrix picking the "non-periodic" index without ifs
cell[(3 - periodic_idx_1 - periodic_idx_2)] = c;
}

return cell.inverse();
}

template<typename T, class DeviceType>
using UnmanagedView = Kokkos::View<T, Kokkos::LayoutRight, DeviceType, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

Expand All @@ -47,11 +103,9 @@ MetatomicSystemAdaptorKokkos<DeviceType>::MetatomicSystemAdaptorKokkos(LAMMPS *l
this->strain = torch::eye(3, tensor_options);
}

#include "comm.h"

template<class DeviceType>
void MetatomicSystemAdaptorKokkos<DeviceType>::setup_neighbors_remap_kk(metatomic_torch::System& system, NeighListKokkos<DeviceType>* list) {
auto _ = MetatomicTimer("converting kokkos neighbors with ghosts remapping");
void MetatomicSystemAdaptorKokkos<DeviceType>::setup_neighbors_kk(metatomic_torch::System& system, NeighListKokkos<DeviceType>* list) {
auto _ = MetatomicTimer("converting kokkos neighbors list");
auto dtype = system->positions().scalar_type();

auto total_n_atoms = atomKK->nlocal + atomKK->nghost;
Expand Down Expand Up @@ -144,7 +198,7 @@ void MetatomicSystemAdaptorKokkos<DeviceType>::setup_neighbors_remap_kk(metatomi
);

auto x = system->positions().detach();
auto cell_inverse = system->cell().detach().inverse();
auto cell_inv = cell_inverse(system);

// convert from LAMMPS NL format to metatomic NL format
auto expanded_arange = torch::arange(
Expand Down Expand Up @@ -213,13 +267,11 @@ void MetatomicSystemAdaptorKokkos<DeviceType>::setup_neighbors_remap_kk(metatomi
auto distances_filt = distances.index({cutoff_mask, torch::indexing::Slice()});

// find filtered interatomic vectors using the original atoms
auto original_distances_filtered =
x.index_select(0, neighbors_original_id_filt)
- x.index_select(0, centers_original_id_filt);
auto original_distances_filtered = x.index_select(0, neighbors_original_id_filt) - x.index_select(0, centers_original_id_filt);

// cell shifts
auto pair_shifts = distances_filt - original_distances_filtered;
auto cell_shifts = pair_shifts.matmul(cell_inverse);
auto cell_shifts = pair_shifts.matmul(cell_inv);
cell_shifts = torch::round(cell_shifts).to(torch::kInt32);

if (full_list) {
Expand Down Expand Up @@ -308,7 +360,6 @@ template<class DeviceType>
metatomic_torch::System MetatomicSystemAdaptorKokkos<DeviceType>::system_from_lmp(
NeighList* list,
bool do_virial,
bool remap_pairs,
torch::ScalarType dtype,
torch::Device device
) {
Expand Down Expand Up @@ -348,41 +399,36 @@ metatomic_torch::System MetatomicSystemAdaptorKokkos<DeviceType>::system_from_lm
auto system_positions = this->positions.to(dtype);
cell = cell.to(dtype);

// Periodic boundary conditions handling.
auto pbc = torch::tensor(
{domain->xperiodic, domain->yperiodic, domain->zperiodic},
torch::TensorOptions().dtype(torch::kBool).device(this->device_)
);

cell.index_put_(
{torch::logical_not(pbc)},
torch::tensor({0.0}, torch::TensorOptions().dtype(dtype).device(this->device_))
);

if (do_virial) {
auto model_strain = this->strain.to(dtype);

// pretend to scale positions/cell by the strain so that
// it enters the computational graph.
// scale positions/cell by the strain so that it enters the
// computational graph.
system_positions = system_positions.matmul(model_strain);
cell = cell.matmul(model_strain);
}

// Periodic boundary conditions handling.
//
// While Metatomic models can support mixed PBC settings, we currently
// assume that the system is fully periodic and we throw an error otherwise
if (!domain->xperiodic || !domain->yperiodic || !domain->zperiodic) {
error->one(FLERR, "metatomic/kk currently requires a fully periodic system");
}
auto pbc = torch::tensor(
{domain->xperiodic, domain->yperiodic, domain->zperiodic},
torch::TensorOptions().dtype(torch::kBool).device(this->device_)
);

auto system = torch::make_intrusive<metatomic_torch::SystemHolder>(
atomic_types_,
system_positions,
cell,
pbc
);

if (remap_pairs) {
auto* kk_list = dynamic_cast<NeighListKokkos<DeviceType>*>(list);
assert(kk_list != nullptr);
this->setup_neighbors_remap_kk(system, kk_list);
} else {
error->one(FLERR, "the kokkos version of metatomic requires remap_pairs to be true");
}
auto* kk_list = dynamic_cast<NeighListKokkos<DeviceType>*>(list);
assert(kk_list != nullptr);
this->setup_neighbors_kk(system, kk_list);

return system;
}
Expand Down
3 changes: 1 addition & 2 deletions src/KOKKOS/metatomic_system_kokkos.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,11 @@ class MetatomicSystemAdaptorKokkos : public MetatomicSystemAdaptor {
metatomic_torch::System system_from_lmp(
NeighList* list,
bool do_virial,
bool remap_pairs,
torch::ScalarType dtype,
torch::Device device
) override;

void setup_neighbors_remap_kk(metatomic_torch::System& system, NeighListKokkos<DeviceType>* list);
void setup_neighbors_kk(metatomic_torch::System& system, NeighListKokkos<DeviceType>* list);

private:
/// Torch device corresponding to the kokkos `DeviceType`
Expand Down
Loading