In [200]:
import pandas as pd
import polars as pl

from mintalib.samples import sample_prices
from mintalib.indicators import SMA, EMA, MACD, RSI


In [201]:
prices = sample_prices()
prices

Unnamed: 0_level_0,open,high,low,close,volume
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1980-12-12,0.098943,0.099373,0.098943,0.098943,469033600
1980-12-15,0.094211,0.094211,0.093781,0.093781,175884800
1980-12-16,0.087328,0.087328,0.086898,0.086898,105728000
1980-12-17,0.089049,0.089479,0.089049,0.089049,86441600
1980-12-18,0.091630,0.092061,0.091630,0.091630,73449600
...,...,...,...,...,...
2024-10-15,233.610001,237.490005,232.369995,233.850006,64751400
2024-10-16,231.600006,232.119995,229.839996,231.779999,34082200
2024-10-17,233.429993,233.850006,230.520004,232.149994,32993800
2024-10-18,236.179993,236.179993,234.009995,235.000000,46431500


In [202]:
plprices = pl.from_dataframe(prices.reset_index())
plprices



date,open,high,low,close,volume
datetime[ns],f64,f64,f64,f64,i64
1980-12-12 00:00:00,0.098943,0.099373,0.098943,0.098943,469033600
1980-12-15 00:00:00,0.094211,0.094211,0.093781,0.093781,175884800
1980-12-16 00:00:00,0.087328,0.087328,0.086898,0.086898,105728000
1980-12-17 00:00:00,0.089049,0.089479,0.089049,0.089049,86441600
1980-12-18 00:00:00,0.09163,0.092061,0.09163,0.09163,73449600
…,…,…,…,…,…
2024-10-15 00:00:00,233.610001,237.490005,232.369995,233.850006,64751400
2024-10-16 00:00:00,231.600006,232.119995,229.839996,231.779999,34082200
2024-10-17 00:00:00,233.429993,233.850006,230.520004,232.149994,32993800
2024-10-18 00:00:00,236.179993,236.179993,234.009995,235.0,46431500


In [203]:
import functools

from abc import ABCMeta, abstractmethod


class Study(metaclass=ABCMeta):
    """callable/chainable with process method and composition"""

    __pandas_priority__ = 5000

    @abstractmethod
    def __call__(self, prices): ...

    def __or__(self, other):
        """pipe into callable"""

        if not callable(other):
            return NotImplemented

        return ChainedStudy(self, other)

    def __ror__(self, other):
        """apply to dataframe"""

        if not hasattr(other, "pipe"):
            return NotImplemented

        return other.pipe(self)

    def pipe(self, func, **kwargs):
        """pipe into callable with optional arguments"""

        if kwargs:
            func = functools.partial(func, **kwargs)

        return self | func


class ChainedStudy(Study):
    """chain of callables/studies"""

    funcs: tuple = ()

    def __init__(self, *funcs):
        for func in funcs:
            if not callable(func):
                raise TypeError(f"Argument {func!r} is not callable!")
        self.funcs = funcs

    def __repr__(self):
        return " | ".join(repr(fn) for fn in self.funcs)

    def __call__(self, prices):
        result = prices
        for func in self.funcs:
            if result is None:
                return
            result = func(result)

        return result

    def __or__(self, other):
        if not callable(other):
            return NotImplemented

        funcs = self.funcs + (other,)
        return self.__class__(*funcs)




In [213]:
from collections import Counter


class Update(Study):
    """Update Study"""

    items: tuple = ()
    select: bool = False

    @staticmethod
    def check_args(*args, **kwargs):
        args = args + tuple(kwargs.values())
        for arg in args:
            if not callable(arg):
                raise ValueError("Callable expected!")

    @staticmethod    
    def get_columns(result):
        if hasattr(result, 'columns'):
            return list(result.columns)
        elif hasattr(result, 'name'):
            name = result.name
            if name is not None:
                return [result.name]
        return ()


    def __init__(self, *args, **kwargs):
        self.check_args(*args, **kwargs)
        if kwargs:
            args = args + (kwargs,)
        self.items = args


    def __repr__(self):
        cname = self.__class__.__name__
        args = ", ".join(repr(item) for item in self.items)
        return f"{cname}({args})"


    def __call__(self, prices):
        if not hasattr(prices, 'columns'):
            raise ValueError("DataFrame expected!")

        target = getattr(prices, '__module__', None).partition('.')[0]

        if target == "pandas":
            return self.apply_pandas(prices)
        
        if target == "polars":
            return self.apply_polars(prices)
        
        raise ValueError(f"Unsupported DataFrame type: {target}")
       
    
    def apply_pandas(self, prices):
        counter = Counter()
        for item in self.items:
            if callable(item):
                result = item(prices)
                columns = self.get_columns(result)
                if columns:
                    counter.update(columns)
                else:
                    raise ValueError("Unnamed result in positional args!")
                prices = prices.join(result)
            elif isinstance(item, dict):
                counter.update(item.keys())
                prices = prices.assign(**item)
            else:
                tname = type(item).__name__
                raise ValueError(f"Unsupported item type {tname}!")

        if self.select:
            columns = list(counter)
            prices = prices.filter(columns)
            
        return prices

    def apply_polars(self, prices):
        counter = Counter()
        for item in self.items:
            if callable(item):
                result = item(prices)
                columns = self.get_columns(result)
                if columns:
                    counter.update(columns)
                else:
                    raise ValueError("Unnamed result in positional args!")
                prices = prices.with_columns(result)
            elif isinstance(item, dict):
                counter.update(item.keys())
                for k, v in item.items():
                    result = v(prices)
                    if hasattr(result, 'columns'):
                        raise ValueError("DataFrame result in keyword args!")
                    kwargs = {k: result}
                    prices = prices.with_columns(**kwargs)
            else:
                tname = type(item).__name__
                raise ValueError(f"Unsupported item type {tname}!")
        if self.select:
            columns = list(counter)
            prices = prices.select(*columns)
        return prices


class Select(Update):
    """Select Study"""

    select: bool = True



In [214]:
study = Update(MACD(), sma = SMA(20), sma2 = SMA(20, item="sma") )
study


Update(calc_macd(12, 26, 9), {'sma': calc_sma(20), 'sma2': calc_sma(20)})

In [215]:
study(prices)


Unnamed: 0_level_0,open,high,low,close,volume,macd,macdsignal,macdhist,sma,sma2
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1980-12-12,0.098943,0.099373,0.098943,0.098943,469033600,,,,,
1980-12-15,0.094211,0.094211,0.093781,0.093781,175884800,,,,,
1980-12-16,0.087328,0.087328,0.086898,0.086898,105728000,,,,,
1980-12-17,0.089049,0.089479,0.089049,0.089049,86441600,,,,,
1980-12-18,0.091630,0.092061,0.091630,0.091630,73449600,,,,,
...,...,...,...,...,...,...,...,...,...,...
2024-10-15,233.610001,237.490005,232.369995,233.850006,64751400,1.815958,1.313965,0.501993,227.524000,224.533800
2024-10-16,231.600006,232.119995,229.839996,231.779999,34082200,1.941114,1.439395,0.501719,228.078500,224.768150
2024-10-17,233.429993,233.850006,230.520004,232.149994,32993800,2.046565,1.560829,0.485736,228.242500,225.004525
2024-10-18,236.179993,236.179993,234.009995,235.000000,46431500,2.333211,1.715305,0.617906,228.582500,225.248725


In [216]:
study(plprices)

date,open,high,low,close,volume,macd,macdsignal,macdhist,sma,sma2
datetime[ns],f64,f64,f64,f64,i64,f64,f64,f64,f64,f64
1980-12-12 00:00:00,0.098943,0.099373,0.098943,0.098943,469033600,,,,,
1980-12-15 00:00:00,0.094211,0.094211,0.093781,0.093781,175884800,,,,,
1980-12-16 00:00:00,0.087328,0.087328,0.086898,0.086898,105728000,,,,,
1980-12-17 00:00:00,0.089049,0.089479,0.089049,0.089049,86441600,,,,,
1980-12-18 00:00:00,0.09163,0.092061,0.09163,0.09163,73449600,,,,,
…,…,…,…,…,…,…,…,…,…,…
2024-10-15 00:00:00,233.610001,237.490005,232.369995,233.850006,64751400,1.815958,1.313965,0.501993,227.524,224.5338
2024-10-16 00:00:00,231.600006,232.119995,229.839996,231.779999,34082200,1.941114,1.439395,0.501719,228.0785,224.76815
2024-10-17 00:00:00,233.429993,233.850006,230.520004,232.149994,32993800,2.046565,1.560829,0.485736,228.2425,225.004525
2024-10-18 00:00:00,236.179993,236.179993,234.009995,235.0,46431500,2.333211,1.715305,0.617906,228.5825,225.248725


In [None]:

study = Select(MACD(), sma = SMA(20), sma2 = SMA(20, item="sma") )
study


In [209]:
study(prices)

Unnamed: 0_level_0,macd,macdsignal,macdhist,sma,sma2
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1980-12-12,,,,,
1980-12-15,,,,,
1980-12-16,,,,,
1980-12-17,,,,,
1980-12-18,,,,,
...,...,...,...,...,...
2024-10-15,1.815958,1.313965,0.501993,227.524000,224.533800
2024-10-16,1.941114,1.439395,0.501719,228.078500,224.768150
2024-10-17,2.046565,1.560829,0.485736,228.242500,225.004525
2024-10-18,2.333211,1.715305,0.617906,228.582500,225.248725


In [210]:
study(prices)

Unnamed: 0_level_0,macd,macdsignal,macdhist,sma,sma2
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1980-12-12,,,,,
1980-12-15,,,,,
1980-12-16,,,,,
1980-12-17,,,,,
1980-12-18,,,,,
...,...,...,...,...,...
2024-10-15,1.815958,1.313965,0.501993,227.524000,224.533800
2024-10-16,1.941114,1.439395,0.501719,228.078500,224.768150
2024-10-17,2.046565,1.560829,0.485736,228.242500,225.004525
2024-10-18,2.333211,1.715305,0.617906,228.582500,225.248725


In [211]:
prices | study

Unnamed: 0_level_0,macd,macdsignal,macdhist,sma,sma2
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1980-12-12,,,,,
1980-12-15,,,,,
1980-12-16,,,,,
1980-12-17,,,,,
1980-12-18,,,,,
...,...,...,...,...,...
2024-10-15,1.815958,1.313965,0.501993,227.524000,224.533800
2024-10-16,1.941114,1.439395,0.501719,228.078500,224.768150
2024-10-17,2.046565,1.560829,0.485736,228.242500,225.004525
2024-10-18,2.333211,1.715305,0.617906,228.582500,225.248725


macd,macdsignal,macdhist,sma,sma2
f64,f64,f64,f64,f64
,,,,
,,,,
,,,,
,,,,
,,,,
…,…,…,…,…
1.815958,1.313965,0.501993,227.524,224.5338
1.941114,1.439395,0.501719,228.0785,224.76815
2.046565,1.560829,0.485736,228.2425,225.004525
2.333211,1.715305,0.617906,228.5825,225.248725
