Skip to content

Commit

Permalink
Update JSON schema. (#5982)
Browse files Browse the repository at this point in the history
* Update JSON schema for pseudo huber.
* Update JSON model schema.
  • Loading branch information
trivialfis committed Aug 5, 2020
1 parent 9c93531 commit 8599f87
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
46 changes: 46 additions & 0 deletions doc/model.schema
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,17 @@
}
}
},
"aft_loss_param": {
"type": "object",
"properties": {
"aft_loss_distribution": {
"type": "string"
},
"aft_loss_distribution_scale": {
"type": "string"
}
}
},
"softmax_multiclass_param": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -273,6 +284,17 @@
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "reg:pseudohubererror" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
Expand All @@ -284,6 +306,17 @@
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "reg:linear" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
Expand Down Expand Up @@ -420,6 +453,19 @@
"name",
"lambda_rank_param"
]
},
{
"type": "object",
"properties": {
"name": {"const": "survival:aft"},
"aft_loss_param": { "$ref": "#/definitions/aft_loss_param"}
}
},
{
"type": "object",
"properties": {
"name": {"const": "binary:hinge"}
}
}
]
},
Expand Down
19 changes: 19 additions & 0 deletions tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,25 @@ def test_json_io_schema(self):
schema=schema)
os.remove(model_path)

try:
xgb.train({'objective': 'foo'}, dtrain, num_boost_round=1)
except ValueError as e:
e_str = str(e)
beg = e_str.find('Objective candidate')
end = e_str.find('Stack trace')
e_str = e_str[beg: end]
e_str = e_str.strip()
splited = e_str.splitlines()
objectives = [s.split(': ')[1] for s in splited]
j_objectives = schema['properties']['learner']['properties'][
'objective']['oneOf']
objectives_from_schema = set()
for j_obj in j_objectives:
objectives_from_schema.add(
j_obj['properties']['name']['const'])
objectives = set(objectives)
assert objectives == objectives_from_schema

@pytest.mark.skipif(**tm.no_json_schema())
def test_json_dump_schema(self):
import jsonschema
Expand Down

0 comments on commit 8599f87

Please sign in to comment.