In [331]:
from stringcase import pascalcase, snakecase, camelcase
from json import loads, load
from copy import deepcopy
from dataclasses import dataclass
import re

In [255]:
with open("../swagger.json") as f:
    swagger = load(f)

In [256]:
swagger.keys()

dict_keys(['swagger', 'info', 'basePath', 'paths', 'tags', 'consumes', 'host', 'produces', 'definitions', 'securityDefinitions', 'security'])

In [197]:
from enum import Enum
from typing import NamedTuple, List, Dict, Optional, Tuple, Any

In [483]:
SIGNED_ENDPOINTS = {
    "/announcement/urgent": ["get"],
    "/apiKey": ["get"],
    "/chat": ["post"],
    "/execution": ["get"],
    "/execution/tradeHistory": ["get"],
    "/globalNotification": ["get"],
    "/leaderboard/name": ["get"],
    "/order": ["get", "put", "post", "delete"],
    "/order/bulk": ["put", "post"],
    "/order/closePosition": ["post"],
    "/order/all": ["delete"],
    "/order/cancelAllAfter": ["post"],
    "/position": ["get"],
    "/position/isolate": ["post"],
    "/position/riskLimit": ["post"],
    "/position/transferMargin": ["post"],
    "/position/leverage": ["post"],
    "/user": ["get"],
    "/user/affiliateStatus": ["get"],
    "/user/commission": ["get"],
    "/user/communicationToken": ["post"],
    "/user/executionHistory": ["get"],
    "/user/depositAddress": ["get"],
    "/user/margin": ["get"],
    "/user/preferences": ["post"],
    "/user/quoteFillRatio": ["get"],
    "/user/requestWithdrawal": ["post"],
    "/user/wallet": ["get"],
    "/user/walletHistory": ["get"],
    "/user/walletSummary": ["get"],
    "/userEvent": ["get"],
}

string_formats = {
    "date-time": "DateTime<Utc>",
    "guid": "Uuid",
    None: "String",
    "JSON": "Value",
}

number_formats = {
    "int64": "i64",
    "int32": "i32",
    "double": "f64"
}

# Definitions

In [386]:
def rustify(s: str) -> str:
    s = s.replace("ID", "Id")
    s = snakecase(s)
    if s == "type":
        s = "r#type"
    return s

In [516]:
@dataclass
class FieldDef:
    name: Tuple[str, str]
    ty: str
    optional: bool
    modifiers: Dict[str, str]
    desc: Optional[str] = None

    @classmethod
    def from_swagger(cls, parent_name: str, name: str, tydesc: Dict[str, Any]) -> Tuple[FieldDef, List[StructDef]]:
        common_keys = ["required", "type", "default", "description"]
        if "type" in tydesc:
            ty = tydesc["type"]    

            if ty == "string":
                assert_keys(name, tydesc, *common_keys, "enum", "maxLength", "format")

                if name in provided_enums:
                    rty = f"super::public::{pascalcase(name)}"
                else:
                    rty = string_formats[tydesc.get("format")]

                modifiers = {}
                
                optional = False
                if "required" in tydesc and not tydesc.get("required"):
                    optional = True
                
                desc = tydesc.get("description")
                
                return FieldDef((rustify(name), name), rty, optional, modifiers, desc), []

            elif ty == "number":
                assert_keys(name, tydesc, *common_keys, "format", "minimum")

                modifiers = {}
                if "default" in tydesc and tydesc["default"] == 0:
                    modifiers["default"] = None
                desc = tydesc.get("description")
                
                rty = number_formats[tydesc.get("format")]
                optional = False
                if "required" in tydesc and not tydesc["required"]:
                    optional = True

                return FieldDef((rustify(name), name), rty, optional, modifiers, desc), []
            
            elif ty == "boolean":
                assert_keys(name, tydesc, *common_keys)

                modifiers = {}
                if "default" in tydesc and not tydesc["default"]:
                    modifiers["default"] = None
                    
                optional = False
                if "required" in tydesc and not tydesc["required"]:
                    optional = True
                
                desc = tydesc.get("description")
                
                return FieldDef((rustify(name), name), "bool", optional, modifiers, desc), []
            
            elif ty == "object":
                assert_keys(name, tydesc, *common_keys, "properties")
                sdfs = StructDef.from_swagger(parent_name, name, tydesc)
                
                modifiers = {}
                
                optional = False
                if "required" in tydesc and not tydesc["required"]:
                    optional = True
                
                desc = tydesc.get("description")
                
                if not sdfs:
                    fdf = FieldDef((rustify(name), name), "Value", optional, modifiers, desc) 
                else:
                    fdf = FieldDef((rustify(name), name), sdfs[0].name, optional, modifiers, desc)
                return fdf, sdfs

            elif ty == "array":
                assert_keys(name, tydesc, *common_keys, "items")
                
                items = tydesc["items"]
                
                fdf, sdfs = FieldDef.from_swagger(parent_name, name, items)
                fdf.ty = f"Vec<{fdf.ty}>"
                
                
                if ("default" in tydesc and tydesc["default"] == []) or "default" not in tydesc:
                    fdf.modifiers["default"] = None
                else:
                    raise NotImplementedError(tydesc)
                
                return fdf, sdfs
                
            elif ty == "null":
                return FieldDef((rustify(name), name), "()", False, {}, None), []
            else:
                raise RuntimeError(f"Unimplemented for {ty}")


        elif "$ref" in tydesc:
            assert_keys(name, tydesc, "$ref", *common_keys)
            ref = tydesc["$ref"]

            if ref.startswith("#/definitions/"):
                ty = ref.lstrip("#/definitions/")
                if ty == "x-any":
                    ty = "Value"
                return FieldDef((rustify(name), name), ty, False, {}, None), []
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError(f"{name}, {tydesc}")
            
    def __str__(self) -> str:
        if self.optional:
            ty = f"Option<{self.ty}>"
        else:
            ty = self.ty
        
        mods = []
        if self.name[0] != self.name[1]:
            mods.append(f"rename = \"{self.name[1]}\"")
            
        for mod, modv in fdef.modifiers.items():
            if modv is None:
                mods.append(mod)
            else:
                mods.append(f"{mod} = \"{modv}\"")
        
        serde_header = ""
        if mods:
            mods = ", ".join(mods)
            serde_header = f"#[serde({mods})] "
        field = serde_header + f"""pub {self.name[0]}: {ty}"""
        if self.desc:
            desc = self.desc.replace("\n", " ")
            field = f"""/// {desc}\n""" + field
            
        return field
    
@dataclass
class StructDef:
    name: str
    fields: List[FieldDef]
    modifiers: List[Tuple[str, str]]
    desc: Optional[str] = None
    
    @classmethod
    def from_swagger(cls, parent_name: str, name: str, defs: Dict[str, Any]) -> List[StructDef]:
        assert defs["type"] == "object"
        
        desc = defs.get("description")
        
        sub_sdfs = []
        fdfs = []
        if "properties" not in defs:
            return []
        
        if name == "x-any":
            return []
        
        requires = set()
        
        if "required" in defs:
            requires = set(defs["required"])
            
        for subname, def_ in defs["properties"].items():
            fdf, sdf = FieldDef.from_swagger(f"{parent_name}{pascalcase(subname)}", subname, def_)
            
            sub_sdfs.extend(sdf)
            
            
            if fdf.name[1] not in requires and "default" not in fdf.modifiers:
                fdf.optional = True
                
            fdfs.append(fdf)

        return [StructDef(f"{parent_name}{pascalcase(name)}", fdfs, [], desc), *sub_sdfs]
    
    def __str__(self) -> str:
        if self.desc:
            desc = self.desc.replace("\n", " ")
            desc = f"\n/// {desc}"
        else:
            desc = ""

        fields = [str(fdef) for fdef in self.fields]
        fields = ",\n    ".join(fields)
        
        derive_traits = ["Clone", "Debug", "Deserialize", "Serialize"]
        all_optional = all([fdf.optional or "default" in fdf.modifiers for fdf in self.fields])
        
        if all_optional:
            derive_traits.append("Default")
        derive_traits = ", ".join(derive_traits)
        
        code = f"""
#[derive({derive_traits})]{desc}
pub struct {self.name} {{
    {fields}
}}
    """
        return code
            
class EnumDef(NamedTuple):
    name: str
    variants: List[str]
    originals: List[str]
    desc: Optional[str] = None
    
    @classmethod
    def from_swagger(cls, name: str, defs: Dict[str, Any]) -> EnumDef:
        assert defs["type"] == "string"
        assert set(defs.keys()) == {"enum"}
        
        variants, origins = zip(*[(variant.capitalize(), variant) for variant in defs["enum"]])
        
        return cls(name, variants, originals, None)

In [517]:
defs_ = []

for name, defs in swagger["definitions"].items():
    defs_.extend(StructDef.from_swagger("", name, defs))

In [518]:
with open("../src/models/definitions.rs", "w") as f:
    f.write("""use chrono::{DateTime, Utc};
use serde_json::Value;
use uuid::Uuid;
use serde::{Deserialize, Serialize};
""")
    f.write("\n".join([str(d) for d in defs_]))


# Paths

In [519]:
codes = []

for endpoint, defs in swagger["paths"].items():
    for method, defs in defs.items():
        # Request
        umethod = method.upper()
        cmethod = method.capitalize()
        
        struct_name = pascalcase(endpoint.lstrip("/").replace("/", "_"))
        struct_name = struct_name.replace("_L2", "L2")
        
        desc = defs.get("summary", "No description")
        
        signed = "false"
        if method in SIGNED_ENDPOINTS.get(endpoint, []):
            signed = "true"
            
        if len(defs["parameters"]) == 0:
            has_payload = "false"
        else:
            has_payload = "true"
        
        fields = []
        for tydesc in defs["parameters"]:
            tydesc = deepcopy(tydesc)
            tydesc.pop("in")
            name = tydesc.pop("name")
            
            fdf, sdf = FieldDef.from_swagger("", name, tydesc)
            assert len(sdf) == 0
            
            fields.append(fdf)
        
        derive_traits = ["Clone", "Debug", "Deserialize", "Serialize"]
        all_optional = all([fdf.optional or "default" in fdf.modifiers for fdf in fields])
        
        if all_optional:
            derive_traits.append("Default")
        derive_traits = ", ".join(derive_traits)
        
        fields = ",\n".join([str(f) for f in fields])
        
        req_struct_name = f"{cmethod}{struct_name}Request"
        
        if has_payload == "true":
            req_struct_header = f"""pub struct {req_struct_name} {{
{fields}
}}"""
        else:
            req_struct_header = f"""pub struct {req_struct_name};"""
        
            
        req_def = f"""#[derive({derive_traits})]
/// {desc}
{req_struct_header}
"""
        # Response
        resp_struct_name =  f"{cmethod}{struct_name}Response"
        
        resp_defs = defs["responses"]["200"]["schema"]
        
        fdf, sdfs = FieldDef.from_swagger("", resp_struct_name, resp_defs)
        
        resp_def = "\n".join([str(sdf) for sdf in sdfs])
        
        if not sdfs:
            resp_struct_name = f"{fdf.ty}"
            
        code = f"""{req_def}
{resp_def}

impl Request for {req_struct_name} {{
    const METHOD: Method = Method::{umethod};
    const SIGNED: bool = {signed};
    const ENDPOINT: &'static str = "{endpoint}";
    const HAS_PAYLOAD: bool = {has_payload};
    type Response = {resp_struct_name};
}}
        """
#         print(code)
        codes.append(code)

with open("../src/models/requests.rs", "w") as f:
    f.write("""use http::Method;
use super::Request;
use super::definitions::*;
use serde_json::Value;
use serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
""")
    f.write("\n".join(codes))