|
| 1 | +"""Generate JSON schema from a given dargs.Argument.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import Any |
| 6 | + |
| 7 | +from dargs.dargs import Argument, _Flags |
| 8 | + |
| 9 | +try: |
| 10 | + from typing import get_origin |
| 11 | +except ImportError: |
| 12 | + from typing_extensions import get_origin |
| 13 | + |
| 14 | + |
| 15 | +def generate_json_schema(argument: Argument, id: str = "") -> dict: |
| 16 | + """Generate JSON schema from a given dargs.Argument. |
| 17 | +
|
| 18 | + Parameters |
| 19 | + ---------- |
| 20 | + argument : Argument |
| 21 | + The argument to generate JSON schema. |
| 22 | + id : str, optional |
| 23 | + The URL of the schema, by default "". |
| 24 | +
|
| 25 | + Returns |
| 26 | + ------- |
| 27 | + dict |
| 28 | + The JSON schema. Use :func:`json.dump` to save it to a file |
| 29 | + or :func:`json.dumps` to get a string. |
| 30 | +
|
| 31 | + Examples |
| 32 | + -------- |
| 33 | + Dump the JSON schema of DeePMD-kit to a file: |
| 34 | +
|
| 35 | + >>> from dargs.json_schema import generate_json_schema |
| 36 | + >>> from deepmd.utils.argcheck import gen_args |
| 37 | + >>> import json |
| 38 | + >>> from dargs import Argument |
| 39 | + >>> a = Argument("DeePMD-kit", dtype=dict, sub_fields=gen_args()) |
| 40 | + >>> schema = generate_json_schema(a) |
| 41 | + >>> with open("deepmd.json", "w") as f: |
| 42 | + ... json.dump(schema, f, indent=2) |
| 43 | + """ |
| 44 | + schema = { |
| 45 | + "$schema": "https://json-schema.org/draft/2020-12/schema", |
| 46 | + "$id": id, |
| 47 | + "title": argument.name, |
| 48 | + **_convert_single_argument(argument), |
| 49 | + } |
| 50 | + return schema |
| 51 | + |
| 52 | + |
| 53 | +def _convert_single_argument(argument: Argument) -> dict: |
| 54 | + """Convert a single argument to JSON schema. |
| 55 | +
|
| 56 | + Parameters |
| 57 | + ---------- |
| 58 | + argument : Argument |
| 59 | + The argument to convert. |
| 60 | +
|
| 61 | + Returns |
| 62 | + ------- |
| 63 | + dict |
| 64 | + The JSON schema of the argument. |
| 65 | + """ |
| 66 | + data = { |
| 67 | + "description": argument.doc, |
| 68 | + "type": list({_convert_types(tt) for tt in argument.dtype}), |
| 69 | + } |
| 70 | + if argument.default is not _Flags.NONE: |
| 71 | + data["default"] = argument.default |
| 72 | + properties = { |
| 73 | + **{ |
| 74 | + nn: _convert_single_argument(aa) |
| 75 | + for aa in argument.sub_fields.values() |
| 76 | + for nn in (aa.name, *aa.alias) |
| 77 | + }, |
| 78 | + **{ |
| 79 | + vv.flag_name: { |
| 80 | + "type": "string", |
| 81 | + "enum": list(vv.choice_dict.keys()) + list(vv.choice_alias.keys()), |
| 82 | + "default": vv.default_tag, |
| 83 | + "description": vv.doc, |
| 84 | + } |
| 85 | + for vv in argument.sub_variants.values() |
| 86 | + }, |
| 87 | + } |
| 88 | + required = [ |
| 89 | + aa.name |
| 90 | + for aa in argument.sub_fields.values() |
| 91 | + if not aa.optional and not aa.alias |
| 92 | + ] + [vv.flag_name for vv in argument.sub_variants.values() if not vv.optional] |
| 93 | + allof = [ |
| 94 | + { |
| 95 | + "if": { |
| 96 | + "oneOf": [ |
| 97 | + { |
| 98 | + "properties": {vv.flag_name: {"const": kkaa}}, |
| 99 | + } |
| 100 | + for kkaa in (kk, *aa.alias) |
| 101 | + ], |
| 102 | + "required": [vv.flag_name] |
| 103 | + if not (vv.optional and vv.default_tag == kk) |
| 104 | + else [], |
| 105 | + }, |
| 106 | + "then": _convert_single_argument(aa), |
| 107 | + } |
| 108 | + for vv in argument.sub_variants.values() |
| 109 | + for kk, aa in vv.choice_dict.items() |
| 110 | + ] |
| 111 | + allof += [ |
| 112 | + {"oneOf": [{"required": [nn]} for nn in (aa.name, *aa.alias)]} |
| 113 | + for aa in argument.sub_fields.values() |
| 114 | + if not aa.optional and aa.alias |
| 115 | + ] |
| 116 | + if not argument.repeat: |
| 117 | + data["properties"] = properties |
| 118 | + data["required"] = required |
| 119 | + if allof: |
| 120 | + data["allOf"] = allof |
| 121 | + else: |
| 122 | + data["items"] = { |
| 123 | + "type": "object", |
| 124 | + "properties": properties, |
| 125 | + "required": required, |
| 126 | + } |
| 127 | + if allof: |
| 128 | + data["items"]["allOf"] = allof |
| 129 | + return data |
| 130 | + |
| 131 | + |
| 132 | +def _convert_types(T: type | Any | None) -> str: |
| 133 | + """Convert a type to JSON schema type. |
| 134 | +
|
| 135 | + Parameters |
| 136 | + ---------- |
| 137 | + T : type | Any | None |
| 138 | + The type to convert. |
| 139 | +
|
| 140 | + Returns |
| 141 | + ------- |
| 142 | + str |
| 143 | + The JSON schema type. |
| 144 | + """ |
| 145 | + # string, number, integer, object, array, boolean, null |
| 146 | + if T is None or T is type(None): |
| 147 | + return "null" |
| 148 | + elif T is str: |
| 149 | + return "string" |
| 150 | + elif T in (int, float): |
| 151 | + return "number" |
| 152 | + elif T is bool: |
| 153 | + return "boolean" |
| 154 | + elif T is list or get_origin(T) is list: |
| 155 | + return "array" |
| 156 | + elif T is dict or get_origin(T) is dict: |
| 157 | + return "object" |
| 158 | + raise ValueError(f"Unknown type: {T}") |
0 commit comments