In [9]:
from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from fractions import Fraction
from types import SimpleNamespace
from typing import Any, Callable, Protocol, Union, cast, overload
from typing_extensions import Self, TypeGuard
import sympy
from varname import varname, ImproperUseError

# from ..physics import Unit
# from ..lang import Text


# TODO: Use real Text class
class Text:
    def __init__(self, *args, **kwargs) -> None:
        self.args = args
        self.kwargs = kwargs

    @property
    def default(self):
        return self.args[0]


VarIdentifier = Any #Union[str, sympy.Symbol]

def is_var_identifier(value: Any) -> TypeGuard[VarIdentifier]:
    return True
    return isinstance(value, (str, sympy.Symbol))


def var_name(identifier: VarIdentifier):
    if isinstance(identifier, sympy.Symbol):
        return identifier.name
    else:
        return str(identifier)


def _infer_scope_name():
    try:
        return cast(str, varname(2))
    except ImproperUseError:
        return "unknown"
        # raise ValueError(
        #     f"Could not infer scope name. Please specify one or directly assign the scope to a variable to use its name."
        # )


class CancelScopeAction(Exception):
    pass


class Module:
    def __init__(self, next: Module) -> None:
        self.next = next

    @property
    def last(self):
        if self.next is self:
            return self
        else:
            return self.next.last

    def on_set_prop(self, var: VarIdentifier, prop: str, value: Any):
        self.next.on_set_prop(var, prop, value)

    def on_define(self, name: str, props: dict[str, Any]):
        self.next.on_define(name, props)

    @property
    def known_props(self) -> set[str]:
        return set()


class PropModule(Module):
    prop: str

    def on_set(self, var: VarIdentifier, value: Any) -> None:
        self.next.on_set_prop(var, self.prop, value)

    def on_set_prop(self, var: VarIdentifier, prop: str, value: Any):
        if prop == self.prop:
            self.on_set(var, value)
        else:
            super().on_set_prop(var, prop, value)

    @property
    def known_props(self) -> set[str]:
        return {self.prop}


class EndModule(Module):
    def __init__(self, space: Space) -> None:
        self.space = space
        super().__init__(self)

    def on_set_prop(self, var: VarIdentifier, prop: str, value: Any):
        self.space.get_props(var)[prop] = value

    def on_define(self, name: str, props: dict[str, Any]):
        self.space.vars[name] = {}

        for prop, value in props.items():
            self.space.set_prop(name, prop, value)


class Parsable(Protocol):
    @classmethod
    @abstractmethod
    def parse(cls, input: Any) -> Self:
        pass


class PropParseModule(PropModule):
    @abstractmethod
    def parse(self, input: Any):
        pass

    def on_set(self, var: VarIdentifier, value: Any):
        super().on_set(var, self.parse(value))


class VarSymbol(sympy.Symbol):
    space: Space

    @property
    def props(self):
        return self.space.get_props(self)

    def __getattr__(self, name: str):
        if name in self.props:
            return self.props[name]
        else:
            return super().__Symbol_getattr__(name)  # type: ignore


@dataclass
class Space:
    name: str = field(default_factory=_infer_scope_name)
    vars: dict[str, dict[str, Any]] = field(default_factory=dict)
    modules: list[Module] = field(default_factory=list)
    # prop_parsers: dict[str, list[Callable[[Any], Any]]] = field(default_factory=dict)

    def __post_init__(self) -> None:
        self.modules.append(EndModule(self))

    def has(self, var: VarIdentifier):
        return var_name(var) in self.vars

    def get_props(self, var: VarIdentifier):
        name = var_name(var)
        if name not in self.vars:
            raise ValueError(f"Var {name} is not defined.")
        return self.vars[name]

    def get_prop(self, var: VarIdentifier, prop: str) -> Any:
        return self.get_props(var).get(prop, None)

    def set_prop(self, var: VarIdentifier, prop: str, value: Any):
        self.start_module.on_set_prop(var, prop, value)

    def create(self, name: str | None = None, **props):
        if name is None:
            name = cast(str, varname())

        if name in self.vars:
            raise ValueError(f"Var {name} is already defined.")

        self.start_module.on_define(name, props)
        return self.get_var(name)

    def get_var(self, var: VarIdentifier):
        name = var_name(var)
        if name not in self.vars:
            raise ValueError(f"Var {name} is not defined.")

        symbol = VarSymbol(name)
        symbol.space = self
        return symbol

    # def var(self, name: str | None = None, **props):
    #     if len(props) == 0 and name in self.vars:
    #         return self.get_var(name)
    #     else:
    #         return self.create_var(name, **props)

    @overload
    def get(self, var: VarIdentifier) -> VarSymbol:
        ...

    @overload
    def get(self, var: VarIdentifier, prop: str) -> Any:
        ...

    def get(self, var: VarIdentifier, prop: str | None = None):
        if prop is None:
            return self.get_var(var)
        else:
            return self.get_prop(var, prop)

    def use(self, *module_factories: Callable[[Module], Module]):
        for module_factory in module_factories:
            next = self.modules[-1]
            module = module_factory(next)
            if not isinstance(module, Module):
                raise ValueError(
                    f"Module factory did not return a Module. Got {type(module)}."
                )
            self.modules.append(module)

    @property
    def start_module(self):
        return self.modules[-1]

    @property
    def known_props(self):
        return set.union(*(module.known_props for module in self.modules))

    def __getattr__(self, name: str):
        if name.startswith("set_"):
            prop = name[4:]
            return lambda var, value=None: cast(Any, self.set_prop(var, prop, value))
        elif name.startswith("get_"):
            prop = name[4:]
            return lambda var, value=None: cast(Any, self.get_prop(var, prop))
        elif self.has(name):
            return self.get_var(name)
        else:
            raise AttributeError(f"Space has no attribute '{name}'.")

    def __str__(self) -> str:
        return f"{self.name}({len(self.vars)} vars, {len(self.modules) - 1} modules)"  # Ignore end module

    def __dir__(self):
        keys = list(super().__dir__())

        for prop in self.known_props:
            keys.append(f"get_{prop}")
            keys.append(f"set_{prop}")

        keys.extend(self.vars.keys())

        return keys



class QuantitySystem:
    labels: list[Text]

    def __init__(self, *labels) -> None:
        self.labels = [Text(label) for label in labels]

    def quantity(self, label, components: list[Fraction | int | str] | None = None, /, **components_kwargs: Fraction | int | str):
        if components is None:
            components = [
                components_kwargs.get(label.default, 0) for label in self.labels
            ]
        elif len(components_kwargs) > 0:
            raise ValueError("Cannot specify both components and components_kwargs.")
        
        return Quantity(label, self, components)

class Quantity:
    label: Text
    system: QuantitySystem
    components: list[Fraction]

    def __init__(self, label, system: QuantitySystem, components: list[Fraction | int | str]) -> None:
        if len(components) != len(system.labels):
            raise ValueError(f"System {system} requires {len(system.labels)} components, got {len(components)}.")

        self.label = Text(label)
        self.system = system
        self.components = [Fraction(c) for c in components]


    def unit(self, label, *, factor = 1, offset = 0):
        return Unit_(label, self, factor=factor, offset=offset)

    def __str__(self) -> str:
        s = ""
        for label, component in zip(self.system.labels, self.components):
            s += label.args[0] + "^" + str(component) + " "
        return s.strip()


class Unit_:
    label: Text
    quantity: Quantity
    factor: float
    offset: float

    def __init__(self, label, quantity: Quantity, *, factor = 1, offset = 0) -> None:
        self.label = Text(label)
        self.quantity = quantity
        self.factor = factor
        self.offset = offset

    



class si:
    quantity_system = QuantitySystem("T", "L", "M", "Θ", "N", "J")


    time = quantity_system.quantity("t", T=1)
    length = quantity_system.quantity("l", L=1)
    mass = quantity_system.quantity("m", M=1)


    second = time.unit("s")



class Dim:
    L: Fraction
    M: Fraction
    T: Fraction
    Θ: Fraction
    N: Fraction
    J: Fraction

    def __init__(self, **powers: Fraction | int | str) -> None:
        self.powers = {d: p if isinstance(p, Fraction) else Fraction(p) for d, p in powers.items()}

    def __getitem__(self, name: str):
        return self.powers.get(name, Fraction(0))

    def __getattr__(self, name: str):
        return self[name]

    def __mul__(self, other: Any) -> Dim:
        if isinstance(other, Dim):
            return Dim(*{d: self[d] + other[d] for d in set().union(self.powers, other.powers)})
        else:
            return self

    def __rmul__(self, other: Any) -> Dim:
        return self * other

    def __pow__(self, other: Any):
        if isinstance(other, (int, float, Fraction)):
            return Dim(*{d: p * other for d, p in self.powers.items()})
        else:
            raise ValueError(f"Cannot raise Dim to the power of {other}.")

    
class DimModule(PropModule):
    prop = "dim"



class LabelModule(Module):
    def on_define(self, name: str, props: dict[str, Any]):
        if "label" not in props:
            props |= {"label": name}  # TODO: Use Text(name)
        super().on_define(name, props)

    @property
    def known_props(self) -> set[str]:
        return {"label"}

units = Space()
units.use(LabelModule, DimModule)
units.create("1", dim=Dim(), label=Text("1", long="one"))
units.create("m", dim=Dim(L=1), label=Text("m", long="metre"))
units.create("s", dim=Dim(T=1), label=Text("s", long="second"))


computed_count = -1

class UnitModule(PropParseModule):
    prop = "unit"

    def parse(self, input: Any):
        if isinstance(input, sympy.Expr):
            global computed_count
            computed_count += 1
            return units.create(f"computed_{computed_count}", expr=input)
        elif is_var_identifier(input):
            return units.get(input)
        else:
            raise ValueError(f"Cannot parse Unit from {input}.")

    def on_define(self, name: str, props: dict[str, Any]):
        if "unit" not in props:
            if "expr" in props:
                expr: sympy.Expr = props["expr"]
                args = list(expr.free_symbols)
                f = sympy.lambdify(args, props["expr"])
                unit = f(*(arg.unit for arg in args))
                props |= {"unit": unit}

            else:
                props |= {"unit": 1}

        super().on_define(name, props)

In [10]:
vars = Space()
vars.use(UnitModule, LabelModule)

x = vars.create(unit=units.m)
t = vars.create(unit=units.s)
x.unit

computed_0

In [None]:


v = vars.create(expr=x / t)
v.unit