Skip to content

Commit

Permalink
enable allegro/nequip to work with float32 and float64
Browse files Browse the repository at this point in the history
tests for nequip, allegro double prec.

include torch allow_tf32

included torch freeze model

updated energy of nequip water test with double prec.

nequip/allegro: update parsing input and torch api
  • Loading branch information
gabriele16 authored and oschuett committed May 25, 2023
1 parent a11bcb3 commit 3a7d654
Show file tree
Hide file tree
Showing 17 changed files with 874 additions and 84 deletions.
Binary file added data/Allegro/water-gra-film-double.pth
Binary file not shown.
Binary file added data/NequIP/water-double.pth
Binary file not shown.
101 changes: 85 additions & 16 deletions src/force_fields_input.F
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ MODULE force_fields_input
USE shell_potential_types, ONLY: shell_p_create,&
shell_p_type
USE string_utilities, ONLY: uppercase
USE torch_api, ONLY: torch_model_read_metadata
USE torch_api, ONLY: torch_allow_tf32,&
torch_model_read_metadata
#include "./base/base_uses.f90"

IMPLICIT NONE
Expand Down Expand Up @@ -707,7 +708,7 @@ END SUBROUTINE read_quip_section
!> \param nonbonded ...
!> \param section ...
!> \param start ...
!> \author teo
!> \author Gabriele Tocci
! **************************************************************************************************
SUBROUTINE read_nequip_section(nonbonded, section, start)
TYPE(pair_potential_p_type), POINTER :: nonbonded
Expand Down Expand Up @@ -755,7 +756,7 @@ END SUBROUTINE read_nequip_section
!> \param nonbonded ...
!> \param section ...
!> \param start ...
!> \author teo
!> \author Gabriele Tocci
! **************************************************************************************************
SUBROUTINE read_allegro_section(nonbonded, section, start)
TYPE(pair_potential_p_type), POINTER :: nonbonded
Expand Down Expand Up @@ -2371,15 +2372,16 @@ END SUBROUTINE read_eam_data
! **************************************************************************************************
!> \brief reads NequIP potential from .pth file
!> \param nequip ...
!> \author Gabriele Tocci
! **************************************************************************************************
SUBROUTINE read_nequip_data(nequip)
TYPE(nequip_pot_type), POINTER :: nequip
CHARACTER(len=*), PARAMETER :: routineN = 'read_nequip_data'
CHARACTER(LEN=default_path_length) :: cutoff_str
CHARACTER(LEN=default_path_length) :: allow_tf32_str, config_str, cutoff_str
INTEGER :: handle
LOGICAL :: found
LOGICAL :: allow_tf32, found
CALL timeset(routineN, handle)
Expand All @@ -2395,30 +2397,40 @@ SUBROUTINE read_nequip_data(nequip)
READ (cutoff_str, *) nequip%rcutsq
nequip%rcutsq = cp_unit_to_cp2k(nequip%rcutsq, nequip%unit_coords)
nequip%rcutsq = nequip%rcutsq*nequip%rcutsq
nequip%unit_coords_val = 1.0_dp
nequip%unit_coords_val = cp_unit_to_cp2k(nequip%unit_coords_val, nequip%unit_coords)
nequip%unit_forces_val = 1.0_dp
nequip%unit_forces_val = cp_unit_to_cp2k(nequip%unit_forces_val, nequip%unit_forces)
nequip%unit_energy_val = 1.0_dp
nequip%unit_energy_val = cp_unit_to_cp2k(nequip%unit_energy_val, nequip%unit_energy)
nequip%unit_cell_val = 1.0_dp
nequip%unit_cell_val = cp_unit_to_cp2k(nequip%unit_cell_val, nequip%unit_cell)
! look in config which contains all the .yaml file options to see if we use float32 or float64
config_str = torch_model_read_metadata(nequip%nequip_file_name, "config")
CALL read_default_dtype(config_str, nequip%do_nequip_sp)
allow_tf32_str = torch_model_read_metadata(nequip%nequip_file_name, "allow_tf32")
allow_tf32 = (TRIM(allow_tf32_str) == "1")
IF (TRIM(allow_tf32_str) /= "1" .AND. TRIM(allow_tf32_str) /= "0") THEN
CALL cp_abort(__LOCATION__, &
"The value for allow_tf32 <"//TRIM(allow_tf32_str)// &
"> is not supported. Check the .yaml and .pth files.")
END IF
CALL torch_allow_tf32(allow_tf32)
CALL timestop(handle)
END SUBROUTINE read_nequip_data
! **************************************************************************************************
!> \brief reads ALLEGRO potential from .pth file
!> \param allegro ...
!> \author Gabriele Tocci
! **************************************************************************************************
SUBROUTINE read_allegro_data(allegro)
TYPE(allegro_pot_type), POINTER :: allegro
CHARACTER(len=*), PARAMETER :: routineN = 'read_allegro_data'
CHARACTER(LEN=default_path_length) :: cutoff_str
CHARACTER(LEN=default_path_length) :: allow_tf32_str, config_str, cutoff_str
INTEGER :: handle
LOGICAL :: found
LOGICAL :: allow_tf32, found
CALL timeset(routineN, handle)
Expand All @@ -2429,23 +2441,80 @@ SUBROUTINE read_allegro_data(allegro)
"> not found.")
END IF
allegro%allegro_version = torch_model_read_metadata(allegro%allegro_file_name, "nequip_version")
allegro%nequip_version = torch_model_read_metadata(allegro%allegro_file_name, "nequip_version")
IF (allegro%nequip_version == "") THEN
CALL cp_abort(__LOCATION__, &
"Allegro model file <"//TRIM(allegro%allegro_file_name)// &
"> has not been deployed; did you forget to run `nequip-deploy`?")
END IF
cutoff_str = torch_model_read_metadata(allegro%allegro_file_name, "r_max")
READ (cutoff_str, *) allegro%rcutsq
allegro%rcutsq = cp_unit_to_cp2k(allegro%rcutsq, allegro%unit_coords)
allegro%rcutsq = allegro%rcutsq*allegro%rcutsq
allegro%unit_coords_val = 1.0_dp
allegro%unit_coords_val = cp_unit_to_cp2k(allegro%unit_coords_val, allegro%unit_coords)
allegro%unit_forces_val = 1.0_dp
allegro%unit_forces_val = cp_unit_to_cp2k(allegro%unit_forces_val, allegro%unit_forces)
allegro%unit_energy_val = 1.0_dp
allegro%unit_energy_val = cp_unit_to_cp2k(allegro%unit_energy_val, allegro%unit_energy)
allegro%unit_cell_val = 1.0_dp
allegro%unit_cell_val = cp_unit_to_cp2k(allegro%unit_cell_val, allegro%unit_cell)
! look in config which contains all the .yaml file options to see if we use float32 or float64
config_str = torch_model_read_metadata(allegro%allegro_file_name, "config")
CALL read_default_dtype(config_str, allegro%do_allegro_sp)
allow_tf32_str = torch_model_read_metadata(allegro%allegro_file_name, "allow_tf32")
allow_tf32 = (TRIM(allow_tf32_str) == "1")
IF (TRIM(allow_tf32_str) /= "1" .AND. TRIM(allow_tf32_str) /= "0") THEN
CALL cp_abort(__LOCATION__, &
"The value for allow_tf32 <"//TRIM(allow_tf32_str)// &
"> is not supported. Check the .yaml and .pth files.")
END IF
CALL torch_allow_tf32(allow_tf32)
CALL timestop(handle)
END SUBROUTINE read_allegro_data
! **************************************************************************************************
!> \brief reads the default_dtype used in the Allegro/NequIP model by parsing the config file
!> \param config_str ...
!> \param do_model_sp ...
!> \author Gabriele Tocci
! **************************************************************************************************
SUBROUTINE read_default_dtype(config_str, do_model_sp)
CHARACTER(LEN=default_path_length) :: config_str
LOGICAL :: do_model_sp
CHARACTER(len=*), PARAMETER :: routineN = 'read_default_dtype'
INTEGER :: handle, i, idx, len_config
CALL timeset(routineN, handle)
len_config = LEN_TRIM(config_str)
idx = INDEX(config_str, "default_dtype:")
IF (idx /= 0) THEN
i = idx + 14 ! skip over "default_dtype:"
DO WHILE (i <= len_config .AND. config_str(i:i) == " ")
i = i + 1 ! skip over any whitespace
END DO
IF (i > len_config) THEN
CALL cp_abort(__LOCATION__, &
"No default_dtype found, check the Nequip/Allegro .yaml or .pth files."// &
" Default_dtype should be either <float32> or <float64>.")
ELSE IF (config_str(i:i + 6) == "float32") THEN
do_model_sp = .TRUE.
ELSE IF (config_str(i:i + 6) == "float64") THEN
do_model_sp = .FALSE.
ELSE
CALL cp_abort(__LOCATION__, &
"The default_dtype should be either <float32> or <float64>."// &
" Check the NequIP/Allegro .yaml and .pth files.")
END IF
END IF
CALL timestop(handle)
END SUBROUTINE read_default_dtype
! **************************************************************************************************
!> \brief reads TABPOT potential from file
!> \param tab ...
Expand Down
8 changes: 6 additions & 2 deletions src/input_cp2k_mm.F
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,9 @@ SUBROUTINE create_NEQUIP_section(section)
CALL keyword_create(keyword, __LOCATION__, name="ATOMS", &
description="Defines the atomic kinds involved in the NEQUIP potential. "// &
"Provide a list of each element.", &
"Provide a list of each element, making sure that the mapping from the ATOMS list "// &
"to NequIP atom types is correct. This mapping should also be consistent for the "// &
"atomic coordinates as specified in the sections COORDS or TOPOLOGY.", &
usage="ATOMS {KIND 1} {KIND 2} .. {KIND N}", type_of_var=char_t, &
n_var=-1)
CALL section_add_keyword(section, keyword)
Expand Down Expand Up @@ -1542,7 +1544,9 @@ SUBROUTINE create_ALLEGRO_section(section)
CALL keyword_create(keyword, __LOCATION__, name="ATOMS", &
description="Defines the atomic kinds involved in the ALLEGRO potential. "// &
"Provide a list of each element.", &
"Provide a list of each element, making sure that the mapping from the ATOMS list "// &
"to NequIP atom types is correct. This mapping should also be consistent for the "// &
"atomic coordinates as specified in the sections COORDS or TOPOLOGY.", &
usage="ATOMS {KIND 1} {KIND 2} .. {KIND N}", type_of_var=char_t, &
n_var=-1)
CALL section_add_keyword(section, keyword)
Expand Down
76 changes: 51 additions & 25 deletions src/manybody_allegro.F
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ MODULE manybody_allegro
torch_dict_release,&
torch_dict_type,&
torch_model_eval,&
torch_model_freeze,&
torch_model_load
USE util, ONLY: sort
#include "./base/base_uses.f90"
Expand Down Expand Up @@ -244,18 +245,21 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom
INTEGER, DIMENSION(:, :), POINTER :: list, sort_list
LOGICAL, ALLOCATABLE :: use_atom(:)
REAL(kind=dp) :: drij, lattice(3, 3), rab2_max, rij(3)
REAL(KIND=dp), DIMENSION(3) :: cell_v, cvi
REAL(kind=sp), ALLOCATABLE :: pos(:, :)
REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, new_edge_cell_shifts
REAL(sp), DIMENSION(:, :), POINTER :: atomic_energy, forces
REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, new_edge_cell_shifts, &
pos
REAL(kind=dp), DIMENSION(3) :: cell_v, cvi
REAL(kind=dp), DIMENSION(:, :), POINTER :: atomic_energy, forces
REAL(kind=sp) :: lattice_sp(3, 3)
REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: new_edge_cell_shifts_sp, pos_sp
REAL(kind=sp), DIMENSION(:, :), POINTER :: atomic_energy_sp, forces_sp
TYPE(allegro_data_type), POINTER :: allegro_data
TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
TYPE(pair_potential_single_type), POINTER :: pot
TYPE(torch_dict_type) :: inputs, outputs

CALL timeset(routineN, handle)

NULLIFY (atomic_energy, forces)
NULLIFY (atomic_energy, forces, atomic_energy_sp, forces_sp)
n_atoms = SIZE(particle_set)
ALLOCATE (use_atom(n_atoms))
use_atom = .FALSE.
Expand All @@ -265,7 +269,6 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom
pot => potparm%pot(ikind, jkind)%pot
DO i = 1, SIZE(pot%type)
IF (pot%type(i) /= allegro_type) CYCLE
IF (.NOT. ASSOCIATED(allegro)) allegro => pot%set(i)%allegro
DO iat = 1, n_atoms
IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .TRUE.
Expand All @@ -282,6 +285,7 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom
CALL fist_nonbond_env_set(fist_nonbond_env, allegro_data=allegro_data)
NULLIFY (allegro_data%use_indices, allegro_data%force)
CALL torch_model_load(allegro_data%model, pot%set(1)%allegro%allegro_file_name)
CALL torch_model_freeze(allegro_data%model)
END IF
IF (ASSOCIATED(allegro_data%force)) THEN
IF (SIZE(allegro_data%force, 2) /= n_atoms_use) THEN
Expand Down Expand Up @@ -366,7 +370,7 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom
IF (drij <= rab2_max) THEN
nedges = nedges + 1
edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
edge_cell_shifts(:, nedges) = REAL(cvi, kind=sp)
edge_cell_shifts(:, nedges) = cvi
END IF
END DO
ifirst = ilast + 1
Expand All @@ -378,17 +382,21 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom
END DO Kind_Group_Loop_Allegro
END DO

allegro => pot%set(1)%allegro

ALLOCATE (temp_edge_index(2, nedges))
temp_edge_index(:, :) = edge_index(:, :nedges)
ALLOCATE (new_edge_cell_shifts(3, nedges))
new_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
DEALLOCATE (edge_cell_shifts)

ALLOCATE (t_edge_index(nedges, 2))

t_edge_index(:, :) = TRANSPOSE(temp_edge_index)
DEALLOCATE (temp_edge_index, edge_index)

lattice = cell%hmat/pot%set(1)%allegro%unit_cell_val
lattice_sp = REAL(lattice, kind=sp)

iat_use = 0
ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
Expand All @@ -397,38 +405,56 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom
IF (.NOT. use_atom(iat)) CYCLE
iat_use = iat_use + 1
atom_types(iat_use) = particle_set(iat)%atomic_kind%kind_number - 1
pos(:, iat) = REAL(r_last_update_pbc(iat)%r(:)/pot%set(1)%allegro%unit_coords_val, kind=sp)
pos(:, iat) = r_last_update_pbc(iat)%r(:)/allegro%unit_coords_val
END DO

CALL torch_dict_create(inputs)
CALL torch_dict_insert(inputs, "pos", pos)

IF (allegro%do_allegro_sp) THEN
ALLOCATE (new_edge_cell_shifts_sp(3, nedges), pos_sp(3, n_atoms_use))
new_edge_cell_shifts_sp(:, :) = REAL(new_edge_cell_shifts(:, :), kind=sp)
pos_sp(:, :) = REAL(pos(:, :), kind=sp)
DEALLOCATE (pos, new_edge_cell_shifts)
CALL torch_dict_insert(inputs, "pos", pos_sp)
CALL torch_dict_insert(inputs, "edge_cell_shift", new_edge_cell_shifts_sp)
CALL torch_dict_insert(inputs, "cell", lattice_sp)
ELSE
CALL torch_dict_insert(inputs, "pos", pos)
CALL torch_dict_insert(inputs, "edge_cell_shift", new_edge_cell_shifts)
CALL torch_dict_insert(inputs, "cell", lattice)
END IF

CALL torch_dict_insert(inputs, "edge_index", t_edge_index)
CALL torch_dict_insert(inputs, "edge_cell_shift", new_edge_cell_shifts)
CALL torch_dict_insert(inputs, "cell", REAL(lattice, kind=sp))
CALL torch_dict_insert(inputs, "atom_types", atom_types)

CALL torch_dict_create(outputs)

CALL torch_model_eval(allegro_data%model, inputs, outputs)

CALL torch_dict_get(outputs, "atomic_energy", atomic_energy)
CALL torch_dict_get(outputs, "forces", forces)

allegro_data%force(:, :) = REAL(forces(:, :), kind=dp)*allegro%unit_forces_val

pot_allegro = 0.0_dp

DO iat_use = 1, SIZE(unique_list_a)
i = unique_list_a(iat_use)
pot_allegro = pot_allegro + REAL(atomic_energy(1, i), kind=dp)*allegro%unit_energy_val
END DO
IF (allegro%do_allegro_sp) THEN
CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_sp)
CALL torch_dict_get(outputs, "forces", forces_sp)
allegro_data%force(:, :) = REAL(forces_sp(:, :), kind=dp)*allegro%unit_forces_val
DO iat_use = 1, SIZE(unique_list_a)
i = unique_list_a(iat_use)
pot_allegro = pot_allegro + REAL(atomic_energy_sp(1, i), kind=dp)*allegro%unit_energy_val
END DO
DEALLOCATE (forces_sp, atomic_energy_sp, new_edge_cell_shifts_sp, pos_sp)
ELSE
CALL torch_dict_get(outputs, "atomic_energy", atomic_energy)
CALL torch_dict_get(outputs, "forces", forces)
allegro_data%force(:, :) = forces(:, :)*allegro%unit_forces_val
DO iat_use = 1, SIZE(unique_list_a)
i = unique_list_a(iat_use)
pot_allegro = pot_allegro + atomic_energy(1, i)*allegro%unit_energy_val
END DO
DEALLOCATE (forces, atomic_energy, pos, new_edge_cell_shifts)
END IF

CALL torch_dict_release(inputs)
CALL torch_dict_release(outputs)

DEALLOCATE (pos, t_edge_index, edge_cell_shifts, atom_types, atomic_energy)
DEALLOCATE (new_edge_cell_shifts)
DEALLOCATE (forces)
DEALLOCATE (t_edge_index, atom_types)

CALL timestop(handle)
END SUBROUTINE allegro_energy_store_force_virial
Expand Down

0 comments on commit 3a7d654

Please sign in to comment.