In [1]:
from __future__ import annotations
from typing import Union, TypeVar
from metadsl_all import *

```python
if cond:
    a = b
    c = d
else:
    a = b + 1
    c = d + 1
```

In [2]:
@expression
def init_state() -> State:
    ...

In [3]:
T = TypeVar("T")
@expression
def if_(cond: object, true_branch: T, false_branch: T) -> T:
    ...

In [4]:
@expression
def add(x: object, y: object) -> object:
    ...

In [5]:
state = init_state()

line_2_state = state.set_var("a", line(2, state.get_var("b")))
line_3_state = line_2_state.set_var("c", line(3, line_2_state.get_var("d")))

line_5_state = state.set_var("a", line(5, add(state.get_var("b"), 1)))
line_6_state = line_5_state.set_var("c", line(6, line_5_state.get_var("d")))

state = if_(
    line(1, state.get_var("cond")),
    line_3_state,
    line_6_state,
)
res = state.namespace['a']
get_lines(res)

{1, 2, 3, 5, 6}

In [6]:
res = execute(res)

Typez(definitions=None, nodes=[PrimitiveNode(id='5588165156287652350', type='int', repr='1'), CallNode(id='-824938096903300058', function='init_state\n', type_params=None, args=None, kwargs=None), PrimitiveNode(id='-7587250769671343037', type='str', repr='cond'), PrimitiveNode(id='8126027657902826506', type='str', repr='a'), PrimitiveNode(id='-7875468003596548032', type='int', repr='2'), PrimitiveNode(id='903956026342425513', type='str', repr='b'), PrimitiveNode(id='-949420366002552021', type='NoneType', repr='None'), PrimitiveNode(id='-8294602091862145234', type='str', repr='c'), PrimitiveNode(id='-7059664274030091654', type='int', repr='3'), PrimitiveNode(id='1081597251926427767', type='str', repr='d'), PrimitiveNode(id='3626853550258118895', type='int', repr='5'), PrimitiveNode(id='8005759971604257497', type='int', repr='6'), CallNode(id='-3932112222624709724', function='State.namespace\n', type_params=None, args=['-824938096903300058'], kwargs=None), CallNode(id='-87246864570436602

In [7]:
@register_ds
@rule
def if_optimizations(x: T, y: T, cond: object):
    
    def res():
        if not isinstance(x, Expression) or not isinstance(y, Expression):
            raise NoMatch()
        if type(x) != type(y):
            raise NoMatch()
        if x.function != y.function:
            raise NoMatch()
        
        if len(x.args) != len(y.args) or set(x.kwargs.keys()) != set(y.kwargs.keys()):
            raise NoMatch()
        
        differing_arg_or_kwarg = None
        for i in range(len(x.args)):
            if x.args[i] == y.args[i]:
                continue
            if differing_arg_or_kwarg is not None:
                raise NoMatch()
            differing_arg_or_kwarg = i
        for k in x.kwargs.keys():
            if x.kwargs[k] == y.kwargs[k]:
                continue
            if differing_arg_or_kwarg  is not None:
                raise NoMatch()
            differing_arg_or_kwarg = k
        
        # all the same
        if differing_arg_or_kwarg  is None:
            return x
        
        return type(x)(
            function=x.function,
            args=[if_(cond, a, y.args[i]) if i == differing_arg_or_kwarg else a for i, a in enumerate(x.args)],
            kwargs={k: if_(cond, v, y.kwargs[k]) if k == differing_arg_or_kwarg else v for k, v in x.kwargs.items()},
        )
    # if if_ is appplied to both side of the same function, with only one differing arg or kwargs, we can put it inside
    yield if_(cond, x, y), res
    
    yield if_(cond, l, r)[k], if_(cond

How do we move the `if_` inside the `Mapping.create`? When can `if_` be moved inside?

Of course we could move the getitem inside, but really we want to move the inside the create, no?

It's about independence vs depencence... 

In [8]:
res = execute(res)

Typez(definitions=None, nodes=[PrimitiveNode(id='5588165156287652350', type='int', repr='1'), CallNode(id='-824938096903300058', function='init_state\n', type_params=None, args=None, kwargs=None), PrimitiveNode(id='-7587250769671343037', type='str', repr='cond'), PrimitiveNode(id='-8294602091862145234', type='str', repr='c'), PrimitiveNode(id='-7059664274030091654', type='int', repr='3'), CallNode(id='4702451177623270837', function='Mapping.create\n', type_params=None, args=None, kwargs=None), PrimitiveNode(id='1081597251926427767', type='str', repr='d'), PrimitiveNode(id='8126027657902826506', type='str', repr='a'), PrimitiveNode(id='-7875468003596548032', type='int', repr='2'), PrimitiveNode(id='903956026342425513', type='str', repr='b'), PrimitiveNode(id='-949420366002552021', type='NoneType', repr='None'), PrimitiveNode(id='8005759971604257497', type='int', repr='6'), PrimitiveNode(id='3626853550258118895', type='int', repr='5'), CallNode(id='-3932112222624709724', function='State.

In [9]:
get_lines(res)

{1, 2, 3, 5, 6}