Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow custom templates for API and endpoint __init__ files. #442

Merged
merged 11 commits into from
Jun 28, 2021
1 change: 1 addition & 0 deletions end_to_end_tests/custom-templates-golden-record/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
my-test-api-client
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
""" Contains methods for accessing the API """

from typing import Type

from my_test_api_client.api.default import DefaultEndpoints
from my_test_api_client.api.parameters import ParametersEndpoints
from my_test_api_client.api.tests import TestsEndpoints


class MyTestApiClientApi:
@classmethod
def tests(cls) -> Type[TestsEndpoints]:
return TestsEndpoints

@classmethod
def default(cls) -> Type[DefaultEndpoints]:
return DefaultEndpoints

@classmethod
def parameters(cls) -> Type[ParametersEndpoints]:
return ParametersEndpoints
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
""" Contains methods for accessing the API Endpoints """

import types

from my_test_api_client.api.default import get_common_parameters, post_common_parameters


class DefaultEndpoints:
@classmethod
def get_common_parameters(cls) -> types.ModuleType:
return get_common_parameters

@classmethod
def post_common_parameters(cls) -> types.ModuleType:
return post_common_parameters
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
""" Contains methods for accessing the API Endpoints """

import types

from my_test_api_client.api.parameters import get_same_name_multiple_locations_param


class ParametersEndpoints:
@classmethod
def get_same_name_multiple_locations_param(cls) -> types.ModuleType:
return get_same_name_multiple_locations_param
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
""" Contains methods for accessing the API Endpoints """

import types

from my_test_api_client.api.tests import (
defaults_tests_defaults_post,
get_basic_list_of_booleans,
get_basic_list_of_floats,
get_basic_list_of_integers,
get_basic_list_of_strings,
get_user_list,
int_enum_tests_int_enum_post,
json_body_tests_json_body_post,
no_response_tests_no_response_get,
octet_stream_tests_octet_stream_get,
optional_value_tests_optional_query_param,
post_form_data,
test_inline_objects,
token_with_cookie_auth_token_with_cookie_get,
unsupported_content_tests_unsupported_content_get,
upload_file_tests_upload_post,
)


class TestsEndpoints:
@classmethod
def get_user_list(cls) -> types.ModuleType:
"""
Get a list of things
"""
return get_user_list

@classmethod
def get_basic_list_of_strings(cls) -> types.ModuleType:
"""
Get a list of strings
"""
return get_basic_list_of_strings

@classmethod
def get_basic_list_of_integers(cls) -> types.ModuleType:
"""
Get a list of integers
"""
return get_basic_list_of_integers

@classmethod
def get_basic_list_of_floats(cls) -> types.ModuleType:
"""
Get a list of floats
"""
return get_basic_list_of_floats

@classmethod
def get_basic_list_of_booleans(cls) -> types.ModuleType:
"""
Get a list of booleans
"""
return get_basic_list_of_booleans

@classmethod
def post_form_data(cls) -> types.ModuleType:
"""
Post form data
"""
return post_form_data

@classmethod
def upload_file_tests_upload_post(cls) -> types.ModuleType:
"""
Upload a file
"""
return upload_file_tests_upload_post

@classmethod
def json_body_tests_json_body_post(cls) -> types.ModuleType:
"""
Try sending a JSON body
"""
return json_body_tests_json_body_post

@classmethod
def defaults_tests_defaults_post(cls) -> types.ModuleType:
"""
Defaults
"""
return defaults_tests_defaults_post

@classmethod
def octet_stream_tests_octet_stream_get(cls) -> types.ModuleType:
"""
Octet Stream
"""
return octet_stream_tests_octet_stream_get

@classmethod
def no_response_tests_no_response_get(cls) -> types.ModuleType:
"""
No Response
"""
return no_response_tests_no_response_get

@classmethod
def unsupported_content_tests_unsupported_content_get(cls) -> types.ModuleType:
"""
Unsupported Content
"""
return unsupported_content_tests_unsupported_content_get

@classmethod
def int_enum_tests_int_enum_post(cls) -> types.ModuleType:
"""
Int Enum
"""
return int_enum_tests_int_enum_post

@classmethod
def test_inline_objects(cls) -> types.ModuleType:
"""
Test Inline Objects
"""
return test_inline_objects

@classmethod
def optional_value_tests_optional_query_param(cls) -> types.ModuleType:
"""
Test optional query parameters
"""
return optional_value_tests_optional_query_param

@classmethod
def token_with_cookie_auth_token_with_cookie_get(cls) -> types.ModuleType:
"""
Test optional cookie parameters
"""
return token_with_cookie_auth_token_with_cookie_get
55 changes: 54 additions & 1 deletion end_to_end_tests/regen_golden_record.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
""" Regenerate golden-record """
import filecmp
import os
import shutil
import tempfile
from pathlib import Path

from typer.testing import CliRunner

from openapi_python_client.cli import app

if __name__ == "__main__":

def regen_golden_record():
runner = CliRunner()
openapi_path = Path(__file__).parent / "openapi.json"

Expand All @@ -24,3 +28,52 @@
if result.exception:
raise result.exception
output_path.rename(gr_path)


def regen_custom_template_golden_record():
runner = CliRunner()
openapi_path = Path(__file__).parent / "openapi.json"
tpl_dir = Path(__file__).parent / "test_custom_templates"

gr_path = Path(__file__).parent / "golden-record"
tpl_gr_path = Path(__file__).parent / "custom-templates-golden-record"

output_path = Path(tempfile.mkdtemp())
config_path = Path(__file__).parent / "config.yml"

shutil.rmtree(tpl_gr_path, ignore_errors=True)

os.chdir(str(output_path.absolute()))
result = runner.invoke(
app, ["generate", f"--config={config_path}", f"--path={openapi_path}", f"--custom-template-path={tpl_dir}"]
)

if result.stdout:
generated_output_path = output_path / "my-test-api-client"
for f in generated_output_path.glob("**/*"): # nb: works for Windows and Unix
relative_to_generated = f.relative_to(generated_output_path)
gr_file = gr_path / relative_to_generated
if not gr_file.exists():
print(f"{gr_file} does not exist, ignoring")
continue

if not gr_file.is_file():
continue

if not filecmp.cmp(gr_file, f, shallow=False):
target_file = tpl_gr_path / relative_to_generated
target_dir = target_file.parent

target_dir.mkdir(parents=True, exist_ok=True)
shutil.copy(f"{f}", f"{target_file}")

shutil.rmtree(output_path, ignore_errors=True)

if result.exception:
shutil.rmtree(output_path, ignore_errors=True)
raise result.exception


if __name__ == "__main__":
regen_golden_record()
regen_custom_template_golden_record()
13 changes: 13 additions & 0 deletions end_to_end_tests/test_custom_templates/api_init.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
""" Contains methods for accessing the API """

from typing import Type
{% for tag in endpoint_collections_by_tag.keys() %}
from {{ package_name }}.api.{{ tag }} import {{ utils.pascal_case(tag) }}Endpoints
{% endfor %}

class {{ utils.pascal_case(package_name) }}Api:
{% for tag in endpoint_collections_by_tag.keys() %}
@classmethod
def {{ tag }}(cls) -> Type[{{ utils.pascal_case(tag) }}Endpoints]:
return {{ utils.pascal_case(tag) }}Endpoints
{% endfor %}
24 changes: 24 additions & 0 deletions end_to_end_tests/test_custom_templates/endpoint_init.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
""" Contains methods for accessing the API Endpoints """

import types
{% for endpoint in endpoint_collection.endpoints %}
from {{ package_name }}.api.{{ endpoint_collection.tag }} import {{ utils.snake_case(endpoint.name) }}
{% endfor %}

class {{ utils.pascal_case(endpoint_collection.tag) }}Endpoints:

{% for endpoint in endpoint_collection.endpoints %}

@classmethod
def {{ utils.snake_case(endpoint.name) }}(cls) -> types.ModuleType:
{% if endpoint.description %}
"""
{{ endpoint.description }}
"""
{% elif endpoint.summary %}
"""
{{ endpoint.summary }}
"""
{% endif %}
return {{ utils.snake_case(endpoint.name) }}
{% endfor %}
Loading