ADT Pattern in Python
- https://github.com/sabrinahu5/program-synthesis/blob/main/interpreter/flashfill-interpreter/ff-interpreter.py

In [3]:
from pprint import pprint
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

In [4]:
@dataclass(frozen=True)
class LType:
    pass

@dataclass(frozen=True)
class LBool(LType):
    pass

@dataclass(frozen=True)
class LArrow(LType):
    param_type: LType
    return_type: LType

@dataclass(frozen=True)
class LTerm:
    pass

@dataclass(frozen=True)
class LVar(LTerm):
    name: str

@dataclass(frozen=True)
class LAbs(LTerm):
    param: LVar
    param_type: LType
    body: LTerm

@dataclass(frozen=True)
class LApp(LTerm):
    func: LTerm
    arg: LTerm

@dataclass(frozen=True)
class LTrue(LTerm):
    pass

@dataclass(frozen=True)
class LFalse(LTerm):
    pass

@dataclass(frozen=True)
class LIf(LTerm):
    condition: LTerm
    then_branch: LTerm
    else_branch: LTerm

class ImmutableAssignDict(dict):
    def __setitem__(self, key, value):
        if key in self:
            raise KeyError(f"Cannot overwrite existing key: {key}")
        super().__setitem__(key, value)

TypeEnv = ImmutableAssignDict[LVar, LType]

In [14]:
class LTypeError(Exception):
    pass

def typecheck(gamma_in: TypeEnv, term: LTerm) -> Optional[LType]:
    gamma = gamma_in.copy()
    if isinstance(term, LVar):
        if term in gamma:
            return gamma[term]
        else:
            raise LTypeError(f"Unbound variable: {term.name}")
    elif isinstance(term, LTrue) or isinstance(term, LFalse):
        return LBool()
    elif isinstance(term, LAbs):
        param_type = term.param_type
        gamma[term.param] = param_type
        return LArrow(param_type, typecheck(gamma, term.body))
    elif isinstance(term, LApp):
        func_type = typecheck(gamma, term.func)
        arg_type = typecheck(gamma, term.arg)
        pass

In [15]:
# Constants
TRUE = LTrue()
FALSE = LFalse()

# Example terms
identity_bool = LAbs(LVar("x"), LBool(), LVar("x"))
applied_identity = LApp(identity_bool, TRUE)

complex_term = LIf(
    LApp(LAbs(LVar("x"), LBool(), LVar("x")), TRUE),
    LAbs(LVar("y"), LBool(), LVar("y")),
    FALSE
)

In [7]:
pprint(complex_term)

LIf(condition=LApp(func=LAbs(param=LVar(name='x'),
                             param_type=LBool(),
                             body=LVar(name='x')),
                   arg=LTrue()),
    then_branch=LAbs(param=LVar(name='y'),
                     param_type=LBool(),
                     body=LVar(name='y')),
    else_branch=LFalse())


In [16]:
gamma: TypeEnv = TypeEnv()
gamma[LVar("x")] = LBool()
basic_term = LVar("x")
print(typecheck(gamma, basic_term))
print(typecheck(gamma, complex_term))

LBool()
None
