ADT Pattern in Python
- https://github.com/sabrinahu5/program-synthesis/blob/main/interpreter/flashfill-interpreter/ff-interpreter.py

In [6]:
from pprint import pprint
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

In [47]:
@dataclass(frozen=True)
class LType:
    pass

@dataclass(frozen=True)
class LBool(LType):
    pass

@dataclass(frozen=True)
class LInt(LType):
    pass

@dataclass(frozen=True)
class LNilType(LType):
    pass

@dataclass(frozen=True)
class LArrow(LType):
    param_type: LType
    return_type: LType

@dataclass(frozen=True)
class LProduct(LType):
    left: LType
    right: LType

@dataclass(frozen=True)
class LSum(LType):
    left: LType
    right: LType

@dataclass(frozen=True)
class LTerm:
    pass

@dataclass(frozen=True)
class LBegin(LTerm):
    terms: list[LTerm]

@dataclass(frozen=True)
class LVar(LTerm):
    name: str

@dataclass(frozen=True)
class LAbs(LTerm):
    param: LVar
    param_type: LType
    body: LTerm

@dataclass(frozen=True)
class LApp(LTerm):
    func: LTerm
    arg: LTerm

@dataclass(frozen=True)
class LTrue(LTerm):
    pass

@dataclass(frozen=True)
class LFalse(LTerm):
    pass

@dataclass(frozen=True)
class LIf(LTerm):
    condition: LTerm
    then_branch: LTerm
    else_branch: LTerm

@dataclass(frozen=True)
class LDefine(LTerm):
    name: LVar
    body: LTerm

@dataclass(frozen=True)
class LList(LType):
    element_type: LType

@dataclass(frozen=True)
class LNil(LTerm):
    pass

@dataclass(frozen=True)
class LCons(LTerm):
    car: LTerm
    cdr: LTerm

@dataclass(frozen=True)
class LNum(LTerm):
    value: int

@dataclass(frozen=True)
class LAdd(LTerm):
    left: LTerm
    right: LTerm

class ImmutableAssignDict(dict):
    def __setitem__(self, key, value):
        if key in self:
            raise KeyError(f"Cannot overwrite existing key: {key}")
        super().__setitem__(key, value)

TypeEnv = ImmutableAssignDict[LVar, LType]
EvalEnv = ImmutableAssignDict[LVar, LTerm]

In [56]:
def typecheck(gamma_in: TypeEnv, term: LTerm) -> LType:
    gamma = gamma_in.copy()
    if isinstance(term, LVar):
        if term in gamma:
            return gamma[term]
        else:
            raise RuntimeError(f"Unbound variable error: {term.name}")
    elif isinstance(term, LNil):
        return LNilType()
    elif isinstance(term, LNum):
        return LInt()
    elif isinstance(term, LTrue) or isinstance(term, LFalse):
        return LBool()
    elif isinstance(term, LAbs):
        param_type = term.param_type
        gamma[term.param] = param_type
        return LArrow(param_type, typecheck(gamma, term.body))
    elif isinstance(term, LApp):
        func_type = typecheck(gamma, term.func) # arrow type
        arg_type = typecheck(gamma, term.arg)
        if func_type.param_type == arg_type:
            return func_type.return_type
        else:
            raise RuntimeError(f"Application type error: {term}")
    elif isinstance(term, LIf):
        condition_type = typecheck(gamma, term.condition)
        then_type = typecheck(gamma, term.then_branch)
        else_type = typecheck(gamma, term.else_branch)
        if isinstance(condition_type, LBool) and then_type == else_type:
            return else_type
        else:
            raise RuntimeError(f"Conditional type error: {term}")
    elif isinstance(term, LCons):
        car_type = typecheck(gamma, term.car)
        cdr_type = typecheck(gamma, term.cdr)
        # Infer the type of LNil from the car
        if isinstance(cdr_type, LNilType):  # This is the case where the cdr is LNil
            return LList(car_type)

        if isinstance(cdr_type, LList) and cdr_type.element_type == car_type:
            return LList(car_type)
        else:
            raise RuntimeError("Type mismatch in cons cell")
    elif isinstance(term, LDefine):
        body_type = typecheck(gamma, term.body)
        gamma[term.name] = body_type
        return body_type
    elif isinstance(term, LAdd):
        left_type = typecheck(gamma, term.left)
        right_type = typecheck(gamma, term.right)
        if isinstance(left_type, LInt) and isinstance(right_type, LInt):
            return LInt()
        else:
            raise RuntimeError(f"Addition type error: {term}")
    else:
        raise RuntimeError(f"Unknown type case: {term}")

def evaluate(env_in: EvalEnv, term: LTerm) -> LTerm:
    env = env_in.copy()
    if isinstance(term, LAbs) or isinstance(term, LTrue) or isinstance(term, LFalse) or isinstance(term, LInt) or isinstance(term, LCons):
        return term
    elif isinstance(term, LVar):
        return env[term]
    elif isinstance(term, LIf):
        condition = evaluate(env, term.condition)
        if isinstance(condition, LTrue):
            return evaluate(env, term.then_branch)
        elif isinstance(condition, LFalse):
            return evaluate(env, term.else_branch)
        else:
            raise RuntimeError("Huh, the condition is supposed to only be true or false...")
    elif isinstance(term, LApp):
        reduced_func = evaluate(env, term.func) # left app
        reduced_arg = evaluate(env, term.arg) # right app
        # beta-reduction substitution
        closure_env = reduced_func.env.copy()
        closure_env[reduced_func.param] = reduced_arg
        return evaluate(closure_env, reduced_func.body)
    elif isinstance(term, LDefine):
        evaluated_body = evaluate(env, term.body)
        env[term.name] = evaluated_body
        return evaluated_body
    elif isinstance(term, LAdd):
        left_val = evaluate(env, term.left)
        right_val = evaluate(env, term.right)
        if isinstance(left_val, LNum) and isinstance(right_val, LNum):
            return LNum(left_val.value + right_val.value)
        else:
            raise RuntimeError("Runtime error in addition: non-numeric values")
    elif isinstance(term, LBegin):
        for subterm in term.terms[:-1]:
            env = evaluate(env, subterm)
        return evaluate(env, term.terms[-1])
    else:
        raise RuntimeError(f"Evaluation error: {term}")

In [51]:
# Constants
TRUE = LTrue()
FALSE = LFalse()

# Example terms
identity_bool = LAbs(LVar("x"), LBool(), LVar("x"))
applied_identity = LApp(identity_bool, TRUE)

complex_term_no_type_check = LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("y"), LBool(), LVar("y")),
    FALSE
)

complex_term = LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("x"), LBool(), LVar("x")),
    LAbs(LVar("y"), LBool(), LVar("y"))
)

In [53]:
pprint(complex_term)

LIf(condition=LApp(func=LAbs(param=LVar(name='x'),
                             param_type=LBool(),
                             body=LVar(name='x')),
                   arg=LTrue()),
    then_branch=LAbs(param=LVar(name='x'),
                     param_type=LBool(),
                     body=LVar(name='x')),
    else_branch=LAbs(param=LVar(name='y'),
                     param_type=LBool(),
                     body=LVar(name='y')))


In [54]:
gamma: TypeEnv = TypeEnv()
gamma[LVar("x")] = LBool()
basic_term = LVar("x")
a = LAbs(LVar("y"), LBool(), LVar("x"))
b = LApp(a, TRUE)
print(typecheck(gamma, complex_term))
print(typecheck(gamma, basic_term))
pprint(typecheck(gamma, a))
pprint(typecheck(gamma, b))

LArrow(param_type=LBool(), return_type=LBool())
LBool()
LArrow(param_type=LBool(), return_type=LBool())
LBool()


In [55]:
env: EvalEnv = EvalEnv()
x = LVar("x")
env[x] = TRUE
complexer_term = LApp(LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("x"), LBool(), LVar("x")),
    LAbs(LVar("y"), LBool(), LVar("y"))
), LVar("x"))
print(evaluate(env, complexer_term))

AttributeError: 'LAbs' object has no attribute 'env'

In [44]:
list_example = LCons(
    car=LNum(1),
    cdr=LCons(
        car=LNum(2),
        cdr=LCons(
            car=LNum(3),
            cdr=LNil(),
        ),
    ),
)
list_example

LCons(car=LNum(value=1), cdr=LCons(car=LNum(value=2), cdr=LCons(car=LNum(value=3), cdr=LNil())))

In [45]:
typecheck(gamma, list_example), evaluate(env, list_example)

(LList(element_type=LInt()),
 LCons(car=LNum(value=1), cdr=LCons(car=LNum(value=2), cdr=LCons(car=LNum(value=3), cdr=LNil()))))

In [46]:
func_define = LDefine(
    name=LVar("increment"),
    body=LAbs(
        param=LVar("x"),
        param_type=LInt(),
        body=LAdd(LVar("x"), LNum(1))
    )
)

# Apply the function to an argument
func_application = LApp(
    func=LVar("increment"),
    arg=LNum(5)
)

# Type environment and evaluation environment
gamma = ImmutableAssignDict()
env = ImmutableAssignDict()

# Typecheck the function definition
print(typecheck(gamma, func_define))  # Should return LArrow(LInt(), LInt())

# Evaluate the function definition and application
evaluate(env, func_define)            # This will store `increment` in the environment
result = evaluate(env, func_application)  # Apply `increment` to 5

print(result)  # Should output LNum(6)

LArrow(param_type=LInt(), return_type=LInt())


KeyError: LVar(name='increment')