In [None]:
# | default_exp codegen

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

from fastcore.test import *
from typing import Optional, Generator
from jinja2 import UndefinedError
from sal.loaders import xml_to_data
from sal.utils import files
from textwrap import dedent

import jupyter_black
import tempfile
import nbdev.showdoc as showdoc

jupyter_black.load()

In [None]:
# | export
import abc

from pydantic import BaseModel
from typing import Any, Callable
from dataclasses import dataclass
from pathlib import Path
from textwrap import dedent

from sal.core import Data
from sal.loaders import xml_file_to_data
from sal.templates import (
    Renderer,
    TemplateLoader,
    TemplateRenderer,
    MissingTemplateException,
)

# Code Generation

## What code generation means with sal? 

In it's basic form, it will combine xml files converted to `Data` structures, with jinja templates, to render code. Later we will also introduce some frontmatter.

For this, we need a basic structure to work with for generating code. As an example, we'll be working with an hypotetical "model"

In [None]:
struct: Data = xml_to_data("""
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
""")

...and the basic templates used with this structure are:

In [None]:
model = (
    "class {{ name }}Model(models.Model):\n"
    "    {%- for child in children %}\n"
    "    {{ child | render }}\n"
    "    {%- endfor %}\n"
)

field = "{{ name }} = models.{{ type | title }}Field()"

## Code generator I (jinja only)

Now that we can render `jinja2`, we can make a basic code generator

In [None]:
# | exporti

@dataclass
class WriteFileResult:
    to: str
    content: str

class SalAction(abc.ABC):
    @property
    @abc.abstractmethod
    def name(self) -> str:
        pass

    @abc.abstractmethod
    def process_data(self, sal: "Sal", data: Data) -> tuple[str, WriteFileResult | None]:
        pass

    def __str__(self) -> str:
        return f"action:{self.name}"


class ToFileAction(SalAction):
    name = "to-file"

    def process_data(self, sal: "Sal", data: Data) -> tuple[str, WriteFileResult]:
        rendered = sal.renderer.render(data, template=Renderer.DEFAULT_TEMPLATE)

        if "to" not in data.attrs:
            raise RuntimeError(
                "To save to file you need to define the 'to' attribute with a filepath"
            )
        to = data.attrs["to"]

        return rendered, WriteFileResult(to=to, content=rendered)


class GroupAction(SalAction):
    name = "group"

    def process_data(self, sal: "Sal", data: Data) -> tuple[Any, None]:
        return [sal.process(d) for d in data.children], None

In [None]:
# | export

class Config(BaseModel):
    template_directories: list[Path]
    filters: dict[str, Callable] = {}

class Sal:
    def __init__(self, config: Config, renderer: Renderer):
        self.config = config
        self.renderer = renderer
        self.actions = [ToFileAction(), GroupAction()]  #  ToStringAction()
        self.action_results: list[WriteFileResult] = []

    @property
    def action_names(self) -> list[str]:
        return [action.name for action in self.actions]

    def pre_process_data(self, data: Data) -> Data:
        for d, _ in data:
            if d.name in self.action_names:
                continue
            new_attributes = self.renderer.get_metadata_for_template(d.name, d)
            d.attrs.update(new_attributes)
        return data

    def process_data(self, data: Data) -> str | Any:
        try:
            for action in self.actions:
                if data.name == action.name:
                    ret, action_result = action.process_data(self, data)
                    if action_result:
                        self.action_results.append(action_result)
                    return ret
            return self.renderer.process(data)
        except MissingTemplateException as e:
            path = Path(self.config.template_directories[0]) / f"{e.name}.jinja2"
            raise RuntimeError(
                dedent(f"""
                The template `{e.name}` was not found. Here's a default template to 
                get you started:
                
                {self.renderer.DEFAULT_TEMPLATE}

                ---
                at: {path}

                """).strip()
            )

    def process_xml_from_filename(self, file: str) -> str | Any:
        struct: Data = xml_file_to_data(file)
        return self.process(struct)

    def process_action_results(self) -> None:
        for action_result in self.action_results:
            if isinstance(action_result, WriteFileResult):
                print("writing to {result.to}: '{content}'")
                with open(action_result.to, "w") as h:
                    h.write(action_result.content)
            else:
                raise RuntimeError(f"Unsupported action {action_result}")

    def process(self, data: Data) -> str | Any:
        return self._process(data)

    # TODO support snapshots
    def _process(self, data: Data) -> str | Any:
        data = self.pre_process_data(data)
        result = self.process_data(data)
        self.process_action_results()
        return result

    @classmethod
    def from_config(
        cls,
        *,
        template_directories: list[Path],
        filters: dict[str, Callable] | None = None
    ) -> "Sal":
        config = Config(template_directories=template_directories, filters=filters or {})

        repository = TemplateLoader.from_directories(config.template_directories)

        template_renderer = Renderer(
            repository=repository, renderer=TemplateRenderer(), filters=config.filters
        )
        return cls(config, template_renderer)

It's important to note that a parent should be able the trigger the rendering of his children (this enures the recursive nature of the template rendering). Look at the `model` template for an example:

In [None]:
model = (
    "class {{ name }}Model(models.Model):\n"
    "    {%- for child in children %}\n"
    "    {{ child | render }}\n"
    "    {%- endfor %}\n"
)

field = "{{ name }} = models.{{ type | title }}Field()"

In [None]:
# | hide


with files(
    {
        "/tmp/templates/model.jinja2": model,
        "/tmp/templates/field.jinja2": field,
    }
):
    sal = Sal.from_config(template_directories=["/tmp/templates"])

    test_eq(
        sal.process(struct.clone()).strip(),
        dedent(
            """
    class UserModel(models.Model):
        id = models.IntegerField()
        username = models.CharField()
        email = models.EmailField()
    """
        ).strip(),
    )


struct2: Data = xml_to_data(
    """
<group>
<group>
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
</group>
</group>"""
)


with files(
    {
        "/tmp/templates/model.jinja2": model,
        "/tmp/templates/field.jinja2": field,
    }
):
    sal = Sal.from_config(template_directories=["/tmp/templates"])
    test_eq(
        sal.process(struct2.clone())[0][0].strip(),
        dedent(
            """
    class UserModel(models.Model):
        id = models.IntegerField()
        username = models.CharField()
        email = models.EmailField()
    """
        ).strip(),
    )

**todo: document to-file**

**todo: document group>**

We are missing one more thing, we need to be able to save the result to a file and we'd like to have that info in the xml and not mess with code to get the job done. So, here's a new struct:

In [None]:
destination = tempfile.NamedTemporaryFile()

s_file = xml_to_data(
    f"""
<to-file to="{destination.name}">
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
</to-file>"""
)

with files(
    {
        "/tmp/templates/model.jinja2": model,
        "/tmp/templates/field.jinja2": field,
    }
):
    sal = Sal.from_config(template_directories=["/tmp/templates"])
    print(sal.process(s_file.clone()))

In [None]:
# | hide

destination = tempfile.NamedTemporaryFile()

s_file = xml_to_data(
    f"""
<group>
<to-file to="{destination.name}">
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
</to-file>
</group>"""
)

with files(
    {
        "/tmp/templates/model.jinja2": model,
        "/tmp/templates/field.jinja2": field,
    }
):
    sal = Sal.from_config(template_directories=["/tmp/templates"])
    result = sal.process(s_file)

with open(destination.name, "r") as h:
    test_eq(
        h.read(),
        """
class UserModel(models.Model):
    id = models.IntegerField()
    username = models.CharField()
    email = models.EmailField()
    """.strip(),
    )

> To make this even more powerful, we can use `frontmatter` to embed meta data into the templates themself and merge those with the attributes of the node. 

> To make it even more powerful, the frontmatter can contain any attribute from the struct so it needs to be extracted in a raw formar, rendered and then extracted. But first, we need new templates..

In [None]:
model = """
---
reference:  "sigla-{{ node.attrs.name | lower }}-model"
---
class {{ name }}Model(models.Model): # {{ reference }}
    {% for child in children -%}
    {{ child | render }}
    {% endfor %}
"""

field = """
---
reference:  "sigla-{{ node.name | lower }}-model"
---
{{ name }} = models.{{ type | title }}Field() 
"""


with files(
    {
        "/tmp/templates/model.jinja2": model,
        "/tmp/templates/field.jinja2": field,
    }
):
    sal = Sal.from_config(template_directories=["/tmp/templates"])
    test_eq(
        sal.process(struct.clone()).strip(),
        dedent(
            """
    class UserModel(models.Model): # sigla-user-model
        id = models.IntegerField()
        username = models.CharField()
        email = models.EmailField()
    """
        ).strip(),
    )

In [None]:
# | hide

destination = tempfile.NamedTemporaryFile()
s_file = xml_to_data(
    f"""
    <group>
<to-file to="{destination.name}">
    <model name="User">
        <field name="id" type="integer"/>
        <field name="username" type="char"/>
        <field name="email" type="email"/>
    </model>
</to-file>
    </group>
"""
)

# sal = Sal(template_renderer2)
# sal.pre_process_data(s_file.clone())

In [None]:
with files(
    {
        "/tmp/templates/model.jinja2": model,
        "/tmp/templates/field.jinja2": field,
    }
):
    sal = Sal.from_config(template_directories=["/tmp/templates"])
    sal.process(s_file)

    with open(destination.name, "r") as h:
        test_eq(
            h.read().strip(),
            dedent(
                """
    class UserModel(models.Model): # sigla-user-model
        id = models.IntegerField()
        username = models.CharField()
        email = models.EmailField()
        """
            ).strip(),
        )

In [None]:
xml = xml_to_data(
    """
<to-file to="/tmp/results.txt">
    <W>
        <a/>
        <a/>
        <b/>
    </W>
</to-file>
"""
)


w = """
---
---
{%- for i in node|imports|sum(None, [])|unique %}
{{ i }}
{%- endfor %}


class W:
    {%- for child in children %}
    {{ child | render }}
    {%- endfor %}
    
"""


a = """
---
imports: 
    - from AAA import A
---
a = AAA()
"""

b = """
---
imports: 
    - from BBB import B
---
b = BBB()
"""

with files(
    {
        "/tmp/templates/W.jinja2": w,
        "/tmp/templates/a.jinja2": a,
        "/tmp/templates/b.jinja2": b,
        "/tmp/results.txt": " ",
    }
):

    def imports(data: Data):
        imports_ = [d.attrs.get("imports") for d, _ in data]
        imports_ = [d for d in imports_ if d]
        return imports_

    sal = Sal.from_config(
        template_directories=["/tmp/templates"], filters={"imports": imports}
    )
    res = sal.process(xml)

    assert (
        res.strip()
        == dedent(
            """
    from AAA import A
    from BBB import B


    class W:
        a = AAA()
        a = AAA()
        b = BBB()
    """
        ).strip()
    )

---

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()