# Code Generation

In [None]:
#| default_exp codegen

In [None]:
# | export
from sal.loaders import xml_to_data
from pathlib import Path
from sal.core import Data, iter_data, render, FrontMatter
from jinja2 import Environment, BaseLoader, Template
from typing import Optional, Any
import abc
from jinja2 import StrictUndefined
from textwrap import dedent
from yaml.parser import ParserError
from black import format_str, FileMode

In [None]:
# | hide
import tempfile
from fastcore.test import *
from jinja2 import UndefinedError
import nbdev.showdoc as showdoc

## What code generation means with sal? 

In it's basic form, it will combine xml files converted to `Data` structures, with jinja templates, to render code. Later we will also introduce some frontmatter.

For this, we need a basic structure to work with for generating code. As an example, we'll be working with an hypotetical "model"

In [None]:
struct: Data = xml_to_data(
    """
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
    """
)

...and the basic templates used with this structure are:

In [None]:
model = (
    "class {{ name }}Model(models.Model):\n"
    "    {%- for child in children %}\n"
    "    {{ child | render }}\n"
    "    {%- endfor %}\n"
)

field = "{{ name }} = models.{{ type | title }}Field()"

## Template rendering with jinja2

Let's wrap up the rendering function from the `core` into an usable class

In [None]:
# | export
class JinjaTemplateRenderer:
    def render(self, template = None, **kwargs) -> str:
        if template is None:
            raise RuntimeError("Missing template")
        return render(template, **kwargs)

## Template loading

We will need a way to get the templates

In [None]:
# | export
class TemplateLoader(abc.ABC):
    @abc.abstractmethod
    def get_template(self, name: str) -> str:
        """Separate method to allow an override to the template, before returning"""
        raise NotImplementedError

class MissingTemplate(Exception):
    def __init__(self, name):
        super().__init__(f"The template {name} is missing")
        self.name = name

In [None]:
#| export
class InMemoryTemplateLoader(TemplateLoader):
    """
        Will keep a list of templates names + templates content
    """
    def __init__(self, *args, templates=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.templates = templates

    def get_template(self, name: str):
        if name in self.templates.keys():
            return self.templates[name]
        raise MissingTemplate(name)
        
    @classmethod
    def from_directory(cls, directory):
        path = Path(directory)
        glob = path.glob('*.jinja2')

        templates_raw = {}
        for p in glob:
            model_name = p.name.replace('.jinja2', '')
            with open(p, 'r') as h:
                tpl = h.read()
            templates_raw[model_name] = tpl

        return cls(
            templates=templates_raw
        )

## Tying rendering and loading together

And finally, put these 2 together to form a class to render a `Data` instance

In [None]:
# | export
class Renderer:
    
    # if no template is passed in, we use the DEFAULT_TEMPLATE
    DEFAULT_TEMPLATE = "{% for child in children %}{{ child | render }}{% endfor %}"
    
    def __init__(
        self,
        *,
        renderer: JinjaTemplateRenderer = None,
        repository: TemplateLoader = None,
        filters=None
    ):
        self.renderer = renderer
        self.repository = repository
        self.filters = filters or {}

    def render(self, data: Data, template: Optional[str] = None) -> str:
        if template is None:
            template = self.repository.get_template(data.name)

        return self.renderer.render(
            template=template,
            **data.attrs,
            filters={**self.filters, "render": self.render},
            node=data,
            children=data.children,
        )

    def process(self, data: Data) -> str:
        return self.render(data)

The entry point for this class, after `__init__`, is the `process` method

## Code generator I (jinja only)

Now that we can render `jinja2`, we can make a basic code generator

In [None]:
# | export
class SalBasic:
    def __init__(self, renderer: Optional[Renderer] = None):
        self.renderer = renderer or Renderer()

    def pre_process_data(self, data: Data):
        return data

    def action_default(self, data: Data):
        return self.renderer.process(data)

    def action_to_file(self, data: Data):
        rendered = self.renderer.render(
            data, template=Renderer.DEFAULT_TEMPLATE
        )
        to = data.attrs["to"]
        with open(to, "w") as h:
            h.write(rendered)
        return rendered

    def action_black(self, data: Data):
        rendered = self.renderer.render(
            data, template=Renderer.DEFAULT_TEMPLATE
        )
        print(repr(rendered))
        return format_str(rendered, mode=FileMode())
    
    def process_data(self, data: Data):
        if data.name == "to-file":
            return self.action_to_file(data)
        if data.name == "black":
            return self.action_black(data)
        elif data.name == "wrapper":
            return [self.process(d) for d in data.children]
        else:
            return self.action_default(data)

    def process(self, data: Data):
        data = self.pre_process_data(data)
        return self.process_data(data)

It's important to note that a parent should be able the trigger the rendering of his children (this enures the recursive nature of the template rendering). Look at the `model` template for an example:

In [None]:
model = (
    "class {{ name }}Model(models.Model):\n"
    "    {%- for child in children %}\n"
    "    {{ child | render }}\n"
    "    {%- endfor %}\n"
)

field = "{{ name }} = models.{{ type | title }}Field()"

In [None]:
# | hide
repository = InMemoryTemplateLoader(
    templates={
        "model": model,
        "field": field,
    }
)
template_renderer = Renderer(repository=repository, renderer=JinjaTemplateRenderer())

With this, here's a basic jinja2-based code generator using the hard coded templates:

In [None]:
sal = SalBasic(template_renderer)
print(sal.process(struct.clone()))

class UserModel(models.Model):
    id = models.IntegerField()
    username = models.CharField()
    email = models.EmailField()


In [None]:
# | hide
sal = SalBasic(template_renderer)
test_eq(
    sal.process(struct.clone()).strip(),
    """
class UserModel(models.Model):
    id = models.IntegerField()
    username = models.CharField()
    email = models.EmailField()
""".strip(),
)

**todo: document to-file**

**todo: document black**

**todo: document wrapper**

In [None]:
# | hide
struct2: Data = xml_to_data(
    """
<wrapper>
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
</wrapper>"""
)

sal = SalBasic(template_renderer)
test_eq(
    sal.process(struct2.clone())[0].strip(),
    """
class UserModel(models.Model):
    id = models.IntegerField()
    username = models.CharField()
    email = models.EmailField()
""".strip(),
)

We are missing one more thing, we need to be able to save the result to a file and we'd like to have that info in the xml and not mess with code to get the job done. So, here's a new struct:

In [None]:
destination = tempfile.NamedTemporaryFile()

s_file = xml_to_data(
    f"""
<to-file to="{destination.name}">
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
</to-file>"""
)


sal = SalBasic(template_renderer)
print(sal.process(s_file.clone()))

class UserModel(models.Model):
    id = models.IntegerField()
    username = models.CharField()
    email = models.EmailField()


In [None]:
# | hide

destination = tempfile.NamedTemporaryFile()

s_file = xml_to_data(f"""
<to-file to="{destination.name}">
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
</to-file>""")

sal = SalBasic(template_renderer)
sal.process(s_file)

with open(destination.name, "r") as h:
    test_eq(
        h.read(),
        """
class UserModel(models.Model):
    id = models.IntegerField()
    username = models.CharField()
    email = models.EmailField()
    """.strip(),
    )

## Code generator II (jinja + frontmatter)

To make this even more powerful, we can use `frontmatter` to embed meta data into the templates themself and merge those with the attributes of the node. 

To make it even more powerful, the frontmatter can contain any attribute from the struct so it needs to be extracted in a raw formar, rendered and then extracted. But first, we need new templates..

In [None]:
model = """
---
reference:  "sigla-{{ node.attrs.name | lower }}-model"
---
class {{ name }}Model(models.Model): # {{ reference }}
    {% for child in children -%}
    {{ child | render }}
    {% endfor %}
"""

field = """
---
reference:  "sigla-{{ node.name | lower }}-model"
---
{{ name }} = models.{{ type | title }}Field() 
"""

In [None]:
# | hide
fm = FrontMatter().get_raw_frontmatter(model)
assert fm == 'reference:  "sigla-{{ node.attrs.name | lower }}-model"'

**todo**

In [None]:
# | export

class FrontMatterMixin:
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.frontmatter_handler = FrontMatter()

    def get_template(self, data: Data, frontmatter=False):
        template = super().get_template(data)
        if not frontmatter:
            template = self.frontmatter_handler.get_content(template)
        else:
            template = self.frontmatter_handler.get_raw_frontmatter(template)
        return template

    
class FrontMatterInMemoryTemplateLoader(FrontMatterMixin, InMemoryTemplateLoader):
    pass

**todo**

In [None]:
# | export
class Sal(SalBasic):
    def get_frontmatter_attributes_for_data(self, template: str, data: Data) -> dict:
        rendered = self.renderer.render(data, template)
        parsed = self.renderer.repository.frontmatter_handler.parse(rendered)
        return parsed

    def pre_process_data(self, data: Data):
        for d, _ in iter_data(data):
            
            if d.name in ['to-file', 'black']:
                continue
            
            # load template
            template = self.renderer.repository.get_template(
                d.name, frontmatter=True
            )
            # handle front matter
            new_attributes = self.get_frontmatter_attributes_for_data(template, d)
            
            # update attributes
            d.attrs.update(new_attributes)
        return data

In [None]:
repository = FrontMatterInMemoryTemplateLoader(
    templates={
        "model": model,
        "field": field,
    }
)
template_renderer2 = Renderer(repository=repository, renderer=JinjaTemplateRenderer())

sal = Sal(template_renderer2)
print(sal.process(struct.clone()))

class UserModel(models.Model): # sigla-user-model
    id = models.IntegerField()
    username = models.CharField()
    email = models.EmailField()
    


In [None]:
#| hide
sal = Sal(template_renderer2)
test_eq(
    sal.process(struct.clone()).strip(),
    """
class UserModel(models.Model): # sigla-user-model
    id = models.IntegerField()
    username = models.CharField()
    email = models.EmailField()
""".strip(),
)

In [None]:
# | hide

destination = tempfile.NamedTemporaryFile()

s_file = xml_to_data(
    f"""
<to-file to="{destination.name}">
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
</to-file>"""
)


sal = Sal(template_renderer2)
sal.process(s_file)

with open(destination.name, "r") as h:
    test_eq(
        h.read().strip(),
        """
class UserModel(models.Model): # sigla-user-model
    id = models.IntegerField()
    username = models.CharField()
    email = models.EmailField()
    """.strip(),
    )

In [None]:
xml = xml_to_data("""
<black>
    <W>
        <a/>
        <a/>
        <b/>
    </W>
</black>
""")


w = """
---
---
{%- for i in node|imports|sum(None, [])|unique %}
{{ i }}
{%- endfor %}

class W:
    {%- for child in children %}
    {{ child | render }}
    {%- endfor %}
    
"""


a = """
---
imports: 
    - from AAA import A
---
a = AAA( )
"""
b = """
---
imports: 
    - from BBB import B
---
b = BBB()
"""

repository = FrontMatterInMemoryTemplateLoader(
    templates={
        "W": w,
        "a": a,
        "b": b,
    }
)


def imports(data: Data):
    imports_ = [d.attrs.get('imports') for d, _ in iter_data(data)]
    imports_ = [d for d in imports_ if d]
    return imports_


template_renderer2 = Renderer(repository=repository, renderer=JinjaTemplateRenderer(), filters={'imports': imports})


sal = Sal(template_renderer2)
res = sal.process(xml)
test_eq(res.strip(), """
from AAA import A
from BBB import B


class W:
    a = AAA()
    a = AAA()
    b = BBB()
""".strip())


'\nfrom AAA import A\nfrom BBB import B\n\nclass W:\n    a = AAA( )\n    a = AAA( )\n    b = BBB()'


---

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()