In [72]:
from __future__ import annotations
from typing import Any, Union, Optional, Literal


_VARIABLE = "variable"
_DATA = "data"
_MODULE = "module"
_RESOURCE = "resource"
_PROPERTY = "property"
_MAP = "map"
_OUTPUT = "output"
_PROVIDER = "provider"
_GROUP_VALENCE = {
    _VARIABLE: 1,
    _DATA: 2,
    _MODULE: 1,
    _RESOURCE: 2,
    _PROPERTY: 1,
    _MAP: 0,
    _OUTPUT: 1,
    _PROVIDER: 1,
}
_GROUPS = list(_GROUP_VALENCE.keys())


class BlockError(Exception):
    """
    Exception to return due to issues defining blocks
    """


class Caller:
    def __init__(self, base: Any, call: str):
        self.base = base
        self.call = call

    def __str__(self) -> str:
        return f"{str(self.base)}.{self.call}"

    def __repr__(self) -> str:
        return self.__str__()


class Block:
    _tab_space = "    "
    _group_abbrv = {_VARIABLE: "var", _MAP: ""}

    def __init__(
        self,
        _group: str,
        *args: str,
        _tomap: bool = True,
        invisible_map: bool = False,
        **kwargs: Union[str, int, float, Block, bool, list, dict],
    ):
        self._group = _group
        self.group, self.group_abbrv, self.ids = self._group_id_reprs(_group, args, invisible_map)
        self.invisible_map = invisible_map
        self.properties = kwargs
        self._max_elements = 4
        self._tomap = _tomap
        self.dependencies = set()
        self._format_props()  # needs to run on start to capture all dependencies

    def _group_id_reprs(self, s: str, ids: tuple[str], invisible_map: bool = False) -> tuple[str, str, tuple[str]]:
        if s == _MAP:
            if not ids:
                return "", "", ids
            elif invisible_map:
                return ids[0], "", ""
        elif s == _PROPERTY:
            return ids[0], ids[0], ids[1:]
        return s, self._group_abbrv.get(s, s), ids

    def _write_ids(self) -> str:
        return " ".join([self.group] + [f'"{id}"' for id in self.ids]).strip() + " "

    def _write(self, comment: str = "", pad: int = 0) -> str:
        """
        Creates a string block that translates Block to Terraform
        """
        invis_map_insert = "= " if (self.invisible_map) else ""
        lines = [f"#{comment}"] if comment else []
        lines.append(self._write_ids() + invis_map_insert + "{")
        lines += self._format_props(pad)
        lines.append("}")
        return "\n".join(map(lambda s: pad * self._tab_space + s, lines))

    def _format_props(self, pad: int = 0) -> list[str]:
        """
        Returns a list of parameter lines
        """
        max_len = max(len(k) for k in self.properties.keys())
        basic_params = []
        property_params = []
        for k, v in self.properties.items():
            if isinstance(v, Block):
                self._add_dependencies(v)
            elif isinstance(v, Caller):
                self._add_dependencies(v.base)
            if isinstance(v, Block):
                property_params.append(v._write(pad=pad + 1))
            else:
                base_string = f"{self._tab_space}{k}{' ' * (max_len - len(k))} = "
                if isinstance(v, dict):
                    map_rep = self._map_rep(v)
                    basic_params.append(base_string + map_rep[0])
                    basic_params += list(
                        map(lambda b: len(base_string) * " " + b, map_rep[1:])
                    )
                elif isinstance(v, tuple or list):
                    basic_params.append(
                        base_string
                        + "tolist(["
                        + ",".join(self._parse(_v) for _v in v)
                        + "])"
                    )
                else:
                    basic_params.append(base_string + self._parse(v))
        return basic_params + property_params

    def __str__(self) -> str:
        return self.__repr__()

    def __repr__(self) -> str:
        return ".".join([self.group_abbrv] + list(self.ids)).strip(".")

    def _parse(self, s: Union[Caller, int, float, str, bool]) -> str:
        if isinstance(s, Caller):
            return str(s)
        elif isinstance(s, int or float):
            return str(s)
        elif isinstance(s, bool):
            return str(s).lower()
        else:
            return f'"{s}"'

    def __getitem__(self, attribute: str) -> str:
        return Caller(self, attribute)

    def _map_rep(self, d: dict, pad=0) -> list[str]:
        dummy_block = Block(_MAP, **d)
        if self._tomap:
            base = dummy_block._format_props(pad=pad + 6)
            base[0] = "tomap(" + base[0]
            base[-1] = base[-1] + ")"
        else:
            base = dummy_block._format_props(pad=pad)
        if len(base) < self._max_elements:
            base = [base[0] + ",".join(base[1:-1]) + base[-1]]
        return base

    def _validate(self):
        # TODO: Validate new blocks exist and have all required properties
        group_valence = _GROUP_VALENCE.get(self._group, len(self.ids))
        if group_valence != len(self.ids):
            raise BlockError(
                f'Group "{self._group}" requires valence {group_valence}, but got {len(self.ids)}.'
            )

    def _add_dependencies(self, v: Block):
        for sub in [v] + list(v.dependencies):
            if sub._group in [_VARIABLE, _DATA, _MODULE, _RESOURCE, _OUTPUT]:
                self.dependencies.add(v)


In [81]:
from itertools import filterfalse
from typing import Union, Optional
import functools
import os
from argparse import ArgumentParser



class DependencyError(Exception):
    """
    Exception to return due to issues resolving dependencies
    """


def resolve_dependencies(
    dependency_map: dict[str, set[str]], base_layer: set[str] = set()
) -> list[set[str]]:
    layers = []
    total_blocks = set(dependency_map.keys())
    while base_layer != total_blocks:
        next_layer = set(
            block
            for block in dependency_map
            if dependency_map[block].issubset(base_layer)
        )
        current_layer = next_layer - base_layer
        layers.append(current_layer)
        base_layer = next_layer
        if not current_layer:
            raise DependencyError(
                "Unable to resolve dependencies, ensure they are consistent and non-circular."
            )
    return layers


class Registry(dict):
    def __init__(self):
        super(Registry, self).__init__()

    def __setitem__(self, block_id: str, block: Block):
        if isinstance(block, Block):
            super(Registry, self).__setitem__(block_id, block)
        return self

    def __getitem__(self, block_id: str) -> Block:
        block = super(Registry, self).get(block_id, None)
        if block:
            return block
        else:
            raise BlockError(f"Block {block_id} is not registered in the registry.")

    def __str__(self):
        return str(list(super(Registry, self).keys()))

    def __repr__(self):
        return self.__str__()


class Group:
    _ignore_duplicates = False

    def __init__(self, group: str, registry: Registry):
        self.group = group
        self.registry = registry
        self.blocks = {}

    def __getitem__(self, block_id: str) -> Block:
        return self.blocks[block_id]

    def _update_tracking_and_return(self, new_block: Block):
        self.blocks[str(new_block)] = new_block
        if (str(new_block) not in self.registry) or self._ignore_duplicates:
            self.registry[str(new_block)] = new_block
        else:
            raise BlockError(
                f"Block {new_block} is already registered in the block registry."
            )
        return new_block

    def __call__(
        self,
        *ids,
        **kwargs: Optional[Union[str, int, float, Block, bool, list, dict]],
    ) -> Block:
        new_block = Block(self.group, *ids, **kwargs)
        return self._update_tracking_and_return(new_block)


class Providers:
    _ignore_duplicates = False
    
    def __init__(self, registry: Registry):
        self.registry = registry
        self.blocks = {}
        self.provider_options = {}
    
    def __getitem__(self, provider_block_id: str) -> Block:
        return self.provider_blocks[provider_block_id]

    def _update_tracking(self, provider: str, source: Optional[str] = None, version: Optional[str] = None, **kwargs):
        provider_block = Block(_PROVIDER, provider, **kwargs)
        required_provider_params = {}
        if source:
            required_provider_params["source"] = source
        if version:
            required_provider_params["version"] = version
        provider_options = Block(_MAP, provider, _tomap=False, invisible_map = True, **required_provider_params)
        if (str(provider_block) not in self.registry) or self._ignore_duplicates:
            self.registry[str(provider_block)] = provider_block
        else:
            raise BlockError(
                f"Provider {provider} is already registered in the block registry."
            )
        self.provider_options[provider] = provider_options
    
    def add(self, provider: str, source: Optional[str] = None, version: Optional[str] = None, **kwargs):
        self._update_tracking(provider, source, version, **kwargs)
    
    def build_provider(self):
        return Block("property", "terraform", _tomap=False, required_providers=Block(_PROPERTY, "required_providers", **self.provider_options))


class MetaFormer:
    def __init__(
        self,
        name: str = "main",
        isolate_module: bool = False,
        split_out: bool = False,
        registry: Optional[Registry] = None,
    ):
        if registry is not None:
            self.registry = registry
        else:
            self.registry = Registry()
        self.name = name
        self.isolate_module = isolate_module
        self._create_groups()

        # shortened aliases
        self.dat = self.data
        self.res = self.resource
        self.mod = self.module
        self.var = self.variable
        self.prop = self.property
        self.prov = self.provider

    def _create_groups(self):
        self.data = Group(_DATA, self.registry)
        self.resource = Group(_RESOURCE, self.registry)
        self.module = Group(_MODULE, self.registry)
        self.variable = Group(_VARIABLE, self.registry)
        self.property = Group(_PROPERTY, self.registry)
        self.output = Group(_OUTPUT, self.registry)
        self.provider = Providers(self.registry)
        return self

    def _clear_registry(self):
        self.registry = Registry()
        return self
    
    def import_provider(self, provider: str, source: Optional[str] = None, version: Optional[str] = None, **kwargs):
        self.provider.add(provider, source, version, **kwargs)

    def _collect_dependencies(self) -> dict[str, list[str]]:
        return {
            block_id: {str(dep_block) for dep_block in block.dependencies}
            for block_id, block in self.registry.items()
            if block._group != _PROPERTY
        }

    def _resolve_dependencies(self) -> list[set[str]]:
        return resolve_dependencies(self._collect_dependencies())

    def collect(self) -> list[Block]:
        dependencies = self._resolve_dependencies()
        return [self.provider.build_provider()] + functools.reduce(
            lambda m, n: self._sort(m) + self._sort(n), dependencies, []
        )

    def _sort(self, deps: set[str]) -> list[Block]:
        """
        Sort the dependencies putting in the following order:
            VARIABLES -> DATA -> RESOURCES -> MODULES -> OUTPUTS
        """
        deps = [self.registry[str(block_id)] for block_id in deps]
        return (
            list(filter(lambda b: b._group == _PROPERTY, deps))
            + list(filter(lambda b: b._group == _VARIABLE, deps))
            + list(filter(lambda b: b._group == _DATA, deps))
            + list(filter(lambda b: b._group == _RESOURCE, deps))
            + list(filter(lambda b: b._group == _MODULE, deps))
            + list(filter(lambda b: b._group == _OUTPUT, deps))
            + list(
                filter(
                    lambda b: b._group
                    not in [_VARIABLE, _DATA, _RESOURCE, _MODULE, _OUTPUT],
                    deps,
                )
            )
        )

    def _write(self):
        """
        Return the contents of the MetaForm object as a string
        """
        return "\n\n".join(block._write() for block in self.collect())

    def build(self):
        """
        Build out the new terraform scripts from the metaform commands
        """
        if self.isolate_module:
            main_path = os.path.join(os.path.realpath("__main__"), self.name)
            os.mkdir(main_path)
            with open(os.path.join(main_path, "main.tf"), "w") as f:
                f.write(self._write())
        else:
            with open(f"{self.name}.tf", "w") as f:
                f.write(self._write())


In [82]:
tf = MetaFormer()
tf.provider.add("aws", "hashicorp/aws", "~> 4.0", region="us-east-1a")

In [83]:
print(tf.registry['provider.aws']._write())

provider "aws" {
    region = "us-east-1a"
}


In [84]:
print(tf._write())

terraform {
    required_providers {
            aws = {
            source  = "hashicorp/aws"
            version = "~> 4.0"
        }
    }
}

provider "aws" {
    region = "us-east-1a"
}
