ADT Pattern in Python
- https://github.com/sabrinahu5/program-synthesis/blob/main/interpreter/flashfill-interpreter/ff-interpreter.py

In [1]:
from pprint import pprint
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

In [2]:
class ImmutableAssignDict(dict):
    def __setitem__(self, key, value):
        if key in self:
            raise KeyError(f"Cannot overwrite existing key: {key}")
        super().__setitem__(key, value)

@dataclass(frozen=True)
class LType:
    pass

@dataclass(frozen=True)
class LBool(LType):
    pass

@dataclass(frozen=True)
class LInt(LType):
    pass

@dataclass(frozen=True)
class LNilType(LType):
    pass

@dataclass(frozen=True)
class LArrow(LType):
    param_type: LType
    return_type: LType

@dataclass(frozen=True)
class LProduct(LType):
    left: LType
    right: LType

@dataclass(frozen=True)
class LSum(LType):
    left: LType
    right: LType

@dataclass(frozen=True)
class LTerm:
    pass

@dataclass(frozen=True)
class LBegin(LTerm):
    terms: list[LTerm]

@dataclass(frozen=True)
class LVar(LTerm):
    name: str

@dataclass(frozen=True)
class LAbs(LTerm):
    param: LVar
    param_type: LType
    body: LTerm

@dataclass(frozen=True)
class LApp(LTerm):
    func: LTerm
    arg: LTerm

@dataclass(frozen=True)
class LClosure(LTerm):
    param: LVar
    param_type: LType
    body: LTerm
    env: ImmutableAssignDict[LVar, LTerm]

@dataclass(frozen=True)
class LTrue(LTerm):
    pass

@dataclass(frozen=True)
class LFalse(LTerm):
    pass

@dataclass(frozen=True)
class LIf(LTerm):
    condition: LTerm
    then_branch: LTerm
    else_branch: LTerm

@dataclass(frozen=True)
class LDefine(LTerm):
    name: LVar
    body: LTerm

@dataclass(frozen=True)
class LList(LType):
    element_type: LType

@dataclass(frozen=True)
class LNil(LTerm):
    pass

@dataclass(frozen=True)
class LCons(LTerm):
    car: LTerm
    cdr: LTerm

@dataclass(frozen=True)
class LNum(LTerm):
    value: int

@dataclass(frozen=True)
class LAdd(LTerm):
    left: LTerm
    right: LTerm

@dataclass(frozen=True)
class LCar(LTerm):
    lst: LTerm

@dataclass(frozen=True)
class LCdr(LTerm):
    lst: LTerm

TypeEnv = ImmutableAssignDict[LVar, LType]
EvalEnv = ImmutableAssignDict[LVar, LTerm]

In [27]:
def typecheck(gamma_in: TypeEnv, term: LTerm) -> LType:
    gamma = gamma_in.copy()
    if isinstance(term, LVar):
        if term in gamma:
            return gamma[term]
        else:
            raise RuntimeError(f"Unbound variable error: {term.name}")
    elif isinstance(term, LNil):
        return LNilType()
    elif isinstance(term, LNum):
        return LInt()
    elif isinstance(term, LTrue) or isinstance(term, LFalse):
        return LBool()
    elif isinstance(term, LAbs):
        param_type = term.param_type
        gamma[term.param] = param_type
        return LArrow(param_type, typecheck(gamma, term.body))
    elif isinstance(term, LApp):
        func_type = typecheck(gamma, term.func)
        arg_type = typecheck(gamma, term.arg)
        if isinstance(func_type, LArrow) and func_type.param_type == arg_type:
            return func_type.return_type
        else:
            raise RuntimeError(f"Application type error: {term}")
    elif isinstance(term, LIf):
        condition_type = typecheck(gamma, term.condition)
        then_type = typecheck(gamma, term.then_branch)
        else_type = typecheck(gamma, term.else_branch)
        if isinstance(condition_type, LBool) and then_type == else_type:
            return else_type
        else:
            raise RuntimeError(f"Conditional type error: {term}")
    elif isinstance(term, LCons):
        car_type = typecheck(gamma, term.car)
        cdr_type = typecheck(gamma, term.cdr)
        if isinstance(cdr_type, LNilType):
            return LList(car_type)  # Infer the type of LNil from the car

        if isinstance(cdr_type, LList) and cdr_type.element_type == car_type:
            return LList(car_type)
        else:
            raise RuntimeError("Type mismatch in cons cell")
    elif isinstance(term, LDefine):
        body_type = typecheck(gamma, term.body)
        gamma[term.name] = body_type
        return body_type
    elif isinstance(term, LAdd):
        left_type = typecheck(gamma, term.left)
        right_type = typecheck(gamma, term.right)
        if isinstance(left_type, LInt) and isinstance(right_type, LInt):
            return LInt()
        else:
            raise RuntimeError(f"Addition type error: {term}")
    elif isinstance(term, LCar):
        lst_type = typecheck(gamma, term.lst)
        if isinstance(lst_type, LList):
            return lst_type.element_type
        else:
            raise RuntimeError(f"Car applied to a non-list: {term}")
    
    elif isinstance(term, LCdr):
        lst_type = typecheck(gamma, term.lst)
        if isinstance(lst_type, LList):
            return LList(lst_type.element_type)
        else:
            raise RuntimeError(f"Cdr applied to a non-list: {term}")
    elif isinstance(term, LBegin):
        for t in term.terms[:-1]:
            typecheck(gamma, t)
        return typecheck(gamma, term.terms[-1])
    else:
        raise RuntimeError(f"Unknown type case: {term}")

def evaluate(env_in: EvalEnv, term: LTerm) -> tuple[EvalEnv, LTerm]:
    env = env_in.copy()  # Work on a copy of the environment
    if isinstance(term, LAbs):
        return env, LClosure(term.param, term.param_type, term.body, env)
    elif isinstance(term, LClosure):
        return env, term
    elif isinstance(term, LVar):
        if term in env:
            return env, env[term]
        else:
            raise RuntimeError(f"Unbound variable error: {term.name}")
    elif isinstance(term, LIf):
        updated_env, condition = evaluate(env, term.condition)
        if isinstance(condition, LTrue):
            return evaluate(updated_env, term.then_branch)
        elif isinstance(condition, LFalse):
            return evaluate(updated_env, term.else_branch)
        else:
            raise RuntimeError("Condition is supposed to only be true or false")
    elif isinstance(term, LApp):
        _, reduced_func = evaluate(env, term.func)
        _, reduced_arg = evaluate(env, term.arg)
        if not isinstance(reduced_func, LClosure):
            raise RuntimeError("Attempting to apply a non-closure")
        closure_env = reduced_func.env.copy()       # Use the closure's captured environment
        closure_env[reduced_func.param] = reduced_arg  # Bind the argument in the closure's environment
        return evaluate(closure_env, reduced_func.body)
    elif isinstance(term, LDefine):
        updated_env, evaluated_body = evaluate(env, term.body)
        updated_env[term.name] = evaluated_body
        return updated_env, evaluated_body
    elif isinstance(term, LBegin):
        for subterm in term.terms[:-1]:
            updated_env, _ = evaluate(env, subterm)
            env = updated_env
        return evaluate(env, term.terms[-1])
    elif isinstance(term, LAdd):
        _, left_val = evaluate(env, term.left)
        _, right_val = evaluate(env, term.right)
        if isinstance(left_val, LNum) and isinstance(right_val, LNum):
            return env, LNum(left_val.value + right_val.value)
        else:
            raise RuntimeError("Runtime error in addition: non-numeric values")
    elif isinstance(term, LCons):
        _, car_val = evaluate(env, term.car)
        _, cdr_val = evaluate(env, term.cdr)
        return env, LCons(car_val, cdr_val)
    elif isinstance(term, LCar):
        _, lst_val = evaluate(env, term.lst)
        if isinstance(lst_val, LCons):
            return env, lst_val.car
        else:
            raise RuntimeError("Car operation on non-cons cell")
    elif isinstance(term, LCdr):
        _, lst_val = evaluate(env, term.lst)
        if isinstance(lst_val, LCons):
            return env, lst_val.cdr
        else:
            raise RuntimeError("Cdr operation on non-cons cell")
    elif isinstance(term, LNil) or isinstance(term, LNum) or isinstance(term, LTrue) or isinstance(term, LFalse):
        return env, term
    else:
        raise RuntimeError(f"Evaluation error: {term}")

In [15]:
# Constants
TRUE = LTrue()
FALSE = LFalse()

# Example terms
identity_bool = LAbs(LVar("x"), LBool(), LVar("x"))
applied_identity = LApp(identity_bool, TRUE)

complex_term_no_type_check = LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("y"), LBool(), LVar("y")),
    FALSE
)

complex_term = LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("x"), LBool(), LVar("x")),
    LAbs(LVar("y"), LBool(), LVar("y"))
)

In [16]:
pprint(complex_term)

LIf(condition=LApp(func=LAbs(param=LVar(name='x'),
                             param_type=LBool(),
                             body=LVar(name='x')),
                   arg=LTrue()),
    then_branch=LAbs(param=LVar(name='x'),
                     param_type=LBool(),
                     body=LVar(name='x')),
    else_branch=LAbs(param=LVar(name='y'),
                     param_type=LBool(),
                     body=LVar(name='y')))


In [17]:
gamma: TypeEnv = TypeEnv()
gamma[LVar("x")] = LBool()
basic_term = LVar("x")
a = LAbs(LVar("y"), LBool(), LVar("x"))
b = LApp(a, TRUE)
print(typecheck(gamma, complex_term))
print(typecheck(gamma, basic_term))
pprint(typecheck(gamma, a))
pprint(typecheck(gamma, b))

LArrow(param_type=LBool(), return_type=LBool())
LBool()
LArrow(param_type=LBool(), return_type=LBool())
LBool()


In [18]:
env: EvalEnv = EvalEnv()
x = LVar("x")
env[x] = TRUE
complexer_term = LApp(LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("x"), LBool(), LVar("x")),
    LAbs(LVar("y"), LBool(), LVar("y"))
), LVar("x"))
print(evaluate(env, complexer_term))

({LVar(name='x'): LTrue()}, LTrue())


In [19]:
list_example = LCons(
    car=LNum(1),
    cdr=LCons(
        car=LNum(2),
        cdr=LCons(
            car=LNum(3),
            cdr=LNil(),
        ),
    ),
)
list_example

LCons(car=LNum(value=1), cdr=LCons(car=LNum(value=2), cdr=LCons(car=LNum(value=3), cdr=LNil())))

In [20]:
typecheck(gamma, list_example), evaluate(env, list_example)

(LList(element_type=LInt()),
 ({LVar(name='x'): LTrue()},
  LCons(car=LNum(value=1), cdr=LCons(car=LNum(value=2), cdr=LCons(car=LNum(value=3), cdr=LNil())))))

In [21]:
func_define = LDefine(
    name=LVar("increment"),
    body=LAbs(
        param=LVar("x"),
        param_type=LInt(),
        body=LAdd(LVar("x"), LNum(1))
    )
)

func_application = LApp(
    func=LVar("increment"),
    arg=LNum(5)
)

begin_expr = LBegin([
    func_define,
    func_application
])

gamma = ImmutableAssignDict()
env = ImmutableAssignDict()

print(typecheck(gamma, func_define))  # Should return LArrow(LInt(), LInt())

res_env, res_val = evaluate(env, begin_expr)  # Apply `increment` to 5

print(res_val)  # Should output LNum(6)

LArrow(param_type=LInt(), return_type=LInt())
LNum(value=6)


In [22]:
# Lambda that captures the value of x in its environment
x = LVar("x")
# ((λx. (λy. x + y))(10))(5) => 10 + 5 = 15
# Expected result: the closure captures x = 10, and adds 5 to it
closure_example = LApp(
    LAbs(x, LInt(), LApp(LAbs(LVar("y"), LInt(), LAdd(x, LVar("y"))), LNum(5))),
    LNum(10)
)

res_env, res_val = evaluate(env, closure_example)
print(res_val)  # Output should be LNum(15)

LNum(value=15)


In [23]:
# Define a list [1, 2, 3] using LCons and LNil
list_var = LVar("myList")
define_list = LDefine(
    list_var,
    LCons(LNum(1), LCons(LNum(2), LCons(LNum(3), LNil())))
)

# Define a function that returns a fixed number, ignoring the list
# Function: λx. 42
simple_fn_var = LVar("simple_fn")
simple_fn = LDefine(
    simple_fn_var,
    LAbs(
        LVar("lst"),
        LList(LInt()),  # The input type is a list of integers
        LNum(42)  # The function always returns 42, ignoring its input
    )
)

# Begin expression to define the list and then apply the function to it
begin_expr = LBegin(
    [
        define_list,  # Define myList = [1, 2, 3]
        simple_fn,  # Define the simple function
        LApp(simple_fn_var, list_var)  # Apply simple_fn to myList
    ]
)

# Evaluation and expected result: 42
_, result = evaluate(ImmutableAssignDict(), begin_expr)
print(result)  # Output should be LNum(42)

LNum(value=42)


In [29]:
# Define a list [1, 2, 3] using LCons and LNil
list_var = LVar("myList")
define_list = LDefine(
    list_var,
    LCons(LNum(1), LCons(LNum(2), LCons(LNum(3), LNil())))
)

# Define the LCar operation to get the first element of the list
car_expr = LCar(list_var)

# Define the LCdr operation to get the tail of the list
cdr_expr = LCdr(list_var)

# Begin expression to define the list and retrieve the head and tail
begin_car_expr = LBegin(
    [
        define_list,   # Define myList = [1, 2, 3]
        car_expr      # Get the head of the list (should be 1)
    ]
)

begin_cdr_expr = LBegin(
    [
        define_list,   # Define myList = [1, 2, 3]
        cdr_expr       # Get the tail of the list (should be [2, 3])
    ]
)

# Evaluate and check the result
_, car_result = evaluate(ImmutableAssignDict(), begin_car_expr)
print("Head of the list:", car_result)  # Output should be LNum(1), the head of the list

_, cdr_result = evaluate(ImmutableAssignDict(), begin_cdr_expr)
print("Tail of the list:", cdr_result)  # Output should be LCons(LNum(2), LCons(LNum(3), LNil())), the tail of the list

Head of the list: LNum(value=1)
Tail of the list: LCons(car=LNum(value=2), cdr=LCons(car=LNum(value=3), cdr=LNil()))


In [37]:
try:
    gamma[list_var] = LList(LInt())
except:
    pass

gamma

{LVar(name='myList'): LList(element_type=LInt())}

In [38]:
typecheck(gamma, begin_cdr_expr)

LList(element_type=LInt())