Skip to content

Commit

Permalink
fix: remove model_def_script from AtomicModel (#3449)
Browse files Browse the repository at this point in the history
After #3438, `model_def_script` is no more saved in AtomicModel.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Mar 12, 2024
1 parent bc35ac9 commit 9bcae14
Show file tree
Hide file tree
Showing 9 changed files with 14 additions and 19 deletions.
4 changes: 0 additions & 4 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,6 @@ def do_grad_(self, var_name: str, base: str) -> bool:
return self.fitting_output_def()[var_name].c_differentiable
return self.fitting_output_def()[var_name].r_differentiable

def get_model_def_script(self) -> str:
# TODO: implement this method; saved to model
raise NotImplementedError

setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")

Expand Down
7 changes: 7 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,10 @@ class BaseModel(make_base_model()):
deepmd.dpmodel.model.base_model.BaseBaseModel
Backend-independent BaseModel class.
"""

def __init__(self) -> None:
self.model_def_script = ""

def get_model_def_script(self) -> str:
"""Get the model definition script."""
return self.model_def_script
5 changes: 1 addition & 4 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
atomic_model_: Optional[T_AtomicModel] = None,
**kwargs,
):
BaseModel.__init__(self)
if atomic_model_ is not None:
self.atomic_model: T_AtomicModel = atomic_model_
else:
Expand Down Expand Up @@ -452,10 +453,6 @@ def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.atomic_model.get_nnei()

def get_model_def_script(self) -> str:
"""Get the model definition script."""
return self.atomic_model.get_model_def_script()

def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
return self.atomic_model.get_sel()
Expand Down
3 changes: 0 additions & 3 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def reinit_pair_exclude(
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

def get_model_def_script(self) -> str:
return self.model_def_script

def atomic_output_def(self) -> FittingOutputDef:
old_def = self.fitting_output_def()
if self.atom_excl is None:
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
**kwargs,
):
torch.nn.Module.__init__(self)
self.model_def_script = ""
ntypes = len(type_map)
self.type_map = type_map
self.ntypes = ntypes
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def __init__(
):
models = [dp_model, zbl_model]
super().__init__(models, type_map, **kwargs)
self.model_def_script = ""

self.sw_rmin = sw_rmin
self.sw_rmax = sw_rmax
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(
**kwargs,
):
torch.nn.Module.__init__(self)
self.model_def_script = ""
self.tab_file = tab_file
self.rcut = rcut
self.tab = self._set_pairtab(tab_file, rcut)
Expand Down
5 changes: 0 additions & 5 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,6 @@ def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.atomic_model.get_nnei()

@torch.jit.export
def get_model_def_script(self) -> str:
"""Get the model definition script."""
return self.atomic_model.get_model_def_script()

def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class BaseModel(torch.nn.Module, make_base_model()):
def __init__(self, *args, **kwargs):
"""Construct a basic model for different tasks."""
torch.nn.Module.__init__(self)
self.model_def_script = ""

def compute_or_load_stat(
self,
Expand All @@ -39,3 +40,8 @@ def compute_or_load_stat(
The path to the statistics files.
"""
raise NotImplementedError

@torch.jit.export
def get_model_def_script(self) -> str:
"""Get the model definition script."""
return self.model_def_script

0 comments on commit 9bcae14

Please sign in to comment.