In [2]:
"""
strategy_pipeline.py

Single-file NL -> DSL -> AST -> Python -> Backtest pipeline.

Requirements:
    pip install pandas numpy

Usage:
    python strategy_pipeline.py

It will run a sample demonstration at the bottom. Adjust `NL_INPUT` or
call functions programmatically to test other rules.

 Sakthiyogesh
"""

import re
import math
from dataclasses import dataclass
from typing import Any, List, Union, Tuple, Dict
import pandas as pd
import numpy as np
from datetime import datetime

# -----------------------
# Part A: Simple NL -> DSL
# -----------------------
def nl_to_dsl(nl: str) -> str:
    """
    Convert a small set of natural-language phrases into the DSL form.
    This is intentionally conservative: supports patterns like the examples in the PDF.
    """
    s = nl.strip().lower()
    # common phrase -> tokens
    # 1. buy/enter vs exit
    # We'll produce two-block DSL: ENTRY: ... \n EXIT: ...
    entry_clauses = []
    exit_clauses = []

    # patterns for entry
    # "buy when the close price is above the 20-day moving average and volume is above 1 million."
    m = re.search(r'(buy|enter).+close.+above.+(\d+)[- ]?day', s)
    if m:
        days = int(m.group(2))
        entry_clauses.append(f"close > sma(close,{days})")

    m2 = re.search(r'volume.+above\s+([\d,.kKmM]+)', s)
    if m2:
        val = parse_human_number(m2.group(1))
        entry_clauses.append(f"volume > {int(val)}")

    # cross above yesterday's high
    if "cross" in s and "yesterday" in s:
        entry_clauses.append("crosses_above(close, yesterday_high)")

    # RSI exit pattern
    m3 = re.search(r'rsi\s*\(?\s*(\d+)\s*\)?.*below\s*(\d+)', s)
    if m3:
        rsi_p = int(m3.group(1))
        val = int(m3.group(2))
        exit_clauses.append(f"rsi(close,{rsi_p}) < {val}")

    # fallback: try to detect simple comparisons
    if not entry_clauses and not exit_clauses:
        # try to extract "close > sma(close,20) and volume > 1000000"
        # naive convert numbers like 1M to 1000000
        s2 = s
        s2 = s2.replace("greater than", ">").replace("less than", "<")
        s2 = s2.replace("and", "AND").replace("or", "OR")
        # digits rewrite
        s2 = re.sub(r'([\d,.]+)\s*m\b', lambda mo: str(parse_human_number(mo.group(1) + 'm')), s2)
        # look for rsi pattern
        if "rsi" in s2 and "exit" in s2:
            exit_clauses.append(s2)
        elif "buy" in s2 or "entry" in s2:
            entry_clauses.append(s2)

    # build DSL
    dsl_lines = []
    if entry_clauses:
        dsl_lines.append("ENTRY:")
        dsl_lines.append(" AND ".join(entry_clauses))
    if exit_clauses:
        dsl_lines.append("EXIT:")
        dsl_lines.append(" AND ".join(exit_clauses))
    if not dsl_lines:
        raise ValueError("Could not map NL to DSL with current simple rules. Provide a clearer rule.")
    return "\n".join(dsl_lines)

def parse_human_number(s: str) -> float:
    s = s.strip().lower().replace(',', '')
    if s.endswith('m'):
        return float(s[:-1]) * 1_000_000
    if s.endswith('k'):
        return float(s[:-1]) * 1_000
    try:
        return float(s)
    except:
        return float(re.sub(r'[^\d.]', '', s))

# -----------------------
# Part B: DSL Tokenizer + Parser -> AST
# -----------------------

# We'll support:
# - IDENT (close, open, high, low, volume)
# - NUM (integer)
# - FUNCTION calls like sma(close,20), rsi(close,14)
# - operators: > < >= <= ==, AND, OR
# - special function: crosses_above(a,b)
# - parentheses

Token = Tuple[str, str]

token_specification = [
    ('NUMBER',   r'\d+(\.\d+)?'),   # Integer or decimal number
    ('FUNC',     r'[A-Za-z_][A-Za-z0-9_]*\s*\('),  # function name + '('
    ('ID',       r'[A-Za-z_][A-Za-z0-9_]*'), # identifiers
    ('GE',       r'>='),
    ('LE',       r'<='),
    ('GT',       r'>'),
    ('LT',       r'<'),
    ('EQ',       r'=='),
    ('LPAREN',   r'\('),
    ('RPAREN',   r'\)'),
    ('COMMA',    r','),
    ('AND',      r'\bAND\b'),
    ('OR',       r'\bOR\b'),
    ('WS',       r'\s+'),
    ('OTHER',    r'.'),
]

tok_regex = '|'.join('(?P<%s>%s)' % pair for pair in token_specification)
master_re = re.compile(tok_regex, flags=re.IGNORECASE)

def tokenize(s: str) -> List[Token]:
    pos = 0
    tokens = []
    while pos < len(s):
        m = master_re.match(s, pos)
        if not m:
            raise SyntaxError(f"Unexpected text at pos {pos}: {s[pos:]}")
        typ = m.lastgroup
        val = m.group(typ)
        pos = m.end()
        if typ == 'WS':
            continue
        if typ == 'FUNC':
            # FUNC includes trailing '('; normalize
            fname = val[:-1].strip()
            tokens.append(('FUNC', fname))
            tokens.append(('LPAREN','('))
            continue
        tokens.append((typ, val))
    return tokens

# AST node classes
@dataclass
class ASTNode:
    pass

@dataclass
class BinOp(ASTNode):
    op: str
    left: ASTNode
    right: ASTNode

@dataclass
class FuncCall(ASTNode):
    name: str
    args: List[ASTNode]

@dataclass
class Identifier(ASTNode):
    name: str

@dataclass
class Number(ASTNode):
    value: float

# Parser: recursive descent for expressions with precedence:
# comparisons (>,<,>=,<=,==) lowest than AND/OR? Actually boolean ops lower precedence than comparisons.
# We'll parse comparisons as nodes that return booleans; AND/OR combine them.

class Parser:
    def __init__(self, tokens: List[Token]):
        self.tokens = tokens
        self.pos = 0

    def peek(self) -> Token:
        return self.tokens[self.pos] if self.pos < len(self.tokens) else ('EOF','')

    def eat(self, expected_type: str = None) -> Token:
        tok = self.peek()
        if expected_type and tok[0].upper() != expected_type.upper():
            raise SyntaxError(f"Expected {expected_type} but got {tok}")
        self.pos += 1
        return tok

    def parse(self) -> ASTNode:
        return self.parse_or()

    def parse_or(self):
        node = self.parse_and()
        while True:
            tok = self.peek()
            if tok[0].upper() == 'OR':
                self.eat('OR')
                right = self.parse_and()
                node = BinOp('OR', node, right)
            else:
                break
        return node

    def parse_and(self):
        node = self.parse_comparison()
        while True:
            tok = self.peek()
            if tok[0].upper() == 'AND':
                self.eat('AND')
                right = self.parse_comparison()
                node = BinOp('AND', node, right)
            else:
                break
        return node

    def parse_comparison(self):
        left = self.parse_term()
        tok = self.peek()
        if tok[0] in ('GT','LT','GE','LE','EQ'):
            op = self.eat()[0]
            right = self.parse_term()
            return BinOp(op, left, right)
        return left

    def parse_term(self):
        tok = self.peek()
        if tok[0] == 'LPAREN':
            self.eat('LPAREN')
            node = self.parse()
            self.eat('RPAREN')
            return node
        if tok[0] == 'FUNC':
            name = self.eat('FUNC')[1]
            args = self.parse_arglist()
            return FuncCall(name.lower(), args)
        if tok[0] == 'ID':
            val = self.eat('ID')[1]
            return Identifier(val.lower())
        if tok[0] == 'NUMBER':
            v = float(self.eat('NUMBER')[1])
            return Number(v)
        # special support for "yesterday_high" token or similar IDs
        raise SyntaxError(f"Unexpected token {tok} at position {self.pos}")

    def parse_arglist(self) -> List[ASTNode]:
        args = []
        # we are after LPAREN
        # eat '(' already consumed by tokenizer as separate token (we did)
        # parse until RPAREN
        # NOTE: parser created FUNC token then left LPAREN token; here we assume current token is after that LPAREN
        # But our tokenizer emitted LPAREN already; ensure to consume if present
        # We'll expect '(' already consumed by parse_term by previous code; but for simplicity:
        # if current token is RPAREN, return []
        # Actually parse_term called FUNC then parse_arglist; after eating FUNC we left LPAREN in stream and
        # didn't consume it; so we need to eat LPAREN now:
        if self.peek()[0] == 'LPAREN':
            self.eat('LPAREN')
        while True:
            if self.peek()[0] == 'RPAREN':
                self.eat('RPAREN')
                break
            arg = self.parse()
            args.append(arg)
            if self.peek()[0] == 'COMMA':
                self.eat('COMMA')
                continue
            elif self.peek()[0] == 'RPAREN':
                self.eat('RPAREN')
                break
            else:
                # If expression doesn't have comma but next is RPAREN eventually, loop will break
                pass
        return args

def parse_dsl_block(text: str) -> Dict[str,str]:
    """
    Splits DSL into ENTRY/EXIT blocks and returns dict {'entry': '...', 'exit': '...'}
    """
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    cur = None
    blocks = {}
    for ln in lines:
        up = ln.upper()
        if up.startswith("ENTRY:"):
            cur = 'entry'
            rest = ln[len("ENTRY:"):].strip()
            blocks[cur] = rest if rest else ""
            continue
        if up.startswith("EXIT:"):
            cur = 'exit'
            rest = ln[len("EXIT:"):].strip()
            blocks[cur] = rest if rest else ""
            continue
        if cur is None:
            continue
        # append
        blocks[cur] = blocks.get(cur, "") + " " + ln
    return blocks

def build_ast_from_text(expr_text: str) -> ASTNode:
    tokens = tokenize(expr_text)
    p = Parser(tokens)
    ast = p.parse()
    return ast

# -----------------------
# Part C: AST -> pandas expression (code generator)
# -----------------------

# We'll generate a Python function that, given df, returns signals DataFrame with boolean columns 'entry' and 'exit'.
# To keep it readable, we will create helper functions and then evaluate expressions.

def ast_to_python_expr(ast: ASTNode, temp_vars: Dict[str,int]=None) -> str:
    """
    Convert AST into a python expression string that returns a pandas Series (boolean or numeric).
    Uses df[...] for series and helper functions (sma, rsi, cross_above).
    """
    if temp_vars is None:
        temp_vars = {'v':0}
    if isinstance(ast, Number):
        return repr(ast.value)
    if isinstance(ast, Identifier):
        name = ast.name.lower()
        # map some aliases
        if name == 'yesterday_high':
            return "df['high'].shift(1)"
        if name in ('open','high','low','close','volume'):
            return f"df['{name}']"
        # allow 'close[-1]' style? not here
        return f"df['{name}']"
    if isinstance(ast, FuncCall):
        n = ast.name.lower()
        args_expr = [ast_to_python_expr(a, temp_vars) for a in ast.args]
        if n == 'sma':
            # sma(series, period)
            return f"sma({args_expr[0]}, int({args_expr[1]}))"
        if n == 'rsi':
            return f"rsi({args_expr[0]}, int({args_expr[1]}))"
        if n in ('crosses_above','cross_above','cross_above'):
            return f"cross_above({args_expr[0]},{args_expr[1]})"
        # generic function fallback (not expected)
        return f"{n}({', '.join(args_expr)})"
    if isinstance(ast, BinOp):
        op = ast.op
        left = ast_to_python_expr(ast.left, temp_vars)
        right = ast_to_python_expr(ast.right, temp_vars)
        # map op tokens to python/pandas
        if op == 'AND':
            return f"({left}) & ({right})"
        if op == 'OR':
            return f"({left}) | ({right})"
        if op == 'GT' or op == '>':
            return f"({left}) > ({right})"
        if op == 'LT' or op == '<':
            return f"({left}) < ({right})"
        if op == 'GE' or op == '>=':
            return f"({left}) >= ({right})"
        if op == 'LE' or op == '<=':
            return f"({left}) <= ({right})"
        if op == 'EQ' or op == '==':
            return f"({left}) == ({right})"
        # catch-case: sometimes tokens are actual symbols
        if op in ('>','<','>=','<=','=='):
            return f"({left}) {op} ({right})"
        # fallback
        return f"({left}) {op} ({right})"

def generate_signal_function(entry_ast: Union[ASTNode,None], exit_ast: Union[ASTNode,None]) -> str:
    """
    Returns a string with a Python function definition 'generate_signals(df)' which uses helper functions.
    """
    body_lines = []
    body_lines.append("def generate_signals(df):")
    body_lines.append("    # expects df with columns: open,high,low,close,volume and datetime index or 'date' column")
    body_lines.append("    signals = pd.DataFrame(index=df.index)")
    if entry_ast:
        expr = ast_to_python_expr(entry_ast)
        body_lines.append(f"    signals['entry'] = {expr}")
    else:
        body_lines.append("    signals['entry'] = False")
    if exit_ast:
        expr2 = ast_to_python_expr(exit_ast)
        body_lines.append(f"    signals['exit'] = {expr2}")
    else:
        body_lines.append("    signals['exit'] = False")
    body_lines.append("    # ensure booleans")
    body_lines.append("    signals['entry'] = signals['entry'].fillna(False).astype(bool)")
    body_lines.append("    signals['exit'] = signals['exit'].fillna(False).astype(bool)")
    body_lines.append("    return signals")
    return "\n".join("    " + ln if i>0 else ln for i,ln in enumerate(body_lines))

# -----------------------
# Part D: Indicator helpers and cross detection
# -----------------------

def sma(series: pd.Series, period: int) -> pd.Series:
    period = max(1, int(period))   # Prevent 0 or negative values
    return series.rolling(period, min_periods=1).mean()

def rsi(series: pd.Series, period: int = 14) -> pd.Series:
    # classic RSI (Wilder)
    delta = series.diff()
    up = delta.clip(lower=0.0)
    down = -1 * delta.clip(upper=0.0)
    # Wilder smoothing
    ma_up = up.ewm(alpha=1/period, adjust=False, min_periods=period).mean()
    ma_down = down.ewm(alpha=1/period, adjust=False, min_periods=period).mean()
    rs = ma_up / (ma_down + 1e-12)
    rsi = 100 - (100 / (1 + rs))
    return rsi

def cross_above(s1: pd.Series, s2: pd.Series) -> pd.Series:
    prev = s1.shift(1)
    prev2 = s2.shift(1)
    return (s1 > s2) & (prev <= prev2)

# -----------------------
# Part E: Backtest simulator
# -----------------------

@dataclass
class Trade:
    entry_idx: Any
    exit_idx: Any
    entry_price: float
    exit_price: float
    pnl: float
    ret: float

def run_backtest(df: pd.DataFrame, signals: pd.DataFrame, verbose: bool=False) -> Dict:
    """
    Very simple backtester:
    - Start flat
    - Enter at next bar's open when signals['entry'] becomes True (for simplicity we enter at close of same bar)
    - Exit at next bar's close when signals['exit'] True
    """
    pos = 0
    trades: List[Trade] = []
    entry_price = None
    entry_idx = None

    for idx in df.index:
        row_entry = signals.loc[idx, 'entry']
        row_exit  = signals.loc[idx, 'exit']
        price = float(df.loc[idx, 'close'])

        if pos == 0 and row_entry:
            # enter long at current close (simple)
            pos = 1
            entry_price = price
            entry_idx = idx
            if verbose:
                print(f"ENTER at {idx} price {price}")
            continue  # don't exit same bar

        if pos == 1 and row_exit:
            exit_price = price
            exit_idx = idx
            pnl = exit_price - entry_price
            ret = pnl / entry_price if entry_price != 0 else 0.0
            trades.append(Trade(entry_idx, exit_idx, entry_price, exit_price, pnl, ret))
            pos = 0
            entry_price = None
            entry_idx = None
            if verbose:
                print(f"EXIT at {idx} price {exit_price} pnl {pnl}")
            continue

    # if still in position at the end, close at final close
    if pos == 1:
        final_price = float(df.iloc[-1]['close'])
        exit_idx = df.index[-1]
        pnl = final_price - entry_price
        ret = pnl / entry_price
        trades.append(Trade(entry_idx, exit_idx, entry_price, final_price, pnl, ret))

    # metrics
    total_pnl = sum(t.pnl for t in trades)
    total_return = (np.prod([1 + t.ret for t in trades]) - 1) if trades else 0.0
    equity = [0.0]
    # simple equity curve build: start at 1.0, apply returns sequentially
    eq = 1.0
    eq_curve = [eq]
    for t in trades:
        eq = eq * (1 + t.ret)
        eq_curve.append(eq)
    # max drawdown (from equity curve)
    eq_arr = np.array(eq_curve)
    peak = np.maximum.accumulate(eq_arr)
    dd = (eq_arr - peak) / peak
    max_dd = float(dd.min()) if len(dd)>0 else 0.0

    results = {
        'trades': trades,
        'total_pnl': total_pnl,
        'total_return': total_return,
        'num_trades': len(trades),
        'max_drawdown': max_dd,
    }
    return results

# Pretty print trades
def print_backtest_report(results):
    print("=== Backtest Report ===")
    print(f"Trades: {results['num_trades']}")
    print(f"Total PnL: {results['total_pnl']:.4f}")
    print(f"Total Return: {results['total_return']*100:.2f}%")
    print(f"Max Drawdown: {results['max_drawdown']*100:.2f}%")
    print("Trade Log:")
    for t in results['trades']:
        print(f"- Enter: {t.entry_idx} @{t.entry_price}  Exit: {t.exit_idx} @{t.exit_price}  PnL: {t.pnl:.4f}  Return: {t.ret*100:.2f}%")

# -----------------------
# Part F: End-to-End demo
# -----------------------

SAMPLE_DATA = """date,open,high,low,close,volume
2023-01-01,100,105,99,103,900000
2023-01-02,103,108,101,107,1200000
2023-01-03,107,110,106,109,1300000
2023-01-04,109,112,108,111,1500000
2023-01-05,111,115,110,114,2000000
2023-01-06,114,117,113,116,1800000
2023-01-07,116,119,115,118,1600000
2023-01-08,118,121,117,120,1700000
2023-01-09,120,122,119,121,900000
2023-01-10,121,125,120,124,2000000
"""

def load_sample_df(csv_text: str) -> pd.DataFrame:
    from io import StringIO
    df = pd.read_csv(StringIO(csv_text), parse_dates=['date'])
    df = df.set_index('date')
    return df

def end_to_end(nl_input: str, df: pd.DataFrame):
    print("Natural Language Input:")
    print(nl_input)
    dsl = nl_to_dsl(nl_input)
    print("\nGenerated DSL:")
    print(dsl)
    blocks = parse_dsl_block(dsl)
    entry_text = blocks.get('entry','').strip()
    exit_text = blocks.get('exit','').strip()
    print("\nParsed Blocks:")
    print("ENTRY:", entry_text)
    print("EXIT: ", exit_text)

    entry_ast = build_ast_from_text(entry_text) if entry_text else None
    exit_ast  = build_ast_from_text(exit_text) if exit_text else None

    # show ASTs (simple)
    print("\nAST (entry):", entry_ast)
    print("AST (exit):", exit_ast)

    # generate python function text
    func_code = generate_signal_function(entry_ast, exit_ast)
    # build full runtime environment
    runtime_globals = {
        'pd': pd,
        'np': np,
        'sma': sma,
        'rsi': rsi,
        'cross_above': cross_above,
        'df': df
    }
    # exec function
    exec(func_code, runtime_globals)
    generate_signals = runtime_globals['generate_signals']
    signals = generate_signals(df)

    print("\nSignals head:")
    print(signals.head(10))

    results = run_backtest(df, signals)
    print()
    print_backtest_report(results)
    return {
        'dsl': dsl,
        'entry_ast': entry_ast,
        'exit_ast': exit_ast,
        'signals': signals,
        'results': results,
        'func_code': func_code
    }

# If run as script, demo:
if __name__ == "__main__":
    # Example natural language input
    NL_INPUT = "Buy when the close price is above the 20-day moving average and volume is above 1M. Exit when RSI(14) is below 30."
    df = load_sample_df(SAMPLE_DATA)
    out = end_to_end(NL_INPUT, df)


Natural Language Input:
Buy when the close price is above the 20-day moving average and volume is above 1M. Exit when RSI(14) is below 30.

Generated DSL:
ENTRY:
close > sma(close,0) AND volume > 1
EXIT:
rsi(close,14) < 30

Parsed Blocks:
ENTRY: close > sma(close,0) AND volume > 1
EXIT:  rsi(close,14) < 30

AST (entry): BinOp(op='GT', left=Identifier(name='close'), right=FuncCall(name='sma', args=[Identifier(name='close'), Number(value=0.0)]))
AST (exit): BinOp(op='LT', left=FuncCall(name='rsi', args=[Identifier(name='close'), Number(value=14.0)]), right=Number(value=30.0))

Signals head:
            entry   exit
date                    
2023-01-01  False  False
2023-01-02  False  False
2023-01-03  False  False
2023-01-04  False  False
2023-01-05  False  False
2023-01-06  False  False
2023-01-07  False  False
2023-01-08  False  False
2023-01-09  False  False
2023-01-10  False  False

=== Backtest Report ===
Trades: 0
Total PnL: 0.0000
Total Return: 0.00%
Max Drawdown: 0.00%
Trade Log:
