In [None]:
!pip install lark

In [None]:
import logging
from init_notebook import *
import lark

lark.logger.setLevel(logging.WARN)

In [None]:
grammar = r"""
start: elements
elements: element ("-" element)*
element: layer | loop | default_assignment
loop: UINT "x(" elements ")"
layer: conv | activation_layer | batch_norm_layer | residual_layer

activation_layer: ACTIVATION

batch_norm_layer: "bn"

residual_layer: "r(" layer ")"

default_assignment: (kernel_size | stride | dilation | padding | activation)+

conv: channels "x" (kernel_size | stride | dilation | padding | activation)*
channels: UINT | UFLOAT
kernel_size: "k" UINT | UINT
stride: "s" UINT
dilation: "d" UINT
padding: "p" UINT
activation: "a" ACTIVATION

ACTIVATION: "relu" | "gelu" | "sigmoid" | "tanh"
UINT: /0|[1-9]\d*/
UFLOAT: UINT? "." /\d/*
"""
class Transformer(lark.Transformer):
    def UINT(self, token: lark.Token):
        return int(token.value)
    def UFLOAT(self, token: lark.Token):
        return float(token.value)
    def ACTIVATION(self, token: lark.Token):
        return activation_to_module(token.value)
        
parser = lark.Lark(grammar, parser="lalr", transformer=Transformer())
try:
    tree = parser.parse(
        #"r(32x3p1arelu)-3x(64xk5p2s3-bn)-tanh"
        #"k9s3arelu-24x-48xatanh-64x"
        "16x-3x(1.5x)-r(1.x)-.5x"
    )
    print(tree)
    print(tree.pretty())
except lark.exceptions.LarkError as e:
    print(e)


In [None]:
class Context:
    def __init__(
            self,
            layers: nn.Sequential,
            default_conv_attrs: dict,
            previous_channels: int = 3
    ):
        self.layers = layers
        self.default_conv_attrs = default_conv_attrs
        self.previous_channels = previous_channels

    def __copy__(self):
        return self.__class__(
            layers=self.layers, 
            default_conv_attrs=self.default_conv_attrs,
            previous_channels=self.previous_channels,
        )
        
    def replace(self, key: str, value: Any) -> "Context":
        new_context = self.__copy__()
        setattr(new_context, key, value)
        return new_context
        
def add_layers(
        context: Context,
        tree: lark.Tree,
):  
    #context = context.__copy__()
    if tree.data.value in ("start", "elements"):
        for ch in tree.children:
            add_layers(context, ch)
    
    elif tree.data.value == "element":
        if tree.children[0].data.value == "loop":
            for i in range(int(tree.children[0].children[0])):
                add_layers(context, tree.children[0].children[1])
        
        elif tree.children[0].data.value == "layer":
            add_layers(context, tree.children[0])

        elif tree.children[0].data.value == "default_assignment":
            for conv_attr in tree.children[0].children:
                context.default_conv_attrs[conv_attr.data.value] = conv_attr.children[0]

    elif tree.data.value == "layer":
        if tree.children[0].data.value == "conv":
            conv_attrs = context.default_conv_attrs.copy()
            for conv_attr in tree.children[0].children:
                conv_attrs[conv_attr.data.value] = conv_attr.children[0]
            out_channels = conv_attrs.pop("channels")
            if isinstance(out_channels, float):
                out_channels = int(context.previous_channels * out_channels)
            act = conv_attrs.pop("activation", None)
            context.layers.append(
                nn.Conv2d(in_channels=context.previous_channels, out_channels=out_channels, **conv_attrs)
            )
            if act is not None:
                context.layers.append(act)
            context.previous_channels = out_channels

        elif tree.children[0].data.value == "activation_layer":
            act = tree.children[0].children[0]
            if act is not None:
                context.layers.append(act)

        elif tree.children[0].data.value == "batch_norm_layer":
            context.layers.append(nn.BatchNorm2d(num_features=context.previous_channels))

        elif tree.children[0].data.value == "residual_layer":
            sub_context = context.__copy__()
            sub_context.layers = nn.Sequential()
            print("R", tree.children[0].children[0])
            add_layers(sub_context, tree.children[0].children[0])
            context.layers.append(ResidualAdd(sub_context.layers))
            context.previous_channels = sub_context.previous_channels

context = Context(
    layers = nn.Sequential(),
    default_conv_attrs = {
        "channels": 16, 
        "kernel_size": 3, 
        "stride": 1, 
        "dilation": 1, 
        "padding": 0, 
        "bias": True,
    },
    previous_channels=3,
)
add_layers(context, tree)
context.layers

In [None]:
nn.BatchNorm2d?