In [1]:
# hott_simplified.py
# A small, pedagogical type-checker for simplified HoTT / Univalent Foundations.
# Date: 2025-10-23

from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

# ---------- AST / Terms ----------
@dataclass(frozen=True)
class Term: ...
@dataclass(frozen=True)
class TypeUniverse(Term):
    # Universe levels as integers (very small model)
    level: int = 0

@dataclass(frozen=True)
class Var(Term):
    name: str

@dataclass(frozen=True)
class Pi(Term):
    # (x : A) -> B(x)
    x: str
    A: Term
    B: Term  # B may refer to x

@dataclass(frozen=True)
class Lambda(Term):
    x: str
    body: Term

@dataclass(frozen=True)
class App(Term):
    fn: Term
    arg: Term

@dataclass(frozen=True)
class Sigma(Term):
    x: str
    A: Term
    B: Term

@dataclass(frozen=True)
class Pair(Term):
    fst: Term
    snd: Term

@dataclass(frozen=True)
class Path(Term):
    # Path(A; a, b) : type of paths between a and b in A
    A: Term
    a: Term
    b: Term

@dataclass(frozen=True)
class Refl(Term):
    # refl_a : Path(A; a, a)
    a: Term

@dataclass(frozen=True)
class Hole(Term):
    # placeholder for incomplete terms
    name: str

# Naive representation for 'equivalence' used with univalence postulate
@dataclass(frozen=True)
class Equiv(Term):
    A: Term
    B: Term
    f: Term  # forward function
    g: Term  # inverse-ish function (naive)

# Represent a user postulate/axiom by name -> type
@dataclass
class Postulate:
    name: str
    typ: Term

# ---------- Substitution utility ----------
def subst(term: Term, var: str, value: Term) -> Term:
    """Capture-avoiding substitution (very small - assumes distinct var names)."""
    # For pedagogy, assume no variable capture; nested lambdas may reuse names — user must avoid.
    if isinstance(term, Var):
        return value if term.name == var else term
    if isinstance(term, TypeUniverse):
        return term
    if isinstance(term, Pi):
        if term.x == var:
            return term
        return Pi(term.x, subst(term.A, var, value), subst(term.B, var, value))
    if isinstance(term, Lambda):
        if term.x == var:
            return term
        return Lambda(term.x, subst(term.body, var, value))
    if isinstance(term, App):
        return App(subst(term.fn, var, value), subst(term.arg, var, value))
    if isinstance(term, Sigma):
        if term.x == var:
            return term
        return Sigma(term.x, subst(term.A, var, value), subst(term.B, var, value))
    if isinstance(term, Pair):
        return Pair(subst(term.fst, var, value), subst(term.snd, var, value))
    if isinstance(term, Path):
        return Path(subst(term.A, var, value), subst(term.a, var, value), subst(term.b, var, value))
    if isinstance(term, Refl):
        return Refl(subst(term.a, var, value))
    if isinstance(term, Equiv):
        return Equiv(subst(term.A, var, value), subst(term.B, var, value),
                     subst(term.f, var, value), subst(term.g, var, value))
    if isinstance(term, Hole):
        return term
    raise NotImplementedError(f"subst for {type(term)}")

# ---------- Context and environment ----------
@dataclass
class Ctx:
    # variable name -> type
    gamma: Dict[str, Term]
    postulates: Dict[str, Postulate]
    # Known equivalences (for the naive univalence handling):
    known_equivs: List[Equiv]

    def extend(self, name: str, typ: Term) -> "Ctx":
        new = Ctx(self.gamma.copy(), self.postulates.copy(), list(self.known_equivs))
        new.gamma[name] = typ
        return new

    def add_postulate(self, name: str, typ: Term):
        self.postulates[name] = Postulate(name, typ)

    def add_equiv(self, e: Equiv):
        self.known_equivs.append(e)

# ---------- Judgmental conversion (very conservative) ----------
def conv(ctx: Ctx, A: Term, B: Term) -> bool:
    """
    Convertibility / judgmental equality in this simple system.
    This is intentionally conservative: alpha-equivalence, beta-reduction for simple cases,
    and trivial refl-based path normalization only.
    """
    # alpha-eq by structure for our simple AST:
    if type(A) != type(B):
        # Special case: allow TypeUniverse(level n) equal to TypeUniverse(level m) if n==m
        if isinstance(A, TypeUniverse) and isinstance(B, TypeUniverse):
            return A.level == B.level
        return False
    if isinstance(A, Var):
        return A.name == B.name
    if isinstance(A, TypeUniverse):
        return A.level == B.level
    if isinstance(A, Pi):
        # compare A types, then B with same binder name
        return conv(ctx, A.A, B.A) and conv(ctx, A.B, subst(B.B, B.x, Var(A.x)))
    if isinstance(A, Lambda) and isinstance(B, Lambda):
        return conv(ctx, A.body, B.body)
    if isinstance(A, App) and isinstance(B, App):
        # try beta-reduction if fn is Lambda
        if isinstance(A.fn, Lambda):
            reduced = subst(A.fn.body, A.fn.x, A.arg)
            return conv(ctx, reduced, B)
        if isinstance(B.fn, Lambda):
            reduced = subst(B.fn.body, B.fn.x, B.arg)
            return conv(ctx, A, reduced)
        return conv(ctx, A.fn, B.fn) and conv(ctx, A.arg, B.arg)
    if isinstance(A, Sigma):
        return conv(ctx, A.A, B.A) and conv(ctx, A.B, B.B)
    if isinstance(A, Pair):
        return conv(ctx, A.fst, B.fst) and conv(ctx, A.snd, B.snd)
    if isinstance(A, Path):
        return conv(ctx, A.A, B.A) and conv(ctx, A.a, B.a) and conv(ctx, A.b, B.b)
    if isinstance(A, Refl) and isinstance(B, Refl):
        return conv(ctx, A.a, B.a)
    if isinstance(A, Equiv) and isinstance(B, Equiv):
        return conv(ctx, A.A, B.A) and conv(ctx, A.B, B.B) and conv(ctx, A.f, B.f) and conv(ctx, A.g, B.g)
    if isinstance(A, Hole) and isinstance(B, Hole):
        return A.name == B.name
    # default conservative
    return False

# ---------- Bidirectional type-checking ----------
class TypeErrorEx(Exception):
    pass

def infer(ctx: Ctx, t: Term) -> Term:
    """Synthesize type of term t under context ctx."""
    if isinstance(t, Var):
        if t.name in ctx.gamma:
            return ctx.gamma[t.name]
        raise TypeErrorEx(f"Unknown variable {t.name}")
    if isinstance(t, TypeUniverse):
        # Type : Type_{level+1}
        return TypeUniverse(t.level + 1)
    if isinstance(t, Pi):
        # check that A : Type_i and B : Type_j under extended context
        A_ty = infer(ctx, t.A)
        if not isinstance(A_ty, TypeUniverse):
            raise TypeErrorEx("Pi domain not a universe")
        ctx2 = ctx.extend(t.x, t.A)
        B_ty = infer(ctx2, t.B)
        if not isinstance(B_ty, TypeUniverse):
            raise TypeErrorEx("Pi codomain not a universe")
        # universe level is max(A,B)
        return TypeUniverse(max(A_ty.level, B_ty.level))
    if isinstance(t, Lambda):
        raise TypeErrorEx("Cannot infer type of bare lambda; use check(expected_type)")
    if isinstance(t, App):
        fn_ty = infer(ctx, t.fn)
        if isinstance(fn_ty, Pi):
            check(ctx, t.arg, fn_ty.A)
            return subst(fn_ty.B, fn_ty.x, t.arg)
        else:
            raise TypeErrorEx(f"Expected function type in application; got {fn_ty}")
    if isinstance(t, Sigma):
        A_ty = infer(ctx, t.A)
        if not isinstance(A_ty, TypeUniverse):
            raise TypeErrorEx("Sigma domain not universe")
        ctx2 = ctx.extend(t.x, t.A)
        B_ty = infer(ctx2, t.B)
        if not isinstance(B_ty, TypeUniverse):
            raise TypeErrorEx("Sigma codomain not universe")
        return TypeUniverse(max(A_ty.level, B_ty.level))
    if isinstance(t, Pair):
        A = infer(ctx, t.fst)
        if not isinstance(A, TypeUniverse) and A is None:
            # first component's type must be known by context
            raise TypeErrorEx("Cannot infer pair fst type")
        # Synthesize as Sigma type only if we can get type of snd from context (not implemented)
        raise TypeErrorEx("Cannot infer pair type; use check with Sigma")
    if isinstance(t, Path):
        # Path(A; a, b) : TypeUniverse(level of A)
        A_ty = infer(ctx, t.A)
        if not isinstance(A_ty, TypeUniverse):
            raise TypeErrorEx("Path over non-universe")
        check(ctx, t.a, t.A)
        check(ctx, t.b, t.A)
        return TypeUniverse(A_ty.level)
    if isinstance(t, Refl):
        # infer type: Path(A; a, a) when we can determine A from a
        typ = infer(ctx, t.a)
        # If `typ` is a universe type (i.e., t.a is a type), then refl is path in a universe — allow
        # but more common: a : A, so we want Path(A; a, a)
        # For simplicity, if the inferred type of a is not helpful, we need user to provide expected type.
        # We'll allow a : A when a's type is known in ctx in some variable name? Conservative:
        raise TypeErrorEx("Cannot infer type of refl; use check with Path(A, a, a)")
    if isinstance(t, Equiv):
        # an Equiv is not a type but a witness; infer its 'type' as a sigma-ish structure
        # For simplicity, treat as a Postulated object; return TypeUniverse
        return TypeUniverse(0)
    if isinstance(t, Hole):
        raise TypeErrorEx(f"Cannot infer hole {t.name}")
    raise TypeErrorEx(f"infer not implemented for {type(t)}")

def check(ctx: Ctx, t: Term, expected: Term) -> bool:
    """Check that term t has type expected under ctx."""
    # If expected is a Pi and t is a lambda, check the body under extended context
    if isinstance(expected, Pi) and isinstance(t, Lambda):
        ctx2 = ctx.extend(t.x, expected.A)
        return check(ctx2, t.body, expected.B)
    if isinstance(t, Lambda) and isinstance(expected, TypeUniverse):
        raise TypeErrorEx("Cannot check lambda against a universe")
    # If expected is a Path and t is Refl: check refl(a) : Path(A; a, a)
    if isinstance(expected, Path) and isinstance(t, Refl):
        # ensure t.a : A
        check(ctx, t.a, expected.A)
        if not conv(ctx, t.a, expected.a) or not conv(ctx, t.a, expected.b):
            # require that expected.a and expected.b are judgmentally equal to t.a
            raise TypeErrorEx("refl term endpoints mismatch expected Path endpoints")
        return True
    # If expected is Path and t is a Path term: ensure components check
    if isinstance(expected, Path) and isinstance(t, Path):
        check(ctx, t.A, TypeUniverse(0))
        check(ctx, t.a, t.A)
        check(ctx, t.b, t.A)
        return conv(ctx, expected.A, t.A) and conv(ctx, expected.a, t.a) and conv(ctx, expected.b, t.b)
    # If expected is Universe and term is a type (Pi/Sigma/Universe) we can infer its universe
    if isinstance(expected, TypeUniverse):
        if isinstance(t, Pi) or isinstance(t, Sigma) or isinstance(t, TypeUniverse):
            # synthesize and compare universe levels
            t_ty = infer(ctx, t)
            if not isinstance(t_ty, TypeUniverse):
                raise TypeErrorEx("Expected a universe type but inference failed")
            # require universe levels compatible
            return t_ty.level <= expected.level
        # if t is an Equiv, it's OK as a witness, accept (postulate)
        if isinstance(t, Equiv):
            return True
    # fallback: try to infer and conv
    try:
        t_ty = infer(ctx, t)
        if conv(ctx, t_ty, expected):
            return True
    except TypeErrorEx:
        pass
    raise TypeErrorEx(f"Type check failed: term {t} does not have expected type {expected}")

# ---------- Univalence (naive) ----------
def ua(ctx: Ctx, e: Equiv) -> Term:
    """
    Naive Univalence: from an equivalence between types A and B produce a Path(A, B).
    In real HoTT this is an axiom or theorem requiring higher structure; here we model it as a postulate.
    """
    # Check that e.f : A -> B and e.g : B -> A (naively)
    # Build Path(A, B) as a postulated term whose type-check will be Path(Type, A, B)
    # Verify domain are types
    check(ctx, e.A, TypeUniverse(0))
    check(ctx, e.B, TypeUniverse(0))
    # we won't attempt to verify homotopy inverses; we simply register the equivalence
    ctx.add_equiv(e)
    # Return a special term representing "ua(e)"
    return EquivPath(e)

@dataclass(frozen=True)
class EquivPath(Term):
    equiv: Equiv

# Extend conv to compare EquivPath
def conv(ctx: Ctx, A: Term, B: Term) -> bool:
    # reuse previous conv but add EquivPath handling
    # call the earlier conv implementation with different name; to avoid recursion override, reimplement minimally:
    if type(A) != type(B):
        if isinstance(A, TypeUniverse) and isinstance(B, TypeUniverse):
            return A.level == B.level
        # special: EquivPath vs Path? treat EquivPath(e) as Path(A,B)
        if isinstance(A, EquivPath) and isinstance(B, Path):
            return conv(ctx, A.equiv.A, B.A) and conv(ctx, A.equiv.B, B.A) and conv(ctx, A.equiv.B, B.b)  # naive - but keep consistent
        if isinstance(B, EquivPath) and isinstance(A, Path):
            return conv(ctx, A.A, B.equiv.A) and conv(ctx, A.a, B.equiv.A) and conv(ctx, A.b, B.equiv.B)
        return False
    # structural compare for common cases (same as before)
    if isinstance(A, Var):
        return A.name == B.name
    if isinstance(A, TypeUniverse):
        return A.level == B.level
    if isinstance(A, Pi):
        return conv(ctx, A.A, B.A) and conv(ctx, A.B, B.B)
    if isinstance(A, Lambda) and isinstance(B, Lambda):
        return conv(ctx, A.body, B.body)
    if isinstance(A, App) and isinstance(B, App):
        return conv(ctx, A.fn, B.fn) and conv(ctx, A.arg, B.arg)
    if isinstance(A, Sigma):
        return conv(ctx, A.A, B.A) and conv(ctx, A.B, B.B)
    if isinstance(A, Pair):
        return conv(ctx, A.fst, B.fst) and conv(ctx, A.snd, B.snd)
    if isinstance(A, Path):
        return conv(ctx, A.A, B.A) and conv(ctx, A.a, B.a) and conv(ctx, A.b, B.b)
    if isinstance(A, Refl) and isinstance(B, Refl):
        return conv(ctx, A.a, B.a)
    if isinstance(A, Equiv) and isinstance(B, Equiv):
        return conv(ctx, A.A, B.A) and conv(ctx, A.B, B.B) and conv(ctx, A.f, B.f) and conv(ctx, A.g, B.g)
    if isinstance(A, EquivPath) and isinstance(B, EquivPath):
        return conv(ctx, A.equiv.A, B.equiv.A) and conv(ctx, A.equiv.B, B.equiv.B)
    if isinstance(A, EquivPath) and isinstance(B, Path):
        # treat as Path(A,B) for convenience
        return conv(ctx, A.equiv.A, B.A) and conv(ctx, A.equiv.B, B.b)
    if isinstance(B, EquivPath) and isinstance(A, Path):
        return conv(ctx, A.A, B.equiv.A) and conv(ctx, A.b, B.equiv.B)
    if isinstance(A, Hole) and isinstance(B, Hole):
        return A.name == B.name
    return False

# ---------- Higher Inductive Types (skeleton) ----------
# We give an example: the circle S1 with constructors base : S1 and loop : Path(S1, base, base)
@dataclass(frozen=True)
class HitDecl:
    name: str
    typ: Term
    constructors: List[Tuple[str, Term]]  # name and type

# A registry for HITs
HIT_REGISTRY: Dict[str, HitDecl] = {}

def declare_circle():
    # S1 : Type
    S1 = Var("S1_type")
    base = Var("base")
    loop_path = Path(S1, base, base)
    decl = HitDecl("S1", S1, [("base", S1), ("loop", loop_path)])
    HIT_REGISTRY["S1"] = decl
    return decl

# ---------- Small interface helpers ----------
def pretty(t: Term) -> str:
    if isinstance(t, Var): return t.name
    if isinstance(t, TypeUniverse): return f"Type{t.level}"
    if isinstance(t, Pi): return f"(Π {t.x}:{pretty(t.A)}. {pretty(t.B)})"
    if isinstance(t, Lambda): return f"(λ{t.x}. {pretty(t.body)})"
    if isinstance(t, App): return f"({pretty(t.fn)} {pretty(t.arg)})"
    if isinstance(t, Sigma): return f"(Σ {t.x}:{pretty(t.A)}. {pretty(t.B)})"
    if isinstance(t, Pair): return f"({pretty(t.fst)}, {pretty(t.snd)})"
    if isinstance(t, Path): return f"Path({pretty(t.A)}; {pretty(t.a)}, {pretty(t.b)})"
    if isinstance(t, Refl): return f"refl({pretty(t.a)})"
    if isinstance(t, Equiv): return f"Equiv({pretty(t.A)}↔{pretty(t.B)})"
    if isinstance(t, EquivPath): return f"ua({pretty(t.equiv.A)}↔{pretty(t.equiv.B)})"
    if isinstance(t, Hole): return f"?{t.name}"
    return str(t)

# ---------- Examples / small tests ----------
def examples():
    ctx = Ctx({}, {}, [])
    # Universe levels
    U0 = TypeUniverse(0)
    U1 = TypeUniverse(1)

    # Declare A : Type0, B : Type0
    A = Var("A")
    B = Var("B")
    ctx = ctx.extend("A", U0).extend("B", U0)

    # Example: identity function id : Π (x : A). A
    pi_A_to_A = Pi("x", A, A)
    id_term = Lambda("x", Var("x"))
    # check id : Π x:A. A
    try:
        ok = check(ctx.extend("A", U0), id_term, pi_A_to_A)
        print("id : Π x:A. A  -- ok")
    except TypeErrorEx as e:
        print("id checking failed:", e)

    # Path example: refl
    # Suppose we have a : A (variable)
    ctx2 = ctx.extend("a", A)
    refl_a = Refl(Var("a"))
    path_type = Path(A, Var("a"), Var("a"))
    try:
        ok = check(ctx2, refl_a, path_type)
        print("refl(a) : Path(A,a,a) -- ok")
    except TypeErrorEx as e:
        print("refl checking failed:", e)

    # Naive univalence:
    # Suppose we have f : A -> B and g : B -> A as Vars (user-provided)
    ctx3 = ctx.extend("f", Pi("x", A, B)).extend("g", Pi("y", B, A))
    e = Equiv(A, B, Var("f"), Var("g"))
    # Use ua(e) to yield an EquivPath
    try:
        p = ua(ctx3, e)
        # type-check: ua(e) should be usable where a Path between A and B is expected
        expected = Path(TypeUniverse(0), A, B)  # Path(Type, A, B) (A and B are themselves types)
        # check that p is of that path type — we emulate by asserting components match
        print("ua(e) produced:", pretty(p))
        # naive conv check: EquivPath is treated like a Path(A,B)
        print("conv Path(A,B) with ua(e)?", conv(ctx3, expected, Path(A, A, B)))
    except TypeErrorEx as e:
        print("ua failed:", e)

    # HIT declaration
    s1 = declare_circle()
    print("Declared HIT:", s1.name, "with constructors:", s1.constructors)

if __name__ == "__main__":
    examples()


id : Π x:A. A  -- ok
refl(a) : Path(A,a,a) -- ok
ua(e) produced: ua(A↔B)
conv Path(A,B) with ua(e)? False
Declared HIT: S1 with constructors: [('base', Var(name='S1_type')), ('loop', Path(A=Var(name='S1_type'), a=Var(name='base'), b=Var(name='base')))]
