In [66]:
import pandas as pd
import math
import numpy as np

# Global counter for progress prints.
state_count = 0
STATE_PRINT_INTERVAL = 10000  # Print progress every 10,000 states

# --- Helper Functions with Caching ---

def get_ask_cost_cached(group, q):
    cache = group.setdefault("ask_cache", {})
    if q in cache:
        return cache[q]
    ask_orders = group["ask_orders"]
    remaining = q
    total_cost = 0
    for order in ask_orders:
        if remaining <= 0:
            break
        trade = min(order["volume"], remaining)
        total_cost += trade * order["price"]
        remaining -= trade
    value = total_cost if remaining == 0 else None
    cache[q] = value
    return value

def get_bid_revenue_cached(group, q):
    cache = group.setdefault("bid_cache", {})
    if q in cache:
        return cache[q]
    bid_orders = group["bid_orders"]
    remaining = q
    total_rev = 0
    for order in bid_orders:
        if remaining <= 0:
            break
        trade = min(order["volume"], remaining)
        total_rev += trade * order["price"]
        remaining -= trade
    value = total_rev if remaining == 0 else None
    cache[q] = value
    return value

def group_orders_by_timestamp(df):
    groups = []
    for ts, group_df in df.groupby("timestamp"):
        orders = group_df.to_dict("records")
        ask_orders = sorted([o for o in orders if o["side"].lower() == "ask"],
                             key=lambda x: x["price"])
        bid_orders = sorted([o for o in orders if o["side"].lower() == "bid"],
                             key=lambda x: -x["price"])
        groups.append({
            "timestamp": ts,
            "ask_orders": ask_orders,
            "bid_orders": bid_orders,
            "total_ask": sum(o["volume"] for o in ask_orders),
            "total_bid": sum(o["volume"] for o in bid_orders),
            "ask_cache": {},
            "bid_cache": {}
        })
    groups.sort(key=lambda g: g["timestamp"])
    return groups

# --- Candidate Volumes by Price ---
def candidate_volumes(orders):
    """
    Given a list of orders (assumed sorted either ascending or descending by price),
    returns a sorted list of cumulative volumes aggregated at each distinct price level.
    """
    volumes = []
    cum = 0
    last_price = None
    for order in orders:
        cum += order["volume"]
        if order["price"] != last_price:
            volumes.append(cum)
            last_price = order["price"]
    if cum not in volumes:
        volumes.append(cum)
    return sorted(volumes)

# --- Candidate Set Function (based on aggregated volumes) ---
def candidate_set(cur):
    """
    Returns a list of candidate (x, y) decision pairs for the current group.
    x is the volume to take from the ask side (buy) and y from the bid side (sell).

    The candidate set always includes:
      - The full arbitrage candidate: (min(total_ask, total_bid), min(total_ask, total_bid))
      - Extreme candidates: (total_ask, 0) and (0, total_bid) and (0,0)
      - In addition, for the ask side, it adds candidate x values from the cumulative volumes
        (aggregated by distinct ask prices) and similarly for the bid side.
    """
    avail_buy = cur["total_ask"]
    avail_sell = cur["total_bid"]
    candidates = set()
    # Full arbitrage candidate:
    arb = min(avail_buy, avail_sell)
    candidates.add((arb, arb))
    # Extreme candidates:
    candidates.add((avail_buy, 0))
    candidates.add((0, avail_sell))
    candidates.add((0, 0))
    # Candidate volumes from ask and bid orders:
    Vx = candidate_volumes(cur["ask_orders"])
    Vy = candidate_volumes(cur["bid_orders"])
    for x in Vx:
        for y in Vy:
            if x <= avail_buy and y <= avail_sell:
                candidates.add((x, y))
    # Also add some fraction-based candidates:
    for fraction in [0.25, 0.5, 0.75]:
        x_val = int(round(avail_buy * fraction))
        y_val = int(round(avail_sell * fraction))
        candidates.add((x_val, 0))
        candidates.add((0, y_val))
        candidates.add((x_val, y_val))
    return list(candidates)

# --- Processing a Group's Decision (with Coarse Cost Basis) ---
def process_decision(pos, c, x, y, group):
    """
    Given current state (pos, cost basis c) and a decision (x units buy, y units sell)
    in the given group, returns (profit, new_pos, new_c). New cost basis is rounded to an integer.
    If pos ≠ 0 then c must be provided; otherwise, returns None.
    """
    ask_cost = lambda q: get_ask_cost_cached(group, q)
    bid_rev = lambda q: get_bid_revenue_cached(group, q)
    new_pos = pos + x - y
    profit = 0
    new_c = None

    if pos > 0:
        if c is None:
            return None
        close_qty = min(y, pos)
        if close_qty > 0:
            rev = bid_rev(close_qty)
            if rev is None:
                return None
            profit += rev - close_qty * c
        leftover = pos - close_qty
        extra_sell = max(0, y - pos)
        if leftover + x >= extra_sell:
            new_pos = leftover + x - extra_sell
            total_long = leftover + x
            if x > 0:
                cost_new = ask_cost(x)
                if cost_new is None:
                    return None
            else:
                cost_new = 0
            new_c = ((leftover * c) + cost_new) / total_long if total_long > 0 else None
        else:
            extra = extra_sell - (leftover + x)
            new_pos = -extra
            rev_extra = bid_rev(extra)
            if extra <= 0 or rev_extra is None:
                return None
            new_c = rev_extra / extra
        if new_c is not None:
            new_c = round(new_c)
        return (profit, new_pos, new_c)

    elif pos < 0:
        if c is None:
            return None
        close_qty = min(x, -pos)
        if close_qty > 0:
            cost_new = ask_cost(close_qty)
            if cost_new is None:
                return None
            profit += close_qty * c - cost_new
        leftover = (-pos) - close_qty
        extra_buy = max(0, x - (-pos))
        if leftover + extra_buy >= y:
            new_pos = -(leftover + extra_buy - y)
            total_short = leftover + y
            if y > 0:
                rev_sell = bid_rev(y)
                if rev_sell is None:
                    return None
            else:
                rev_sell = 0
            new_c = ((leftover * c) + rev_sell) / (leftover + y) if (leftover + y) > 0 else None
        else:
            extra = y - (leftover + extra_buy)
            new_pos = extra
            cost_extra = ask_cost(extra)
            if extra <= 0 or cost_extra is None:
                return None
            new_c = cost_extra / extra
        if new_c is not None:
            new_c = round(new_c)
        return (profit, new_pos, new_c)

    else:
        # Starting flat.
        arbitrage_qty = min(x, y)
        if arbitrage_qty > 0:
            cost_arb = ask_cost(arbitrage_qty)
            rev_arb = bid_rev(arbitrage_qty)
            if cost_arb is None or rev_arb is None:
                return None
            profit += rev_arb - cost_arb
        if x > y:
            extra = x - y
            total_cost = ask_cost(x)
            base_cost = ask_cost(arbitrage_qty) if arbitrage_qty > 0 else 0
            if (x > 0 and total_cost is None) or (arbitrage_qty > 0 and base_cost is None):
                return None
            new_c = (total_cost - base_cost) / extra if extra > 0 else None
            new_pos = extra
        elif y > x:
            extra = y - x
            total_rev = bid_rev(y)
            base_rev = bid_rev(arbitrage_qty) if arbitrage_qty > 0 else 0
            if (y > 0 and total_rev is None) or (arbitrage_qty > 0 and base_rev is None):
                return None
            new_c = (total_rev - base_rev) / extra if extra > 0 else None
            new_pos = -extra
        else:
            new_pos = 0
            new_c = None
        if new_c is not None:
            new_c = round(new_c)
        return (profit, new_pos, new_c)

# --- Simulate Individual Trades from a Group ---
def simulate_trades(group, x, y):
    """
    Simulates the individual trades executed for a decision in a given group:
      - Fills x units on the ask side (producing buy trades with positive quantity)
      - Fills y units on the bid side (producing sell trades with negative quantity)
    Returns a list of individual trade records (with 'timestamp', 'price', and 'quantity').
    """
    trades = []
    remaining_buy = x
    for order in group["ask_orders"]:
        if remaining_buy <= 0:
            break
        trade_qty = min(order["volume"], remaining_buy)
        trades.append({
            "timestamp": group["timestamp"],
            "price": order["price"],
            "quantity": trade_qty  # positive for buy
        })
        remaining_buy -= trade_qty
    remaining_sell = y
    for order in group["bid_orders"]:
        if remaining_sell <= 0:
            break
        trade_qty = min(order["volume"], remaining_sell)
        trades.append({
            "timestamp": group["timestamp"],
            "price": order["price"],
            "quantity": -trade_qty  # negative for sell
        })
        remaining_sell -= trade_qty
    return trades

# --- Faster Approximate DP with Enriched Candidate Set ---
def optimize_pnl_for_product(df, max_position):
    """
    Uses an approximate DP which leverages:
      - Coarse cost basis rounding (to an integer)
      - A candidate set that is enriched with cumulative volumes aggregated per price level,
        as well as fixed extreme and fraction-based candidates.
      - It returns the approximate maximum realized profit and a flat list of individual trade records.
    """
    groups = group_orders_by_timestamp(df)
    n_groups = len(groups)
    memo = {}
    decision = {}

    global state_count
    state_count = 0

    def dp(g, pos, c):
        global state_count
        state_count += 1
        if state_count % STATE_PRINT_INTERVAL == 0:
            print(f"Processed {state_count} states; group: {g}, pos: {pos}, cost: {c}")
        if g == n_groups:
            return 0
        key = (g, pos, None if c is None else round(c, 0))
        if key in memo:
            return memo[key]
        cur = groups[g]
        best = -math.inf
        best_decision = (0, 0)
        cand = candidate_set(cur)
        for x, y in cand:
            new_pos = pos + x - y
            if abs(new_pos) > max_position:
                continue
            res = process_decision(pos, c, x, y, cur)
            if res is None:
                continue
            profit_current, new_state_pos, new_state_c = res
            candidate_profit = profit_current + dp(g + 1, new_state_pos, new_state_c)
            if candidate_profit > best:
                best = candidate_profit
                best_decision = (x, y)
        memo[key] = best
        decision[key] = best_decision
        return best

    max_profit = dp(0, 0, None)

    all_trades = []
    def reconstruct(g, pos, c):
        if g == n_groups:
            return
        key = (g, pos, None if c is None else round(c, 0))
        x, y = decision.get(key, (0, 0))
        cur = groups[g]
        # Simulate and record all individual trades for this group.
        if (x, y) != (0, 0):
            trades = simulate_trades(cur, x, y)
            all_trades.extend(trades)
        res = process_decision(pos, c, x, y, cur)
        if res is None:
            new_pos, new_c = pos, c
        else:
            _, new_pos, new_c = res
        reconstruct(g + 1, new_pos, new_c)
    reconstruct(0, 0, None)
    return max_profit, all_trades

# --- Grouping by Product and Running the Approximate DP ---
def optimize_orderbook(csv_file, max_position):
    df = pd.read_csv(csv_file)
    results = {}
    for product, group in df.groupby("product"):
        df_prod = group.sort_values("timestamp").reset_index(drop=True)
        profit, trades = optimize_pnl_for_product(df_prod, max_position)
        results[product] = {"profit": profit, "trades": trades}
    return results


# --- Example Usage ---
if __name__ == '__main__':
    csv_file = 'combined_book_unk_kelp_original.csv'  # adjust as needed
    max_position = 50
    results = optimize_orderbook(csv_file, max_position)
    for product, info in results.items():
        print(f"Product: {product}")
        print("Approximate final realized profit:", info["profit"])
        print("Individual trade records:")
        for trade in info["trades"]:
            print(trade)
        print("-" * 40)


Processed 10000 states; group: 997, pos: -2, cost: 2036
Processed 20000 states; group: 993, pos: -16, cost: None
Processed 30000 states; group: 1000, pos: -12, cost: 2034
Processed 40000 states; group: 995, pos: 30, cost: 2038
Processed 50000 states; group: 985, pos: 12, cost: 2037
Processed 60000 states; group: 981, pos: 15, cost: 2037
Processed 70000 states; group: 979, pos: -13, cost: 2035
Processed 80000 states; group: 988, pos: 6, cost: 2038
Processed 90000 states; group: 978, pos: 41, cost: 2037
Processed 100000 states; group: 990, pos: 33, cost: 2037
Processed 110000 states; group: 972, pos: 27, cost: 2038
Processed 120000 states; group: 968, pos: -20, cost: 2034
Processed 130000 states; group: 968, pos: 5, cost: 2036
Processed 140000 states; group: 965, pos: 32, cost: 2036
Processed 150000 states; group: 990, pos: -9, cost: 2036
Processed 160000 states; group: 961, pos: 35, cost: 2038
Processed 170000 states; group: 958, pos: 49, cost: 2037
Processed 180000 states; group: 957, 

In [70]:
results['SQUID_INK']

{'profit': 21911,
 'trades': [{'timestamp': 100, 'price': 1837, 'quantity': -2},
  {'timestamp': 100, 'price': 1836, 'quantity': -2},
  {'timestamp': 100, 'price': 1835, 'quantity': -2},
  {'timestamp': 300, 'price': 1837, 'quantity': 37},
  {'timestamp': 300, 'price': 1836, 'quantity': -2},
  {'timestamp': 400, 'price': 1836, 'quantity': 7},
  {'timestamp': 500, 'price': 1834, 'quantity': -6},
  {'timestamp': 600, 'price': 1833, 'quantity': 14},
  {'timestamp': 600, 'price': 1834, 'quantity': -10},
  {'timestamp': 600, 'price': 1832, 'quantity': -36},
  {'timestamp': 700, 'price': 1837, 'quantity': 2},
  {'timestamp': 700, 'price': 1838, 'quantity': 4},
  {'timestamp': 1400, 'price': 1842, 'quantity': 9},
  {'timestamp': 1400, 'price': 1841, 'quantity': -9},
  {'timestamp': 1500, 'price': 1844, 'quantity': -2},
  {'timestamp': 1500, 'price': 1842, 'quantity': -5},
  {'timestamp': 1600, 'price': 1844, 'quantity': -4},
  {'timestamp': 1600, 'price': 1843, 'quantity': -28},
  {'timestamp

In [68]:
kelp_trades = pd.DataFrame(results['SQUID_INK']['trades'])
kelp_trades['string'] = kelp_trades.apply(lambda row: f"{row['timestamp']},{row['price']},{row['quantity']}", axis=1)
kelp = ";".join(kelp_trades['string'])


In [69]:
kelp

'100,1837,-2;100,1836,-2;100,1835,-2;300,1837,37;300,1836,-2;400,1836,7;500,1834,-6;600,1833,14;600,1834,-10;600,1832,-36;700,1837,2;700,1838,4;1400,1842,9;1400,1841,-9;1500,1844,-2;1500,1842,-5;1600,1844,-4;1600,1843,-28;1700,1843,16;1700,1844,17;1800,1845,6;2000,1846,-13;2100,1847,10;2300,1848,-7;2400,1849,10;2600,1850,-1;2600,1849,-22;2700,1851,12;2800,1849,-7;2900,1851,10;3200,1849,8;4100,1857,6;4100,1855,-10;4200,1854,-7;4500,1854,10;4600,1858,6;4600,1860,1;4600,1857,-7;4800,1861,-2;4800,1860,-7;4800,1859,-23;4900,1859,9;5000,1860,-12;5000,1859,-9;5300,1860,20;5400,1860,8;5500,1859,11;5600,1860,-5;5600,1859,-3;5900,1860,12;6100,1862,6;6100,1859,-6;6200,1861,-2;6200,1859,-6;6300,1861,-9;6500,1860,-10;6600,1860,-22;6700,1861,14;6700,1862,31;6700,1861,-1;9100,1858,6;9300,1857,12;9500,1856,26;9500,1854,-1;9500,1853,-25;9700,1856,10;9700,1857,5;9700,1858,8;9900,1858,-6;10100,1859,8;10300,1860,-1;10300,1858,-7;10500,1859,6;10600,1858,-6;10800,1857,6;10800,1859,2;11000,1864,-2;11000,1863