Skip to content

Commit

Permalink
Edit interface of model saver
Browse files Browse the repository at this point in the history
  • Loading branch information
Parzival-05 committed Jun 3, 2024
1 parent f49bb6c commit e93c530
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions AIAgent/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,22 @@ def next(self):
self.total_epochs += 1


def model_saver(epochs_info: EpochsInfo, svm_name: str, dir: Path, model: object):
def model_saver(
epochs_info: EpochsInfo, svm_name: str, dir: Path, model: torch.nn.Module
):
"""
Use it to save your torch model with some info in the following format: "{`total_epochs` + `epoch`}_{`svm_name`}" in `dir` directory
Parameters
----------
:param EpochsInfo `epochs_info`: EpochsInfo's instance
:param str `svm_name`: name of svm
:param object `model`: model to save
:param torch.nn.Module `model`: model to save
:param Path `dir`: directory
"""
path_to_model = os.path.join(dir, f"{epochs_info.total_epochs}_{svm_name}")
torch.save(model, Path(path_to_model))
torch.save(model.state_dict(), Path(path_to_model))


def run_training(
Expand Down Expand Up @@ -228,7 +230,7 @@ def objective(
criterion=criterion,
)
torch.cuda.empty_cache()
model_saver(model=model.state_dict())
model_saver(model=model)

model.eval()
dataset.switch_to("val")
Expand Down

0 comments on commit e93c530

Please sign in to comment.