Skip to content

Commit

Permalink
Support reading .ckpt files
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Sep 8, 2023
1 parent d621298 commit e2e606e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
2 changes: 1 addition & 1 deletion backend/src/nodes/properties/inputs/file_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def PthFileInput(primary_input: bool = False) -> FileInput:
input_type_name="PthFile",
label="Model",
file_kind="pth",
filetypes=[".pt", ".pth"],
filetypes=[".pt", ".pth", ".ckpt"],
primary_input=primary_input,
)

Expand Down
24 changes: 22 additions & 2 deletions backend/src/packages/chaiNNer_pytorch/pytorch/io/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,37 @@ def load_model_node(path: str) -> Tuple[PyTorchModel, str, str]:

try:
logger.debug(f"Reading state dict from path: {path}")
extension = os.path.splitext(path)[1].lower()

if os.path.splitext(path)[1].lower() == ".pt":
if extension == ".pt":
state_dict = torch.jit.load( # type: ignore
path, map_location=pytorch_device
).state_dict()
else:
elif extension == ".pth":
state_dict = torch.load(
path,
map_location=pytorch_device,
pickle_module=RestrictedUnpickle, # type: ignore
)
elif extension == ".ckpt":
checkpoint = torch.load(
path,
map_location=pytorch_device,
pickle_module=RestrictedUnpickle, # type: ignore
)
if "state_dict" in checkpoint:
state_dict = {}
for i, j in checkpoint["state_dict"].items():
if "netG." in i:
key = i.replace("netG.", "")
state_dict[key] = j
else:
# Assume it's a state dict, might as well
state_dict = checkpoint
else:
raise ValueError(
f"Unsupported model file extension {extension}. Please try a supported model type."
)

model = load_state_dict(state_dict)

Expand Down

0 comments on commit e2e606e

Please sign in to comment.