Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
import re
from collections import OrderedDict
from pathlib import PosixPath
from pathlib import Path
from typing import Any, Dict, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -587,8 +587,8 @@ def to_json_string(self) -> str:
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
elif isinstance(value, PosixPath):
value = str(value)
elif isinstance(value, Path):
value = value.as_posix()
return value

config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
Expand Down
18 changes: 18 additions & 0 deletions tests/others/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import tempfile
import unittest
from pathlib import Path

from diffusers import (
DDIMScheduler,
Expand Down Expand Up @@ -91,6 +93,14 @@ def __init__(
pass


class SampleObjectPaths(ConfigMixin):
config_name = "config.json"

@register_to_config
def __init__(self, test_file_1=Path("foo/bar"), test_file_2=Path("foo bar\\bar")):
pass


class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self):
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -286,3 +296,11 @@ def test_use_default_values(self):

# Nevertheless "e" should still be correctly loaded to [1, 3] from SampleObject2 instead of defaulting to [1, 5]
assert new_config_2.config.e == [1, 3]

def test_check_path_types(self):
# Verify that we get a string returned from a WindowsPath or PosixPath (depending on system)
config = SampleObjectPaths()
json_string = config.to_json_string()
result = json.loads(json_string)
assert result["test_file_1"] == config.config.test_file_1.as_posix()
assert result["test_file_2"] == config.config.test_file_2.as_posix()