In [None]:
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import sys
import io

import pynini
from pynini import Fst
from pynini.lib import utf8

_U_HEX = re.compile(r"\\u([0-9A-Fa-f]{4})")
_U_HEX_LONG = re.compile(r"\\U([0-9A-Fa-f]{6,8})")
_ESCAPED = re.compile(r"\\([][\\/\-^(){}_.*+?|])")

def _decode_escapes(s: str) -> str:
    """Decodes CLDR-style escapes."""
    def rpl4(m): return chr(int(m.group(1), 16))
    def rpl8(m): return chr(int(m.group(1), 16))
    s = _U_HEX.sub(rpl4, s)
    s = _U_HEX_LONG.sub(rpl8, s)
    s = _ESCAPED.sub(lambda m: m.group(1), s)
    return s

def _char_range(a: str, b: str) -> List[str]:
    return [chr(cp) for cp in range(ord(a), ord(b) + 1)]

def _parse_unicode_set(text: str) -> Tuple[bool, List[str]]:
    """Simple parser for CLDR-style character sets like [a-z] or [^abc]."""
    assert text.startswith("[") and text.endswith("]"), "not a set"
    inner = text[1:-1]
    neg = inner.startswith("^")
    if neg: inner = inner[1:]
    items: List[str] = []
    i = 0
    def read_char(ix: int) -> Tuple[str, int]:
        if ix < len(inner) and inner[ix] == "\\":
            m4 = _U_HEX.match(inner, ix)
            if m4: return (chr(int(m4.group(1), 16)), m4.end())
            m8 = _U_HEX_LONG.match(inner, ix)
            if m8: return (chr(int(m8.group(1), 16)), m8.end())
            if ix + 1 < len(inner): return (inner[ix + 1], ix + 2)
            return ("\\", ix + 1)
        return (inner[ix], ix + 1)
    while i < len(inner):
        c1, j = read_char(i)
        if j < len(inner) - 1 and inner[j] == "-" and j + 1 < len(inner):
            c2, k = read_char(j + 1)
            items.extend(_char_range(c1, c2))
            i = k
        else:
            items.append(c1)
            i = j
    return (neg, items)

def _acceptor(s: str) -> Fst:
    return pynini.accep(s, token_type="utf8")

def _transducer(inp: str, out: str) -> Fst:
    return pynini.cross(pynini.accep(inp, token_type="utf8"), 
                        pynini.accep(out, token_type="utf8"))


@dataclass
class Rule:
    lhs: str
    rhs: str
    op: str
    left: Optional[str] = None
    right: Optional[str] = None

@dataclass
class Directive:
    kind: str
    payload: Optional[str] = None

@dataclass
class VarAssign:
    name: str
    expr: str

Line = Tuple[str, object]

_RULE_CTX_RE = re.compile(r"""
    ^\s*
    (?P<lhs>.+?)
    \s*
    (?P<op><>|>|<)
    \s*
    (?P<rhs>.+?)
    (?:
        /
        \s*
        (?P<L>.*?)
        \s*
        _
        \s*
        (?P<R>.*?)
    )?
    \s*;
    \s*$
""", re.VERBOSE)

_VAR_RE = re.compile(r"""
    ^\s*
    (?P<name>\$[A-Za-z0-9_]+)
    \s*=\s*
    (?P<expr>.+?)
    \s*;?\s*$
""", re.VERBOSE)

def _parse_line(line: str) -> Optional[Line]:
    s = line.strip()
    if not s or s.startswith("#"):
        return None
    
    mvar = _VAR_RE.match(s)
    if mvar and mvar.group("expr").strip():
        name = mvar.group("name")
        expr = mvar.group("expr").strip().rstrip(';')
        return ("var", VarAssign(name=name, expr=expr))
    
    if s.startswith("::"):
        body = s[2:].strip()
        if body.lower() == "null;": return ("dir", Directive(kind="null"))
        if body.lower() == "nfd;": return ("dir", Directive(kind="nfd"))
        if body.lower() == "nfc;": return ("dir", Directive(kind="nfc"))
        if body.startswith("[") and body.endswith(";"):
            payload = body[:-1].strip()
            return ("dir", Directive(kind="filter", payload=payload))
        return ("dir", Directive(kind="unknown", payload=body))
    
    m = _RULE_CTX_RE.match(s)
    if m:
        L_ctx = m.group("L")
        R_ctx = m.group("R")
        return ("rule", Rule(
            lhs=m.group("lhs").strip(),
            rhs=m.group("rhs").strip(),
            op=m.group("op"),
            left=(L_ctx.strip() if L_ctx else None),
            right=(R_ctx.strip() if R_ctx else None),
        ))
    return None

def parse_cldr(text: str) -> List[Line]:
    out: List[Line] = []
    for raw in text.splitlines():
        p = _parse_line(raw)
        if p: out.append(p)
    return out

class Env:
    def __init__(self) -> None:
        self.vars: Dict[str, Fst] = {}
        # Acceptor for all valid UTF-8 characters
        self.sigma: Fst = utf8.VALID_UTF8_CHAR.optimize()
        # Acceptor closure
        self.sigma_star: Fst = self.sigma.closure().optimize()
        
        # Identity Transducer for all valid UTF-8 characters
        I_char = pynini.cross(self.sigma, self.sigma)
        self.I_sigma_star: Fst = I_char.closure().optimize()

    def get(self, name: str) -> Fst:
        if name not in self.vars:
            raise KeyError(f"Undefined variable {name}")
        return self.vars[name]

def _compile_atom(expr: str, env: Env) -> Fst:
    expr = expr.strip()
    if expr.startswith("$"):
        return env.get(expr)
    if expr.startswith("[") and expr.endswith("]"):
        neg, items = _parse_unicode_set(expr)
        if not items: 
            return env.sigma if neg else pynini.Fst()
        u = pynini.union(*(_acceptor(ch) for ch in items)).optimize()
        if neg:
            return (env.sigma - u).optimize()
        return u

    return _acceptor(_decode_escapes(expr))

def _compile_seq(expr: Optional[str], env: Env) -> Fst:
    """Compiles a sequence of atoms (literals, sets, variables) into an Fst acceptor."""
    if not expr: return _acceptor("")
    s = expr
    parts: List[str] = []
    i = 0
    cur = []
    depth = 0
    
    # Simple tokenizer logic
    while i < len(s):
        ch = s[i]
        is_escape_char = (ch == '\\') and (i + 1 < len(s) and s[i+1].isalpha())
        
        if ch == "[" and not is_escape_char: depth += 1
        elif ch == "]" and depth > 0 and not is_escape_char: depth -= 1
        elif ch.isspace() and depth == 0:
            if cur: parts.append("".join(cur)); cur = []
            i += 1
            continue
        
        cur.append(ch)
        i += 1
        if is_escape_char and i < len(s):
            pass 
            
    if cur: parts.append("".join(cur))
    if not parts: return _acceptor("")

    # Composition of FST atoms
    fst = _compile_atom(parts[0], env)
    for p in parts[1:]:
        fst = fst + _compile_atom(p, env)
    return fst.optimize()

def compile_lines(lines: List[Line]) -> Fst:
    env = Env()
    
    vars_to_compile: List[VarAssign] = []
    for kind, payload in lines:
        if kind == "var":
            vars_to_compile.append(payload) # type: ignore[arg-type]

    for va in vars_to_compile:
        env.vars[va.name] = _compile_seq(va.expr, env).optimize()
    
    all_rules: List[Rule] = []
    for kind, payload in lines:
        if kind == "rule":
            all_rules.append(payload) # type: ignore[arg-type]

    # Start with identity transducer
    cascade = env.I_sigma_star.copy()
    
    for r in all_rules:
        lhs = _compile_seq(r.lhs, env)
        rhs = _compile_seq(r.rhs, env)
        
        Lctx = _compile_seq(r.left, env) if r.left else pynini.accep("", token_type="utf8")
        Rctx = _compile_seq(r.right, env) if r.right else pynini.accep("", token_type="utf8")
        
        if lhs.num_states() == 0: 
            print(f"Warning: Empty LHS FST for rule {r.lhs} > {r.rhs}. Skipping.")
            continue
        
        # Extract strings and rebuild with proper token types
        try:
            lhs_str = lhs.string(token_type="utf8")
            if isinstance(lhs_str, bytes):
                lhs_str = lhs_str.decode('utf-8')
        except:
            print(f"Warning: Could not extract LHS string for rule {r.lhs}")
            continue
            
        try:
            rhs_str = rhs.string(token_type="utf8")
            if isinstance(rhs_str, bytes):
                rhs_str = rhs_str.decode('utf-8')
        except:
            print(f"Warning: Could not extract RHS string for rule {r.rhs}")
            continue
        
        # Create fresh acceptors with UTF-8 token type
        lhs_new = pynini.accep(lhs_str, token_type="utf8")
        rhs_new = pynini.accep(rhs_str, token_type="utf8")
        
        # Build tau with cross product
        tau = pynini.cross(lhs_new, rhs_new)
        
        if r.op == ">":
            rewrite = pynini.cdrewrite(tau, Lctx, Rctx, env.sigma_star)
            cascade = pynini.compose(cascade, rewrite)
        elif r.op == "<":
            rewrite = pynini.cdrewrite(tau.invert(), Rctx, Lctx, env.sigma_star)
            cascade = pynini.compose(cascade, rewrite)
        elif r.op == "<>":
            rewrite1 = pynini.cdrewrite(tau, Lctx, Rctx, env.sigma_star)
            cascade = pynini.compose(cascade, rewrite1)
            rewrite2 = pynini.cdrewrite(tau.invert(), Rctx, Lctx, env.sigma_star)
            cascade = pynini.compose(cascade, rewrite2)

    # Optimize only once at the very end, more carefully
    cascade.rmepsilon()
    return cascade

def prefix_language_token(lang_token: str) -> Fst:
    # Maps <lang_token> (input) to "" (output), deleting the prefix.
    return pynini.cross(_acceptor(lang_token), _acceptor("")) 

def build_multilingual_transducer(lang_to_rules_text: Dict[str, str], token_fmt: str = "<{lang}>") -> Fst:
    unified: Optional[Fst] = None
    for lang, text in lang_to_rules_text.items():
        lines = parse_cldr(text)
        t = compile_lines(lines)
        # Sequence: [ <lang> : "" ] . [ rules_for_lang ]
        lt = pynini.concat(prefix_language_token(token_fmt.format(lang=lang)), t)
        lt = lt.optimize()
        unified = lt if unified is None else pynini.union(unified, lt)
    if unified is None:
        unified = pynini.transducer("", "")
    unified = unified.optimize()
    return unified

def save_fst(fst: Fst, path: str) -> None:
    fst.write(path)

def demo():
    cldr_en = r"""
        $V = [aeiou] ; 
        th > θ ;
        sh > ʃ ;
        ng > ŋ ;
        ch > t͡ʃ ;
        c > s / _ [ei] ; 
        $V > a ;
    """

    cldr_es = r"""
        $V = [aeiou] ;
        ll > ʎ ;
        ñ > ɲ ;
        qu > k ;
        c > s / _ [ei] ; 
        c > k ;
        z > s ;
        $V > a ;
    """

    lang_rules = {"eng": cldr_en, "spa": cldr_es}

    fst = build_multilingual_transducer(lang_rules, token_fmt="<{lang}>")
    save_fst(fst, "multilang.fst")

    tests = [
        "<eng>thing", 
        "<eng>mashing", 
        "<spa>llama", 
        "<spa>quiza", 
        "<spa>cita", 
        "<spa>cuna"
    ]

    def apply(s: str) -> str:
        input_fst = _acceptor(s)
        lat = pynini.compose(input_fst, fst)
        
        if lat.start() == pynini.NO_STATE_ID:
            return "<no-path>"

        # Get shortest path
        shortest = pynini.shortestpath(lat)
        
        # Project to output and extract
        output_only = shortest.project("output")
        output_only.rmepsilon()
        
        # Try the simple string extraction
        try:
            result = output_only.string(token_type="utf8")
            if isinstance(result, bytes):
                return result.decode('utf-8')
            return str(result)
        except:
            # Fallback to stringify
            full_str = pynini.stringify(shortest, token_type="utf8")
            if " : " in full_str:
                return full_str.split(" : ", 1)[1]
            return full_str

    for t in tests:
        print(f"{t} -> {apply(t)}")
        
    print("\nWrote multilang.fst")

if __name__ == "__main__":
    demo()