Skip to content

Commit

Permalink
Merge pull request #24 from nasaharvest/path_updates
Browse files Browse the repository at this point in the history
Allow a custom model path
  • Loading branch information
gabrieltseng committed Oct 18, 2023
2 parents 445d7ca + 1801342 commit 182d590
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions presto/presto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from copy import deepcopy
from pathlib import Path
from typing import Optional, Sized, Tuple, Union, cast

import numpy as np
Expand Down Expand Up @@ -785,7 +786,7 @@ def construct_finetuning_model(
return model

@classmethod
def load_pretrained(cls):
def load_pretrained(cls, model_path: Union[str, Path] = default_model_path):
model = cls.construct()
model.load_state_dict(torch.load(default_model_path, map_location=device))
model.load_state_dict(torch.load(model_path, map_location=device))
return model

0 comments on commit 182d590

Please sign in to comment.