# Forgather

A notebook for experimenting with Forgather's syntax.

In [1]:
import sys, os
modules_path = os.path.join('..', 'src')
if modules_path not in sys.path: sys.path.insert(0, modules_path)

from pprint import pp, pformat

from IPython import display as ds

from forgather.latent import Latent
from forgather.config import ConfigEnvironment
from forgather.preprocess import PPEnvironment
from forgather.codegen import generate_code
from forgather.yaml_encoder import to_yaml
import forgather.nb.notebooks as nb

# Show common syntax definition.
with open(os.path.join('..', 'docs', 'syntax.md'), 'r') as f:
    display(ds.Markdown(f.read()))

# Forgather Syntax Reference

Forgather defines a domain-specific language for the dynamic construciton of Python objects using a combination of Jinja2, YAML, and a few extensions.

This guide will focus on the extensions to these languages. For details on YAML and Jinja2, see:

- [Jinja2 Template Designer Documentation](https://jinja.palletsprojects.com/en/3.1.x/templates/)
- [YAML 1.1](https://yaml.org/spec/1.1/)
- [PyYAML Documentation](https://pyyaml.org/wiki/PyYAMLDocumentation)

## Jinja2 Extensions
---
### The Preprocessor

There is a custom Jinja2 preprocessor which implemnts an extended version of Jinja2's [Line Statements](https://jinja.palletsprojects.com/en/3.1.x/templates/#line-statements). These are implemented via regex substition, where the match is converted to normal Jinja syntax.


- \#\# : Line Comment
- \-\- : Line Statement
- << : Line Statement w/ left-trim
- \>> : Line Statement w/ right-trim
- == : Print Command
- '=>' : Print Command w/ right-trim

Example Input:

```jinja2
## If 'do_loop' is True, then output a list of numbers.
-- if do_loop:
    -- for i in range(how_many): ## Loop 'how_many' times.
        == '- ' + i|string
    -- endfor
<< endif
```

Is translated to:

```jinja2
{# If 'do_loop' is True, then output a list of numbers. #}
{% if do_loop: %}
{% for i in range(how_many): %}
{{ '- ' + i|string }}
{% endfor %}
{%- endif %}
```

Output, when passed: do_loop=True, how_many=3
```yaml
- 0
- 1
- 2

```


Normal Jinja2 syntax works just fine too. I just find that the normal syntax is visually difficult to parse (without syntax-highlighting) and is awkward to type.

More Formally

```python
line_comment = r'(.*)\s+#{2,}.*'
line_statement = r'\s*(--|<<|>>|==|=>)\s(.*)'

Substitutions:
{
    '--': r"{% " + re_match[2] + r" %}
    '<<': r"{%- " + re_match[2] + r" %}"
    '>>': r"{% " + re_match[2] + r" -%}"
    '==': r"{{ " + re_match[2] + r" }}"
    '=>': r"{{ " + re_match[2] + r" -}}"
}
```

---
### Jinja2 Globals

A number of globals have been introduced to the Jinja2 environment to assist with pre-processing.

- isotime() : Returns ISO formatted local-time, with 1-second resolution ("%Y-%m-%dT%H:%M:%S")
- utcisotime() : As with isotime(), but UTC time.
- filetime(): Generates a local-time string suitable to be concatenated with a file-name. ("%Y-%m-%dT%H-%M-%S")
- utcfiletime() : As filetime(), but in UTC time.
- now() : Get datetime.datetime.now()
- utcnow() : Get datetime.datetime.utcnow()
- joinpath(*names) : Join a list of file-path segments via os.path.join()
- normpath(path) : Normalize a file path; os.path.normpath()
- abspath(path) : Convert path to absolute path; os.path.abspath()
- relpath(path) : Convert a path to a relative path; os.path.relpath()
- repr(obj) : Get Python representation of object; repr()
- modname_from_path(module_name) : Given a module file path, return the module name
- user_home_dir() : Return absolute path of user's home directory  
- getcwd() : Get the current working directory
- forgather_config_dir() : Get the platform-specific config directory for Forgather.

The following functions from https://pypi.org/project/platformdirs/
- user_data_dir()
- user_cache_dir()
- user_config_dir()
- site_data_dir()
- site_config_dir()

---
### Custom File Loader

A custom loader, derived from the FileSystemLoader, is defined. This loader has a syntax for splitting a single loaded template file into multiple sub-templates.

The primary use-case for this syntax is [template inheritance](https://jinja.palletsprojects.com/en/3.1.x/templates/#template-inheritance), which disallows multiple-inheritance. If you inherit from a template and include a template which is derived from another, Jinja2 does not allow you to direclty override blocks from the included template. You can get around this by creating another template, which overrides the desired blocks, and is included by the top-level template.

Normally, this would require creating another template file, but who needs that!? That's much more difficult to work with.

```jinja2
## This is the main template
-- extends 'base_template.jinja'

## Override block 'foo' from 'base_template.jinja'
-- block foo
    -- include 'foo.bar' ## Include the sub-template
-- endblock


##--------------------- foo.bar ---------------------
## This is a sub-template named 'foo.bar'
-- extends 'some_other_base_template.jinja'

## Override block 'bar' from 'some_other_base_template.jinja'
-- block bar
    ## ... stuff
-- endblock
```

More formally, the syntax for splitting a document is:

```python
split_on = r"\n#\s*-{3,}\s*([\w./]+)\s*-{3,}\n"
```

Note: You can't split a template defined via a Python string, as this bypasses the Loader; only file templates may be split like this.

---
## YAML

### Dot-Name Elision
YAML does not have a way of defining an object, without also constructing it. This can be inconvienient, as it may not be known ahead of time where the first use of an object will be and YAML requires that the defition occur at this point.

To work around this, if the root-node is a mapping, we delete all keys containing strings starting with a dot. Once the object has been defined, YAML does not care if we delete the original definition/instance. My convention is to use ".define", but any name, starting with a dot, will work.

By convention, the primary output object of such a mapping is named "main"

```yaml
# Define points
.define: &pt1 { x: 0, y: 0 }
.define: &pt2 { x: 5, y: 0 }
.define: &pt3 { x: 0, y: 5 }

main:
    # A list of lines, each defined by a pair of points.
    - [ *pt1, *pt2 ]
    - [ *pt2, *pt3 ]
    - [ *pt3, *pt1 ]
```

Constructed graph...

```python
graph()

{'main': [[{'x': 0, 'y': 0}, {'x': 5, 'y': 0}],
          [{'x': 5, 'y': 0}, {'x': 0, 'y': 5}],
          [{'x': 0, 'y': 5}, {'x': 0, 'y': 0}]]}
```

While not apparent from the representation, the points in the lines are not copies, they are all references to the original three points from the definition. There are only three point objects present in the graph!

---
### YAML Types

Of the standard YAML 1.1 types, only those which can be implicilty (without specifying the tag) are supported

YAML 1.1 Tag : Python Type / Examples
- !!null : None
    - null
- !!bool : bool
    - True
    - False
- !!int : int
    - 2
    - -6
- !!float : float
    - 2.0
    - 1.2e-4
- !!str : str
    - "Hello"
    -  world
- !!seq : list
    - \[ 1, 2, 3 \] 
- !!map : dict
    - { x: 1, y: 12 }

The following standard types are presently unsupported:
- !!binary
- !!timestamp
- !!omap, !!pairs
- !!set -- TODO: Implement me!

---
Complex types are instead supported through Forgather specific tags:

#### !tuple : Named Tuple

Syntax: !tuple\[:@name\] \<sequence\>

Construct a named Python tuple from a YAML sequence

```yaml
!tuple:@my_tuple [ 1, 2, 3 ]
```

```python
graph()
(1, 2, 3)
```

---

#### !list : Named List

Syntax: !list\[:@name\] \<sequence\>

Construct a named Python list from a YAML sequence

```yaml
!list:@my_list [ 1, 2, 3 ]

```

```python
graph()
[1, 2, 3]
```

---

#### !dict : Named Dictionary

Syntax: !dict\[:@\<name\>\] \<mapping\>

Construct a named Python dict from a YAML mapping

```yaml
!dict:@my_dict
    foo: 1
    bar: 2
    baz: 3
```

```python
graph()
{'foo': 1, 'bar': 2, 'baz': 3}
```

---
#### !var

Syntax: !var "\<var-name\>" | { name: \<var-name\>, default: \<default-value\> }

This declares a global variable, which can be substituted anywhere in the graph.

```yaml
document = """
point: !dict
    x: !var "x" # Define a variable named 'x'
    y: !var # Define a variable named 'y' with a default value of 16
        name: y
        default: 16
"""
```

The global context is passed in as the special 'context_vars' argument, a dictionary, when constructng the graph.

```python
graph.point(context_vars=dict(x=2.0))
{'x': 2.0, 'y': 16}
```

---
#### !call

Alias: !singleton

Synatx: !call:\<import-spec\>[@\<name\>\] (\<sequence\> | \<mapping\> | ({ args: \<sequence\>, kwargs: \<mapping\> }))

This is a callable object with only a single instance; any aliases refers to the same object instance.

```yaml
# Construct three random ints, all having the same value.
- &random_int !call:random:randrange:@random_int [ 1000 ]
- *random_int
- *random_int
```

```python
graph()

[247, 247, 247]
```

The "SingletonNode" will generally be your 'go-to' for constructing objects, as the symantics mirror what is expected for YAML anchors and aliases.

However, there are a few exceptions...

---
#### !factory

Synatx: !factory:\<import-spec\>[@\<name\>\] (\<sequence\> | \<mapping\> | ({ args: \<sequence\>, kwargs: \<mapping\> }))

This is a callable object which instantiates a new instance everywhere it appears in the graph.

```yaml
# Construct three random ints, all (probably) having different values.
- &random_int !factory:random:randrange [ 1000 ]
- *random_int
- *random_int
```

Constructed...
```python
graph()

[99, 366, 116]
```

---
#### !parial

Alias (depricated): !lambda

Synatx: !parial:\<import-spec\>[@\<name\>\] (\<sequence\> | \<mapping\> | ({ args: \<sequence\>, kwargs: \<mapping\> }))

This constructs a callable object with the same symantics of a Python partial function, where the provided positional and keyword arguments are passed 
to the function. If additional argmuents are given, the positional-args are appended and the keyword-args are merged.

See: https://docs.python.org/3/library/functools.html

```yaml
!partial:pow [ 2 ]
```

```python
graph(3)
8

# This is equivalent to:
pow(2, 3)
```

```yaml

```

---
### CallableNodes


SingletonNode, FactoryNode, and FactoryNode are all instances of the abstract-base-class "CallableNode." A CallableNode can call any Python function, including class constructors. As Python differentiates between positional args and kwargs, making use of both requires the following syntax:

```yaml
!singleton:random:sample
    args:
        - ['red', 'blue']
        - 5
    kwargs:
        counts: [4, 2]
```

Generally speaking, you can omit the explict 'args' and 'kwargs' names, as long as the syntax is unambigous.

```yaml
- !singleton:torch:tensor
    - 2
    - 2
- !singleton:random.binomialvariate { n: 1, p: 0.5 }
```

---
#### CallableNode Tag Syntax

The part of the YAML tag after the first ':' provides the information required to locate and import the requested Callable.

In the simplest case, a [built-in](https://docs.python.org/3/library/functions.html) Python callable just needs to specify the name of the built-in.

```yaml
!singleton:tuple [ 1, 2, 3 ]
```

When the Callable is defined in a module, a second ':' is used to seperate the module name from the name within the module.

```yaml
# See: https://docs.python.org/3/library/operator.html
!singleton:operator:mod [ 365, 7 ]
```

You can also dynamically import a name from a file.

```yaml
# See: https://docs.python.org/3/library/operator.html
!singleton:/path/to/my/pymodule.py:MyClass [ "foo", "bar" ]
```

When using a file-import, which itself has relative imports, you will need to specify which directories to search for relative imports:

```yaml
# See: https://docs.python.org/3/library/operator.html
!singleton:/path/to/my/pymodule.py:MyClass 
    args: [ "foo", "bar" ]
    kwargs:
        submodule_searchpath:
            - "/path/to/my/"
            - "/path/to/shared/modules/"
```
The key-word argument "submodule_searchpath" has a special meaning in this context and will not passed to the called object. 
The import system treats all of the directories in the list as a union, thus "pymodule.py" can perform a relative import from any of these directories.

---
#### Named Callable Nodes

CallableNodes may be given an explcit name. The name servers the same purpose as the YAML anchor/alias, but PyYaml does not make this information available through the tag API. While feasible to hack PyYaml, doing so is risky. For now, there is a somewhat redundant interface for specitying node names.

When a node has been assigned an explicit name, it will always be rendered as an explciit definition in the Python and Yaml code generators, as to improve readability. Doing so is entirely optional.

A callable node's tag may end with '@\<name\>' which will assign a name to the node.

```yaml
.define: &foobar !singleton:dict@foobar
    foo: 1
    bar: 2
    baz: |
        She sells sea shells
        by the sea shore
main:
    - *foobar
```

When rendered as Python:

```python
def construct(
):
    foobar = {
        'foo': 1,
        'bar': 2,
        'baz': (
                'She sells sea shells\n'
                'by the sea shore\n'
            ),
    }
    
    return {
        'main': [
            foobar,
        ],
    }
```

And without the name, the object definition becomes anonymous:

```yaml
.define: &foobar !singleton:dict
...
```

```python
def construct(
):
    return {
        'main': [
            {
                'foo': 1,
                'bar': 2,
                'baz': (
                        'She sells sea shells\n'
                        'by the sea shore\n'
                    ),
            },
        ],
    }
```

## Low Level API

*Basic Usage*

```python
# Imports
from forgather.config import ConfigEnvironment

# Construct a configuration environment
env = ConfigEnvironment()

# Define a configuration
document = """
!call:torch:randn [ 2, 2 ]
"""

# Convert the configuration to a graph
graph = env.load_from_string(document).config

# Construct the graph
graph()
tensor([[ 0.0090,  0.0064],
        [-1.1638,  0.7066]])
```

### Create Config Environment

A configuration environment is required to construct configurations from YAML/Jinja2 inputs; it conains the infromation needed to located Jina2 templates by name as well as defining the global variables available to templates.

```python
from forgather.config import ConfigEnvironment
...
ConfigEnvironment(
    searchpath: Iterable[str | os.PathLike] | str | os.PathLike = tuple("."),
    pp_environment: Environment = None,
    global_vars: Dict[str, Any] = None,
):
```

- searchpath: A list of directories to search for templates in.
- pp_environment: Override the default Jinja2 environment class with another implementation.
- global_vars: Jinja2 global variables visible to all templates.

```python
env = ConfigEnvironment("./templates/")
```

### Define Input

A configuration document consists of a combination of YAML and Jinja2 syntax. Typically, a config template would be loaded from a file, but for testing we can create a template directly from a Python string.

Both the Jinja2 template and the configuration may accept variables.

### Convert Document to Graph

```python
class ConfigEnvironment:
... 
    def load(
        self,
        config_path: os.PathLike | str,
        /,
        **kwargs,
    ) -> Config:
...
    def load_from_string(
        self,
        config: str,
        /,
        **kwargs,
    ) -> Config:
```

- load: Load a template from a path; all paths relative to 'searchpaths' are searched for the template.
    - config_path: The relative (to searchpaths) template path.
    - kwargs: These are passed into the context of the template.
- load_from_string: As with load, but a Python string defines the template body; Note that this bypasses the template loader.
    - config: A Python string with a Jinja2 template.
    - kwargs: Passed to the template.

### Materializing the Graph

Construct the objects directly from the graph.

```python
from forgather.latent import Latent
...
def materialize(obj: Any, /, *args, context_vars: Dict=None, **kwargs):
```

Construct all object in the graph, returning the constructed root-node.

context_vars: The global variables, which will be substitued by '!var' nodes.

If the root node is a partial funciton, *args and **kwargs are forwarded to the function.

Alternatively, if the root-node is not a dictionary, the following are equivalnt:

```python
Latent.materialize(graph)

# Performs the same action, if the root-node is not a dictionary.
graph()
```

If the root-node is a dictionary...

```yaml
main: !partial:math:sqrt []
```

The dictionary elements can be accessed using dot-notation and costructed individually.

```python
graph.main(16)
4.0
```

### Convert Graph to YAML

Convert the node-graph to a YAML representation. This may not be exactly the same as it was in the source template, but should be symantically equivalent.

```python
from forgather.yaml_encoder import to_yaml
...
def to_yaml(obj: Any):
```

## Convert Graph to Python

This function takes the output from Latent.to_py(graph) and uses it to render Pyhon code using a Jinja2 template. If the template is unspecified, an implicit "built-in" template is used, which will generate appropriate import and dynamic import statements, where required.

```python
from forgather.codegen import generate_code
...
def generate_code(
    obj,
    template_name: Optional[str] = None,
    template_str: Optional[str] = None,
    searchpath: Optional[List[str | os.PathLike] | str | os.PathLike] = ".",
    env=None,  # jinja2 environment or compatible API
    **kwargs,
) -> Any:
```

The default template accepts the following additional kwargs:

    factory_name: Optional[str]="construct", ; The name of the generated factory function.
    relaxed_kwargs: Optional[bool]=Undefined, ; if defined, **kwargs is added to the arg list
    
See 'help(generate_code)' for details.

## Trivial Examples

In [2]:
# Imports
from forgather.config import ConfigEnvironment

# Construct a configuration environment
env = ConfigEnvironment()

# Define a configuration
# Here, we construct a 2x2 random tensor.
document = """
!call:torch:randn [ 2, 2 ]
"""

# Convert the configuration to a graph
graph = env.load_from_string(document).config

# Construct the graph
graph()

tensor([[-0.3099, -0.7097],
        [ 1.0701,  0.2011]])

In [3]:
# Construct a function which computes the square-root of its argument.
graph = env.load_from_string("main: !partial:math:sqrt []").config
graph.main(16)

4.0

## Complex Example

The following template defines a simple (and somewhat incomplete) language model.

[./examples/model_def.yaml](./examples/model_def.yaml)

In [4]:
template_path = os.path.join('examples', 'model_def.yaml')
with open(template_path, 'r') as f:
    nb.display_codeblock("yaml", f.read(), "### Configuration Template")

### Configuration Template
```yaml
-- set ns = namespace()
-- from 'examples/formatting.jinja' import h1, h2, h3
-- filter trim() ## This removes whitespace before the header.

## Jina2 block definitions; we can override these in derived templates.
-- block meta_config
    -- set ns.model_src = '../model_src/bits/'
    -- set ns.config_name = 'Control'
    -- set ns.config_description = "Baseline Control"
    ## Example of variable set by jinja2 template.
    -- set ns.vocab_size = 1024
<< endblock meta_config


-- endfilter
-- block header
== h1(ns.config_name)
# {{ utcisotime() }}
# Description: {{ ns.config_description }}
# model_src = {{ ns.model_src }}
# Current Working Dir: "{{ getcwd() }}"
# Forgather Config Dir: "{{ abspath(forgather_config_dir()) }}"
<< endblock header


== h2("Model Definition")

== h3("Layer Norm Factory")

-- block layer_norm_factory
.define: &layer_norm_factory !lambda:torch.nn:LayerNorm@layer_norm_factory
    - !var "hidden_size"
<< endblock layer_norm_factory


== h3("Activation Factory")

-- block activation_factory
.define: &activation_factory !partial:torch.nn:ReLU@activation_factory []
<< endblock activation_factory


== h3("Feedforward Factory")

-- block feedforward_factory
.define: &feedforward_factory !partial:{{ns.model_src}}feedforward_layer.py:FeedforwardLayer@feedforward_factory
    activation_factory: *activation_factory
    d_model: !var "hidden_size"
    d_feedforward: !var "dim_feedforward"
<< endblock feedforward_factory


== h3("Attention Factory")

-- block attention_factory
.define: &attention_factory !partial:{{ns.model_src}}single_head_attn.py:SingleHeadAttn@attention_factory
    d_model: !var "hidden_size"
<< endblock attention_factory


== h3("Layer Factory")

-- block layer_factory
.define: &layer_factory !partial:{{ns.model_src}}pre_ln_layer.py:PreLNLayer@layer_factory
    feedforward_factory: *feedforward_factory
    attention_factory: *attention_factory
    norm_factory: *layer_norm_factory
<< endblock layer_factory


== h3("Layer Stack Factory")

-- block layer_stack_factory
.define: &layer_stack_factory !factory:{{ns.model_src}}layer_stack.py:LayerStack@layer_stack_factory
    layer_factory: *layer_factory
    post_norm_factory: *layer_norm_factory
    num_hidden_layers: !var "n_layers"
<< endblock layer_stack_factory


== h3("Model")

-- block model
## This block is not nearly as factored-out as the others, using inline-definiions.
.define: &model !call:{{ns.model_src}}causal_lm.py:CasualLM@model
    loss_fn: !factory:{{ns.model_src}}causal_loss.py:CausalLoss
    input_encoder: !factory:{{ns.model_src}}input_encoder.py:InputEncoder
        d_model: !var "hidden_size"
        vocab_size: {{ ns.vocab_size }}
    output_decoder: !factory:torch.nn:Linear [ {{ ns.vocab_size }}, !var "hidden_size" ]
    init_weights: !partial:{{ns.model_src}}init_weights.py:simple_weight_init
    layer_stack: *layer_stack_factory
<< endblock model


## Main output
main: *model

```



## Preprocess the Template

This will only run the Jinja preprocessor.

This more or less looks like the original template...

In [5]:
env = ConfigEnvironment()

pp_config = env.preprocess(template_path)
nb.display_codeblock("yaml", pp_config, "### Pre Processed Template")

### Pre Processed Template
```yaml
#---------------------------------------
#                 Control                
#---------------------------------------
# 2025-06-07T04:25:44
# Description: Baseline Control
# model_src = ../model_src/bits/
# Current Working Dir: "/home/dinalt/ai_assets/forgather/notebooks"
# Forgather Config Dir: "/home/dinalt/.config/forgather"

########### Model Definition ###########

# **Layer Norm Factory**

.define: &layer_norm_factory !lambda:torch.nn:LayerNorm@layer_norm_factory
    - !var "hidden_size"

# **Activation Factory**

.define: &activation_factory !partial:torch.nn:ReLU@activation_factory []

# **Feedforward Factory**

.define: &feedforward_factory !partial:../model_src/bits/feedforward_layer.py:FeedforwardLayer@feedforward_factory
    activation_factory: *activation_factory
    d_model: !var "hidden_size"
    d_feedforward: !var "dim_feedforward"

# **Attention Factory**

.define: &attention_factory !partial:../model_src/bits/single_head_attn.py:SingleHeadAttn@attention_factory
    d_model: !var "hidden_size"

# **Layer Factory**

.define: &layer_factory !partial:../model_src/bits/pre_ln_layer.py:PreLNLayer@layer_factory
    feedforward_factory: *feedforward_factory
    attention_factory: *attention_factory
    norm_factory: *layer_norm_factory

# **Layer Stack Factory**

.define: &layer_stack_factory !factory:../model_src/bits/layer_stack.py:LayerStack@layer_stack_factory
    layer_factory: *layer_factory
    post_norm_factory: *layer_norm_factory
    num_hidden_layers: !var "n_layers"

# **Model**

.define: &model !call:../model_src/bits/causal_lm.py:CasualLM@model
    loss_fn: !factory:../model_src/bits/causal_loss.py:CausalLoss
    input_encoder: !factory:../model_src/bits/input_encoder.py:InputEncoder
        d_model: !var "hidden_size"
        vocab_size: 1024
    output_decoder: !factory:torch.nn:Linear [ 1024, !var "hidden_size" ]
    init_weights: !partial:../model_src/bits/init_weights.py:simple_weight_init
    layer_stack: *layer_stack_factory

main: *model

```



## Construct Model Instance

In [6]:
graph = env.load(template_path).config

model_config =dict(
    hidden_size=64,
    dim_feedforward=256,
    n_layers=2,
)

graph.main(context_vars=model_config)

CasualLM(
  loss_fn=CausalLoss()
  (input_encoder): InputEncoder(
    d_model=64, vocab_size=1024
    (dropout): Dropout(p=0.1, inplace=False)
    (embedding): Embedding(1024, 64)
  )
  (output_decoder): Linear(in_features=1024, out_features=64, bias=True)
  (layer_stack): LayerStack(
    (layers): ModuleDict(
      (0): PreLNLayer(
        (feedforward): FeedforwardLayer(
          d_model=64, d_feedforward=256
          (linear1): Linear(in_features=64, out_features=256, bias=True)
          (dropout): Identity()
          (activation): ReLU()
          (linear2): Linear(in_features=256, out_features=64, bias=True)
        )
        (attention): SingleHeadAttn(
          d_model=64, bias=True
          (query_key_linear): Linear(in_features=64, out_features=64, bias=True)
          (value_linear): Linear(in_features=64, out_features=64, bias=True)
        )
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_a

### Override Something

For our experiment, we will want to change just one variable.

In [7]:
experiment_config = """
-- extends 'examples/model_def.yaml'

-- block meta_config
    ## This includes the definition from the parent.
    == super()
    -- set ns.config_name = "No Bias"
    -- set ns.config_description = "Disabled bias in attention. Does it matter?"
<< endblock meta_config

-- block attention_factory
    == super()

    ## And add an override. We are essentially just appending more arguments to the definition.
    # Experiment override.
    bias: False
<< endblock attention_factory
"""

output = env.load_from_string(experiment_config)
nb.display_codeblock("yaml", output.pp_config, "#### Pre Processed Experiment Config")

#### Pre Processed Experiment Config
```yaml

#---------------------------------------
#                 No Bias                
#---------------------------------------
# 2025-06-07T04:26:00
# Description: Disabled bias in attention. Does it matter?
# model_src = ../model_src/bits/
# Current Working Dir: "/home/dinalt/ai_assets/forgather/notebooks"
# Forgather Config Dir: "/home/dinalt/.config/forgather"

########### Model Definition ###########

# **Layer Norm Factory**

.define: &layer_norm_factory !lambda:torch.nn:LayerNorm@layer_norm_factory
    - !var "hidden_size"

# **Activation Factory**

.define: &activation_factory !partial:torch.nn:ReLU@activation_factory []

# **Feedforward Factory**

.define: &feedforward_factory !partial:../model_src/bits/feedforward_layer.py:FeedforwardLayer@feedforward_factory
    activation_factory: *activation_factory
    d_model: !var "hidden_size"
    d_feedforward: !var "dim_feedforward"

# **Attention Factory**

.define: &attention_factory !partial:../model_src/bits/single_head_attn.py:SingleHeadAttn@attention_factory
    d_model: !var "hidden_size"

    # Experiment override.
    bias: False

# **Layer Factory**

.define: &layer_factory !partial:../model_src/bits/pre_ln_layer.py:PreLNLayer@layer_factory
    feedforward_factory: *feedforward_factory
    attention_factory: *attention_factory
    norm_factory: *layer_norm_factory

# **Layer Stack Factory**

.define: &layer_stack_factory !factory:../model_src/bits/layer_stack.py:LayerStack@layer_stack_factory
    layer_factory: *layer_factory
    post_norm_factory: *layer_norm_factory
    num_hidden_layers: !var "n_layers"

# **Model**

.define: &model !call:../model_src/bits/causal_lm.py:CasualLM@model
    loss_fn: !factory:../model_src/bits/causal_loss.py:CausalLoss
    input_encoder: !factory:../model_src/bits/input_encoder.py:InputEncoder
        d_model: !var "hidden_size"
        vocab_size: 1024
    output_decoder: !factory:torch.nn:Linear [ 1024, !var "hidden_size" ]
    init_weights: !partial:../model_src/bits/init_weights.py:simple_weight_init
    layer_stack: *layer_stack_factory

main: *model

```



## Construct Experiment Model

This model now has been modified. The bias is now disabled on the attention module.

In [8]:
graph = output.config
graph.main(context_vars=model_config)

CasualLM(
  loss_fn=CausalLoss()
  (input_encoder): InputEncoder(
    d_model=64, vocab_size=1024
    (dropout): Dropout(p=0.1, inplace=False)
    (embedding): Embedding(1024, 64)
  )
  (output_decoder): Linear(in_features=1024, out_features=64, bias=True)
  (layer_stack): LayerStack(
    (layers): ModuleDict(
      (0): PreLNLayer(
        (feedforward): FeedforwardLayer(
          d_model=64, d_feedforward=256
          (linear1): Linear(in_features=64, out_features=256, bias=True)
          (dropout): Identity()
          (activation): ReLU()
          (linear2): Linear(in_features=256, out_features=64, bias=True)
        )
        (attention): SingleHeadAttn(
          d_model=64, bias=False
          (query_key_linear): Linear(in_features=64, out_features=64, bias=False)
          (value_linear): Linear(in_features=64, out_features=64, bias=False)
        )
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwis

### Implementation Override

Unlike most configuration systems, we can not only change numerical parameters, we can alter the implementatinon!

Let's replace the simple single-head attention module with a multihead-attention module.

In [9]:
experiment_config = """
-- extends 'examples/model_def.yaml'

-- block meta_config
    == super()
    -- set ns.config_name = "Multihead Attention"
    -- set ns.config_description = "Swapped singlehead attention for multihead attention."
    -- set ns.attention_heads = 2
<< endblock meta_config


-- block attention_factory
# Experiment Override.
.define: &attention_factory !partial:{{ns.model_src}}causal_multihead_attn.py:CausalMultiheadAttn@attention_factory
    d_model: !var "hidden_size"
    num_heads: {{ ns.attention_heads }}
<< endblock attention_factory
"""

output = env.load_from_string(experiment_config)
nb.display_codeblock("yaml", output.pp_config, "#### Pre Processed Experiment Config")

#### Pre Processed Experiment Config
```yaml

#---------------------------------------
#           Multihead Attention          
#---------------------------------------
# 2025-06-07T04:26:03
# Description: Swapped singlehead attention for multihead attention.
# model_src = ../model_src/bits/
# Current Working Dir: "/home/dinalt/ai_assets/forgather/notebooks"
# Forgather Config Dir: "/home/dinalt/.config/forgather"

########### Model Definition ###########

# **Layer Norm Factory**

.define: &layer_norm_factory !lambda:torch.nn:LayerNorm@layer_norm_factory
    - !var "hidden_size"

# **Activation Factory**

.define: &activation_factory !partial:torch.nn:ReLU@activation_factory []

# **Feedforward Factory**

.define: &feedforward_factory !partial:../model_src/bits/feedforward_layer.py:FeedforwardLayer@feedforward_factory
    activation_factory: *activation_factory
    d_model: !var "hidden_size"
    d_feedforward: !var "dim_feedforward"

# **Attention Factory**

# Experiment Override.
.define: &attention_factory !partial:../model_src/bits/causal_multihead_attn.py:CausalMultiheadAttn@attention_factory
    d_model: !var "hidden_size"
    num_heads: 2

# **Layer Factory**

.define: &layer_factory !partial:../model_src/bits/pre_ln_layer.py:PreLNLayer@layer_factory
    feedforward_factory: *feedforward_factory
    attention_factory: *attention_factory
    norm_factory: *layer_norm_factory

# **Layer Stack Factory**

.define: &layer_stack_factory !factory:../model_src/bits/layer_stack.py:LayerStack@layer_stack_factory
    layer_factory: *layer_factory
    post_norm_factory: *layer_norm_factory
    num_hidden_layers: !var "n_layers"

# **Model**

.define: &model !call:../model_src/bits/causal_lm.py:CasualLM@model
    loss_fn: !factory:../model_src/bits/causal_loss.py:CausalLoss
    input_encoder: !factory:../model_src/bits/input_encoder.py:InputEncoder
        d_model: !var "hidden_size"
        vocab_size: 1024
    output_decoder: !factory:torch.nn:Linear [ 1024, !var "hidden_size" ]
    init_weights: !partial:../model_src/bits/init_weights.py:simple_weight_init
    layer_stack: *layer_stack_factory

main: *model

```



### Examine the Graph

Internally, the processed configuraiton is represented as an abstract node graph.

In [10]:
nb.display_codeblock("python", pformat(graph), "### Node Graph")

### Node Graph
```python
{'main': SingletonNode('../model_src/bits/causal_lm.py:CasualLM', *(), identity='model', **{'loss_fn': FactoryNode('../model_src/bits/causal_loss.py:CausalLoss', *(), identity=140568966666448, **{}), 'input_encoder': FactoryNode('../model_src/bits/input_encoder.py:InputEncoder', *(), identity=140568966671056, **{'d_model': VarNode('hidden_size', identity=140568966666400, value=Undefined), 'vocab_size': 1024}), 'output_decoder': FactoryNode('torch.nn:Linear', *(1024, VarNode('hidden_size', identity=140568966661072, value=Undefined)), identity=140568966660784, **{}), 'init_weights': LambdaNode('../model_src/bits/init_weights.py:simple_weight_init', *(), identity=140568966672304, **{}), 'layer_stack': FactoryNode('../model_src/bits/layer_stack.py:LayerStack', *(), identity='layer_stack_factory', **{'layer_factory': LambdaNode('../model_src/bits/pre_ln_layer.py:PreLNLayer', *(), identity='layer_factory', **{'feedforward_factory': LambdaNode('../model_src/bits/feedforward_layer.py:FeedforwardLayer', *(), identity='feedforward_factory', **{'activation_factory': LambdaNode('torch.nn:ReLU', *(), identity='activation_factory', **{}), 'd_model': VarNode('hidden_size', identity=140568966674800, value=Undefined), 'd_feedforward': VarNode('dim_feedforward', identity=140568966661552, value=Undefined)}), 'attention_factory': LambdaNode('../model_src/bits/single_head_attn.py:SingleHeadAttn', *(), identity='attention_factory', **{'d_model': VarNode('hidden_size', identity=140568966660544, value=Undefined), 'bias': False}), 'norm_factory': LambdaNode('torch.nn:LayerNorm', *(VarNode('hidden_size', identity=140568966661600, value=Undefined),), identity='layer_norm_factory', **{})}), 'post_norm_factory': LambdaNode('torch.nn:LayerNorm', *(VarNode('hidden_size', identity=140568966661600, value=Undefined),), identity='layer_norm_factory', **{}), 'num_hidden_layers': VarNode('n_layers', identity=140568966671536, value=Undefined)})})}

```



### Convert Graph to YAML

Convert the node-graph to a YAML representation. This may not be exactly the same as it was in the source template, but should be symantically equivalent.

In [11]:
nb.display_codeblock("yaml", to_yaml(graph))

```yaml
.define: &activation_factory !lambda:torch.nn:ReLU@activation_factory []

.define: &feedforward_factory !lambda:../model_src/bits/feedforward_layer.py:FeedforwardLayer@feedforward_factory
    activation_factory: *activation_factory
    d_model: !var 'hidden_size'
    d_feedforward: !var 'dim_feedforward'

.define: &attention_factory !lambda:../model_src/bits/single_head_attn.py:SingleHeadAttn@attention_factory
    d_model: !var 'hidden_size'
    bias: False

.define: &layer_norm_factory !lambda:torch.nn:LayerNorm@layer_norm_factory
    - !var 'hidden_size'

.define: &layer_factory !lambda:../model_src/bits/pre_ln_layer.py:PreLNLayer@layer_factory
    feedforward_factory: *feedforward_factory
    attention_factory: *attention_factory
    norm_factory: *layer_norm_factory

.define: &layer_stack_factory !factory:../model_src/bits/layer_stack.py:LayerStack@layer_stack_factory
    layer_factory: *layer_factory
    post_norm_factory: *layer_norm_factory
    num_hidden_layers: !var 'n_layers'

.define: &model !singleton:../model_src/bits/causal_lm.py:CasualLM@model
    loss_fn: !factory:../model_src/bits/causal_loss.py:CausalLoss []
    input_encoder: !factory:../model_src/bits/input_encoder.py:InputEncoder
        d_model: !var 'hidden_size'
        vocab_size: 1024
    output_decoder: !factory:torch.nn:Linear
        - 1024
        - !var 'hidden_size'
    init_weights: !lambda:../model_src/bits/init_weights.py:simple_weight_init []
    layer_stack: *layer_stack_factory


main: *model

```



### Convert Graph to Python

This function takes the output from Latent.to_py(graph) and uses it to render Pyhon code using a Jinja2 template. If the template is unspecified, an implicit "built-in" template is used, which will generate appropriate import and dynamic import statements, where required.

In [12]:
from forgather.graph_encoder import NamePolicy # NamePolicy.REQUIRED | NamePolicy.ALL | NamePolicy.NAMED
generated_code = generate_code(graph.main, name_policy=None)
nb.display_codeblock("python", generated_code, "### Generated Code", )

### Generated Code
```python
from torch.nn import LayerNorm
from torch.nn import Linear
from torch.nn import ReLU
from importlib.util import spec_from_file_location, module_from_spec
import os
import sys
from functools import partial

# Import a dynamic module.
def dynimport(module, name, searchpath):
    module_path = module
    module_name = os.path.basename(module).split(".")[0]
    module_spec = spec_from_file_location(
        module_name,
        module_path,
        submodule_search_locations=searchpath,
    )
    mod = module_from_spec(module_spec)
    sys.modules[module_name] = mod
    module_spec.loader.exec_module(mod)
    for symbol in name.split("."):
        mod = getattr(mod, symbol)
    return mod

CasualLM = lambda: dynimport("../model_src/bits/causal_lm.py", "CasualLM", ())
CausalLoss = lambda: dynimport("../model_src/bits/causal_loss.py", "CausalLoss", ())
FeedforwardLayer = lambda: dynimport("../model_src/bits/feedforward_layer.py", "FeedforwardLayer", ())
simple_weight_init = lambda: dynimport("../model_src/bits/init_weights.py", "simple_weight_init", ())
InputEncoder = lambda: dynimport("../model_src/bits/input_encoder.py", "InputEncoder", ())
LayerStack = lambda: dynimport("../model_src/bits/layer_stack.py", "LayerStack", ())
PreLNLayer = lambda: dynimport("../model_src/bits/pre_ln_layer.py", "PreLNLayer", ())
SingleHeadAttn = lambda: dynimport("../model_src/bits/single_head_attn.py", "SingleHeadAttn", ())

def construct(
    dim_feedforward,
    hidden_size,
    n_layers,
):
    activation_factory = partial(ReLU, )

    feedforward_factory = partial(FeedforwardLayer(), 
        activation_factory=activation_factory,
        d_model=hidden_size,
        d_feedforward=dim_feedforward,
    )

    attention_factory = partial(SingleHeadAttn(), 
        d_model=hidden_size,
        bias=False,
    )

    layer_norm_factory = partial(LayerNorm, 
        hidden_size,
    )

    layer_factory = partial(PreLNLayer(), 
        feedforward_factory=feedforward_factory,
        attention_factory=attention_factory,
        norm_factory=layer_norm_factory,
    )

    layer_stack_factory = partial(LayerStack(), 
        layer_factory=layer_factory,
        post_norm_factory=layer_norm_factory,
        num_hidden_layers=n_layers,
    )

    model = CasualLM()(
        loss_fn=CausalLoss()(),
        input_encoder=InputEncoder()(
            d_model=hidden_size,
            vocab_size=1024,
        ),
        output_decoder=Linear(
            1024,
            hidden_size,
        ),
        init_weights=partial(simple_weight_init(), ),
        layer_stack=layer_stack_factory(),
    )
    
    return model

```



### Custom Code Template

The above code is pretty generic. How about we wrap this class with a HF PreTrainedModel?  
[./examples/causal_lm.py](./examples/causal_lm.py)

In [13]:
generated_code = generate_code(graph.main, template_name="examples/causal_lm.py", model_type="my_model")
nb.display_codeblock("python", generated_code, "### Generated Code", )

### Generated Code
```python
# See: https://huggingface.co/docs/transformers/custom_models
# This is a template model, with the details filled-in by the code-generator.
from typing import Optional, Tuple

from functools import partial
from torch import nn, Tensor, LongTensor, FloatTensor
import torch
from transformers.modeling_outputs import CausalLMOutput
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    GenerationMixin,
)

from torch.nn import LayerNorm
from torch.nn import Linear
from torch.nn import ReLU

from importlib.util import spec_from_file_location, module_from_spec
import os
import sys
from functools import partial

# Import a dynamic module.
def dynimport(module, name, searchpath):
    module_path = module
    module_name = os.path.basename(module).split(".")[0]
    module_spec = spec_from_file_location(
        module_name,
        module_path,
        submodule_search_locations=searchpath,
    )
    mod = module_from_spec(module_spec)
    sys.modules[module_name] = mod
    module_spec.loader.exec_module(mod)
    for symbol in name.split("."):
        mod = getattr(mod, symbol)
    return mod

CasualLM = lambda: dynimport("../model_src/bits/causal_lm.py", "CasualLM", ())
CausalLoss = lambda: dynimport("../model_src/bits/causal_loss.py", "CausalLoss", ())
FeedforwardLayer = lambda: dynimport("../model_src/bits/feedforward_layer.py", "FeedforwardLayer", ())
simple_weight_init = lambda: dynimport("../model_src/bits/init_weights.py", "simple_weight_init", ())
InputEncoder = lambda: dynimport("../model_src/bits/input_encoder.py", "InputEncoder", ())
LayerStack = lambda: dynimport("../model_src/bits/layer_stack.py", "LayerStack", ())
PreLNLayer = lambda: dynimport("../model_src/bits/pre_ln_layer.py", "PreLNLayer", ())
SingleHeadAttn = lambda: dynimport("../model_src/bits/single_head_attn.py", "SingleHeadAttn", ())

model_type = "my_model"


class DynamicCausalLMConfig(PretrainedConfig):
    model_type = model_type


class DynamicCasualLM(PreTrainedModel, GenerationMixin):
    config_class = DynamicCausalLMConfig
    model_type = model_type

    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.causal_lm = self.construct_model(**config.to_dict())
        if "torch_dtype" in config:
            self.to(config.torch_dtype)

    @staticmethod
    def construct_model(
        dim_feedforward,
        hidden_size,
        n_layers,
        **kwargs
    ):
        activation_factory = partial(ReLU, )

        feedforward_factory = partial(FeedforwardLayer(), 
            activation_factory=activation_factory,
            d_model=hidden_size,
            d_feedforward=dim_feedforward,
        )

        attention_factory = partial(SingleHeadAttn(), 
            d_model=hidden_size,
            bias=False,
        )

        layer_norm_factory = partial(LayerNorm, 
            hidden_size,
        )

        layer_factory = partial(PreLNLayer(), 
            feedforward_factory=feedforward_factory,
            attention_factory=attention_factory,
            norm_factory=layer_norm_factory,
        )

        layer_stack_factory = partial(LayerStack(), 
            layer_factory=layer_factory,
            post_norm_factory=layer_norm_factory,
            num_hidden_layers=n_layers,
        )

        model = CasualLM()(
            loss_fn=CausalLoss()(),
            input_encoder=InputEncoder()(
                d_model=hidden_size,
                vocab_size=1024,
            ),
            output_decoder=Linear(
                1024,
                hidden_size,
            ),
            init_weights=partial(simple_weight_init(), ),
            layer_stack=layer_stack_factory(),
        )
        
        return model

    def forward(
        self,
        input_ids: LongTensor,
        labels: Optional[LongTensor] = None,
        position_ids: Optional[LongTensor] = None,
        attention_mask: Optional[FloatTensor] = None,
        return_dict: bool = False,
        **kwargs,
    ) -> CausalLMOutput | Tuple[FloatTensor, dict[str, FloatTensor]] | FloatTensor:

        outputs = self.causal_lm(
            input_ids=input_ids,
            labels=labels,
            position_ids=position_ids,
            attention_mask=attention_mask,
            **kwargs,
        )

        # Return type depends on arguments.
        if return_dict:
            return CausalLMOutput(**outputs)
        elif labels is not None:
            return (outputs["loss"], outputs["logits"])
        else:
            return outputs["logits"]

    # Bare-minimum for HF text generation interface to work.
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        attention_mask = kwargs.get("attention_mask", None)
        model_inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }
        return model_inputs


AutoConfig.register(model_type, DynamicCausalLMConfig)
AutoModelForCausalLM.register(DynamicCausalLMConfig, DynamicCasualLM)

```



## Execute Generated Code

Execute the generated code, then call the generated 'construct' function to construct the objects.

Note: Lambda nodes with args are not working at present (although Latent.materialize() works)

In [15]:
exec(generated_code)

In [16]:
model_config = DynamicCausalLMConfig(hidden_size=128, dim_feedforward=512, n_layers=3)
model = DynamicCasualLM(model_config)
model

DynamicCasualLM(
  (causal_lm): CasualLM(
    loss_fn=CausalLoss()
    (input_encoder): InputEncoder(
      d_model=128, vocab_size=1024
      (dropout): Dropout(p=0.1, inplace=False)
      (embedding): Embedding(1024, 128)
    )
    (output_decoder): Linear(in_features=1024, out_features=128, bias=True)
    (layer_stack): LayerStack(
      (layers): ModuleDict(
        (0): PreLNLayer(
          (feedforward): FeedforwardLayer(
            d_model=128, d_feedforward=512
            (linear1): Linear(in_features=128, out_features=512, bias=True)
            (dropout): Identity()
            (activation): ReLU()
            (linear2): Linear(in_features=512, out_features=128, bias=True)
          )
          (attention): SingleHeadAttn(
            d_model=128, bias=False
            (query_key_linear): Linear(in_features=128, out_features=128, bias=False)
            (value_linear): Linear(in_features=128, out_features=128, bias=False)
          )
          (norm1): LayerNorm((128,), e