In [None]:
import inspect

from functools import update_wrapper

import polars as pl
from polars.datatypes import Struct, Float64


from mintalib import core
from mintalib import functions as fx
from mintalib.samples import sample_prices


def expression_method(calc_func):
    calc_sig = inspect.signature(calc_func)
    first_param = next(iter(calc_sig.parameters.values()))
    force_struct = first_param.name == 'prices'

    def decorator(func):
        name = func.__name__
        metadata = getattr(calc_func, 'metadata', {})
        output_names = metadata.get('output_names', ())
        output_type = Struct({n: Float64 for n in output_names}) if output_names else Float64
        sig = inspect.signature(func)

        def wrapper(*args, **kwargs):
            binding = sig.bind(*args, **kwargs)
            binding.apply_defaults()
            params = dict(binding.arguments)

            self = params.pop('self', None)
            source = self._expr

            if force_struct:
                source = pl.struct(source)

            def batch_func(prices):
                if force_struct:
                    prices = prices.struct.unnest()

                output = calc_func(prices, **params)
                
                if isinstance(output, tuple):
                    return pl.DataFrame(output._asdict()).fill_nan(None).to_struct()
                else:
                    return pl.Series(output).fill_nan(None)
            
            expr = source.map_batches(batch_func, return_dtype=output_type).alias(name)
#            expr = expr.struct.unnest() if output_names else expr.alias(name)
            
            return expr
        
        setattr(wrapper, "__name__", func.__name__)
        setattr(wrapper, "__qualname__", func.__qualname__)
        setattr(wrapper, "__doc__", calc_func.__doc__)
        setattr(wrapper, "__signature__", sig)

        return wrapper
    
    return decorator



@pl.api.register_expr_namespace("ta")
class MyExtension:
    def __init__(self, expr: pl.Expr) -> None:
        self._expr = expr

    @expression_method(core.calc_ema)
    def ema(self, period: int = 20): ...

    @expression_method(core.calc_sma)
    def sma(self, period: int = 20): ...




  class MyExtension:


In [11]:
prices = sample_prices()
prices = pl.from_pandas(prices, include_index=True)
prices


date,open,high,low,close,volume
datetime[ns],f64,f64,f64,f64,i64
1980-12-12 00:00:00,0.098943,0.099373,0.098943,0.098943,469033600
1980-12-15 00:00:00,0.094211,0.094211,0.093781,0.093781,175884800
1980-12-16 00:00:00,0.087328,0.087328,0.086898,0.086898,105728000
1980-12-17 00:00:00,0.089049,0.089479,0.089049,0.089049,86441600
1980-12-18 00:00:00,0.09163,0.092061,0.09163,0.09163,73449600
…,…,…,…,…,…
2024-10-15 00:00:00,233.610001,237.490005,232.369995,233.850006,64751400
2024-10-16 00:00:00,231.600006,232.119995,229.839996,231.779999,34082200
2024-10-17 00:00:00,233.429993,233.850006,230.520004,232.149994,32993800
2024-10-18 00:00:00,236.179993,236.179993,234.009995,235.0,46431500


In [12]:
prices.select(
    pl.col('close').ta.ema(50)
)

ema
f64
""
""
""
""
""
…
223.931848
224.239619
224.54983
224.95964
