Skip to content

Commit

Permalink
Merge pull request #6 from IzaakWN/master
Browse files Browse the repository at this point in the history
Rebase & fix bug in JSON encoder
  • Loading branch information
IzaakWN committed Mar 26, 2021
2 parents b98d669 + c7cf136 commit 4dc8790
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 10 deletions.
5 changes: 3 additions & 2 deletions src/correctionlib/JSONEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ def encode(self, obj: Any) -> str:
else: # break long list into multiple lines
nlines = math.ceil(len(obj) / float(self.maxlistlen))
maxlen = int(len(obj) / nlines)
for i in range(0, nlines):
for i in range(0, nlines + 1):
line = []
for item in obj[i * maxlen : (i + 1) * maxlen]:
line.append(json.dumps(item))
output.append(", ".join(line))
if line:
output.append(", ".join(line))
if not retval:
lines = (",\n" + indent_str).join(output) # lines between brackets
if (
Expand Down
18 changes: 16 additions & 2 deletions src/correctionlib/schemav2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Optional, Union

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, StrictInt, StrictStr, validator

try:
from typing import Literal # type: ignore
Expand Down Expand Up @@ -158,7 +158,7 @@ class CategoryItem(Model):
The key type must match the type of the Category input variable
"""

key: Union[int, str]
key: Union[StrictInt, StrictStr]
value: Content


Expand All @@ -172,6 +172,20 @@ class Category(Model):
content: List[CategoryItem]
default: Optional[Content]

@validator("content")
def validate_content(cls, content: List[CategoryItem]) -> List[CategoryItem]:
if len(content):
keytype = type(content[0].key)
if not all(isinstance(item.key, keytype) for item in content):
raise ValueError(
f"Keys in the Category node do not have a homogenous type, expected all {keytype}"
)

keys = {item.key for item in content}
if len(keys) != len(content):
raise ValueError("Duplicate keys detected in Category node")
return content


Transform.update_forward_refs()
Binning.update_forward_refs()
Expand Down
38 changes: 38 additions & 0 deletions tests/test_issue058.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from correctionlib import schemav2 as schema


def test_categoryitem_coercion():
x = schema.CategoryItem(key="30", value=1.0)
assert x.key == "30"
x = schema.CategoryItem(key=30, value=1.0)
assert x.key == 30
with pytest.raises(ValueError):
x = schema.CategoryItem(key=30.0, value=1.0)
with pytest.raises(ValueError):
x = schema.CategoryItem(key=b"30", value=1.0)
x = schema.CategoryItem(key="30xyz", value=1.0)
assert x.key == "30xyz"


def test_cat_valid():
with pytest.raises(ValueError):
schema.Category(
nodetype="category",
input="x",
content=[
schema.CategoryItem(key="30xyz", value=1.0),
schema.CategoryItem(key=30, value=1.0),
],
)

with pytest.raises(ValueError):
schema.Category(
nodetype="category",
input="x",
content=[
schema.CategoryItem(key="30xyz", value=1.0),
schema.CategoryItem(key="30xyz", value=1.0),
],
)
23 changes: 17 additions & 6 deletions tests/test_jsonencoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from correctionlib.JSONEncoder import dumps
from correctionlib.JSONEncoder import dumps, json


def test_jsonencode():
Expand Down Expand Up @@ -125,6 +125,8 @@ def test_jsonencode():
breakbrackets=False,
)

retrieved = json.loads(formatted)

expected = """\
{
"layer1": {
Expand Down Expand Up @@ -168,11 +170,13 @@ def test_jsonencode():
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"
],
[ "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n",
"o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "1", "2"
"o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "1", "2",
"3"
],
[ "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q",
"r", "s", "t", "u", "v", "w", "x", "y", "z", "a", "b", "c", "d", "e", "f", "g", "h",
"i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y"
"i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y",
"z"
],
[ "this is short", "very short" ],
[ "this is medium long", "verily, can you see?" ],
Expand Down Expand Up @@ -206,7 +210,8 @@ def test_jsonencode():
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26
],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
27
],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30
Expand All @@ -223,7 +228,8 @@ def test_jsonencode():
],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52
]
],
"layer3_6": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ],
Expand All @@ -244,4 +250,9 @@ def test_jsonencode():
}
}
}"""
assert formatted == expected, f"Found:\n {formatted}"
assert (
formatted == expected
), f"Formatted does not match expected:\nExpected: {expected}\nFormatted: {formatted}"
assert (
retrieved == data
), f"Data before and after encoding do not match:\nBefore: {data}\nFormatted: {formatted}"

0 comments on commit 4dc8790

Please sign in to comment.