ADT Pattern in Python
- https://github.com/sabrinahu5/program-synthesis/blob/main/interpreter/flashfill-interpreter/ff-interpreter.py

In [2]:
from pprint import pprint
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

In [3]:
@dataclass(frozen=True)
class LType:
    pass

@dataclass(frozen=True)
class LBool(LType):
    pass

@dataclass(frozen=True)
class LArrow(LType):
    param_type: LType
    return_type: LType

@dataclass(frozen=True)
class LTerm:
    pass

@dataclass(frozen=True)
class LVar(LTerm):
    name: str

@dataclass(frozen=True)
class LAbs(LTerm):
    param: LVar
    param_type: LType
    body: LTerm

@dataclass(frozen=True)
class LApp(LTerm):
    func: LTerm
    arg: LTerm

@dataclass(frozen=True)
class LTrue(LTerm):
    pass

@dataclass(frozen=True)
class LFalse(LTerm):
    pass

@dataclass(frozen=True)
class LIf(LTerm):
    condition: LTerm
    then_branch: LTerm
    else_branch: LTerm

class ImmutableAssignDict(dict):
    def __setitem__(self, key, value):
        if key in self:
            raise KeyError(f"Cannot overwrite existing key: {key}")
        super().__setitem__(key, value)

TypeEnv = ImmutableAssignDict[LVar, LType]
EvalEnv = ImmutableAssignDict[LVar, LTerm]

In [4]:
class LTypeError(Exception):
    pass

def typecheck(gamma_in: TypeEnv, term: LTerm) -> LType:
    gamma = gamma_in.copy()
    if isinstance(term, LVar):
        if term in gamma:
            return gamma[term]
        else:
            raise LTypeError(f"Unbound variable error: {term.name}")
    elif isinstance(term, LTrue) or isinstance(term, LFalse):
        return LBool()
    elif isinstance(term, LAbs):
        param_type = term.param_type
        gamma[term.param] = param_type
        return LArrow(param_type, typecheck(gamma, term.body))
    elif isinstance(term, LApp):
        func_type = typecheck(gamma, term.func) # arrow type
        arg_type = typecheck(gamma, term.arg)
        if func_type.param_type == arg_type:
            return func_type.return_type
        else:
            raise LTypeError(f"Application type error: {term}")
    elif isinstance(term, LIf):
        condition_type = typecheck(gamma, term.condition)
        then_type = typecheck(gamma, term.then_branch)
        else_type = typecheck(gamma, term.else_branch)
        if isinstance(condition_type, LBool) and then_type == else_type:
            return else_type
        else:
            raise LTypeError(f"Conditional type error: {term}")
    else:
        raise LTypeError(f"Unknown type case: {term}")

def evaluate(env_in: EvalEnv, term: LTerm) -> LTerm:
    env = env_in.copy()
    if isinstance(term, LAbs) or isinstance(term, LTrue) or isinstance(term, LFalse):
        # base case: abstraction or boolean, return that
        return term
    # elif isinstance(term, LDefine): # useful for Lisp, but this is STLC
    #     env[term.param] = eval(env, term.body)
    elif isinstance(term, LVar):
        return env[term]
    elif isinstance(term, LIf):
        condition = evaluate(env, term.condition)
        if isinstance(condition, LTrue):
            return evaluate(env, term.then_branch)
        elif isinstance(condition, LFalse):
            return evaluate(env, term.else_branch)
        else:
            raise RuntimeError("Huh, the condition is supposed to only be true or false...")
    elif isinstance(term, LApp):
        reduced_func = evaluate(env, term.func) # left app
        reduced_arg = evaluate(env, term.arg) # right app
        # beta-reduction substitution
        env[reduced_func.param] = reduced_arg
        return evaluate(env, reduced_func.body)

In [5]:
# Constants
TRUE = LTrue()
FALSE = LFalse()

# Example terms
identity_bool = LAbs(LVar("x"), LBool(), LVar("x"))
applied_identity = LApp(identity_bool, TRUE)

complex_term_no_type_check = LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("y"), LBool(), LVar("y")),
    FALSE
)

complex_term = LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("x"), LBool(), LVar("x")),
    LAbs(LVar("y"), LBool(), LVar("y"))
)

In [6]:
pprint(complex_term)

LIf(condition=LApp(func=LAbs(param=LVar(name='x'),
                             param_type=LBool(),
                             body=LVar(name='x')),
                   arg=LTrue()),
    then_branch=LAbs(param=LVar(name='x'),
                     param_type=LBool(),
                     body=LVar(name='x')),
    else_branch=LAbs(param=LVar(name='y'),
                     param_type=LBool(),
                     body=LVar(name='y')))


In [7]:
gamma: TypeEnv = TypeEnv()
gamma[LVar("x")] = LBool()
basic_term = LVar("x")
a = LAbs(LVar("y"), LBool(), LVar("x"))
b = LApp(a, TRUE)
print(typecheck(gamma, complex_term))
print(typecheck(gamma, basic_term))
pprint(typecheck(gamma, a))
pprint(typecheck(gamma, b))

LArrow(param_type=LBool(), return_type=LBool())
LBool()
LArrow(param_type=LBool(), return_type=LBool())
LBool()


In [8]:
env: EvalEnv = EvalEnv()
x = LVar("x")
env[x] = TRUE
complexer_term = LApp(LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("x"), LBool(), LVar("x")),
    LAbs(LVar("y"), LBool(), LVar("y"))
), LVar("x"))
print(evaluate(env, complexer_term))

LTrue()
