# Let's Build a Quant Trading Strategy - Part 3

In [668]:
# y_hat = model(x)
# orders = strategy(y_hat)
# execute(orders)

In [669]:
# Part 1 = Research
# Part 2 = Strategy
# Part 3 = Implementation 

## Goals

In [670]:
# 1. Code it all together in Python (Model + Strategy) (ideally a statically-typed language like rust)
# 2. Show to build a trading system in a scalable way (very easy to create new strategies + streams data)
# 3. Put it live with real money to see how it performs

## Recap

### Part 1: Research

In [671]:
import models
import torch
import research

model = models.LinearModel(3)
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()

LinearModel(
  (linear): Linear(in_features=3, out_features=1, bias=True)
)

### AR(3) Model to predict future log return - 12h forecast horizon

In [672]:
research.print_model_params(model)

linear.weight:
[[-0.10395038 -0.06726477  0.02827305]]
linear.bias:
[0.00067121]


### Part 2: Strategy Recap

In [673]:
# ~14% without any optimization
# 1. Compounding Trade Sizing
# 2. Leverage
# ~14% to >40%

### Fundamental Building Block: Tick

In [674]:
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

T = TypeVar('T')  # input type
R = TypeVar('R')  # output type

class Tick(ABC, Generic[T, R]):
    @abstractmethod
    def on_tick(self, val: T) -> R:
        """Handle a new tick and optionally return a result."""
        pass

### Sliding Window: The fundamental data structure

In [675]:
from collections import deque
from typing import Deque, Optional
import numpy as np

class DequeWindow(Tick[T, Optional[T]], Generic[T]):
    def __init__(self, n: int):
        self._data: Deque[T] = deque(maxlen=n)

    def on_tick(self, val: T) -> Optional[T]:
        """Append a value and return the oldest value dropped (if any)."""
        dropped = None
        if self.is_full():
            dropped = self._data[0]
        self._data.append(val)
        return dropped
    
    def is_full(self) -> bool:
        return self._data.maxlen == len(self._data)

    def append_left(self, val: T) -> Optional[T]:
        dropped = None
        if self.is_full():
            dropped = self._data[-1]
        self._data.appendleft(val)
        return dropped
    
    def to_numpy(self) -> np.ndarray:
        return np.array(self._data)
    
    def __repr__(self) -> str:
        cls_name = self.__class__.__name__
        return f"{cls_name}(capacity={self._data.maxlen}, values={list(self._data)})"

In [676]:
w = DequeWindow(3)
w

DequeWindow(capacity=3, values=[])

In [677]:
w = DequeWindow(3)
w.on_tick(1)
w

DequeWindow(capacity=3, values=[1])

In [678]:
w = DequeWindow(3)
w.on_tick(1)
w.on_tick(2)
w.on_tick(3)
w

DequeWindow(capacity=3, values=[1, 2, 3])

In [679]:
w = DequeWindow(3)
w.on_tick(1)
w.on_tick(2)
w.on_tick(3)
w.on_tick(4)
w

DequeWindow(capacity=3, values=[2, 3, 4])

In [680]:
for i in range(1000000):
    w.on_tick(i)
w

DequeWindow(capacity=3, values=[999997, 999998, 999999])

### Array-based Window

In [681]:
import numpy as np
from typing import Optional

class NumpyWindow(Tick[T, Optional[T]]):
    def __init__(self, n: int, dtype=np.float64):
        if n <= 0:
            raise ValueError("Capacity must be positive.")
        self._capacity = n
        self._data = np.zeros(n, dtype=dtype)
        self._size = 0

    def on_tick(self, val: float) -> Optional[float]:
        dropped = None

        if self._size < self._capacity:
            self._data[self._size] = val
            self._size += 1
        else:
            dropped = self._data[0]
            # shift left in-place
            for i in range(1, self._capacity):
                self._data[i - 1] = self._data[i]
            self._data[-1] = val

        return dropped


    def __getitem__(self, idx: int) -> float:
        """Index access (0 = oldest)."""
        if not 0 <= idx < self._size:
            raise IndexError("Index out of range.")
        return self._data[idx]

    def __len__(self) -> int:
        return self._size

    def capacity(self) -> int:
        return self._capacity

    def is_full(self) -> bool:
        return self._size == self._capacity

    def values(self) -> np.ndarray:
        return self._data[:self._size]

    def __repr__(self) -> str:
        vals = self.values().tolist()
        return f"{self.__class__.__name__}(capacity={self._capacity}, size={self._size}, values={vals})"


### Benchmark Numpy Window vs Deque Window

In [682]:
def benchmark_window(window, n):
    for i in range(n):
        window.on_tick(i)

window_size = 10
n = 5000000

In [683]:
%%time
benchmark_window(NumpyWindow(window_size), n)

CPU times: user 5.12 s, sys: 91.4 ms, total: 5.21 s
Wall time: 5.99 s


In [684]:
%%time
benchmark_window(DequeWindow(window_size), n)

CPU times: user 536 ms, sys: 6.94 ms, total: 543 ms
Wall time: 544 ms


### Stream the Last Known Value 

In [685]:
class Last(Tick[T, T], Generic[T]):
    def __init__(self):
        self._value: Optional[T] = None

    def on_tick(self, val: T) -> Optional[T]:
        self._value = val
        return val

    def __repr__(self) -> str:
        cls_name = self.__class__.__name__
        return f"{cls_name}(value={self._value})"

In [686]:
last_val = Last()
last_val

Last(value=None)

In [687]:
last_val = Last()
for i in range(5):
    last_val.on_tick(i)
last_val

Last(value=4)

## Streaming Log Returns

### Recap

In [688]:
ts = [100, 120, 100]
log_returns = [
    np.log(ts[1] / ts[0]),
    np.log(ts[2] / ts[1])
]
log_returns

[np.float64(0.1823215567939546), np.float64(-0.1823215567939546)]

In [689]:
np.sum(log_returns)

np.float64(0.0)

In [690]:
import numpy as np

class LogReturn(Tick[float, Optional[float]], Generic[T]):
    def __init__(self):
        self._window = NumpyWindow(2)

    def on_tick(self, val: float) -> Optional[float]:
        self._window.on_tick(val)
        if self._window.is_full():
            return np.log(self._window[1] / self._window[0])
        else:
            return None
        
    def __repr__(self) -> str:
        cls_name = self.__class__.__name__
        return f"{cls_name}(window={self._window})"
    

In [691]:
f = LogReturn()
f.on_tick(100.0)
f

LogReturn(window=NumpyWindow(capacity=2, size=1, values=[100.0]))

In [692]:
v = f.on_tick(120.0)
v

np.float64(0.1823215567939546)

In [693]:
f

LogReturn(window=NumpyWindow(capacity=2, size=2, values=[100.0, 120.0]))

In [694]:
v = f.on_tick(100.0)
v

np.float64(-0.1823215567939546)

In [695]:
f

LogReturn(window=NumpyWindow(capacity=2, size=2, values=[120.0, 100.0]))

### Streaming Auto-Regressive Log Returns Lags

### Recap

In [696]:
time_series = [0.1, -0.2, -0.3]
lag_1 = time_series[-1]
lag_1

-0.3

In [697]:
lag_2 = time_series[-2]
lag_2

-0.2

In [698]:
lag_3 = time_series[-3]
lag_3

0.1

In [699]:
class LogReturnLags(Tick[float, torch.Tensor]):
    def __init__(self, no_lags: int):
        self._lags = DequeWindow(no_lags)
        self._log_return = LogReturn()
    
    def on_tick(self, val: float) -> torch.Tensor | None:
        log_ret = self._log_return.on_tick(val)
        if log_ret is not None:
            self._lags.append_left(log_ret)
            return torch.tensor(self._lags.to_numpy(), dtype=torch.float32) if self._lags.is_full() else None
        else:
            return None
        
    def __repr__(self) -> str:
        cls_name = self.__class__.__name__
        return f"{cls_name}(lags={self._lags}, log_return={self._log_return})" 

In [700]:
lags = LogReturnLags(3)
v = lags.on_tick(90)
lags

LogReturnLags(lags=DequeWindow(capacity=3, values=[]), log_return=LogReturn(window=NumpyWindow(capacity=2, size=1, values=[90.0])))

In [701]:
lags.on_tick(100)
lags

LogReturnLags(lags=DequeWindow(capacity=3, values=[np.float64(0.10536051565782635)]), log_return=LogReturn(window=NumpyWindow(capacity=2, size=2, values=[90.0, 100.0])))

In [702]:
lags.on_tick(150)
lags

LogReturnLags(lags=DequeWindow(capacity=3, values=[np.float64(0.4054651081081644), np.float64(0.10536051565782635)]), log_return=LogReturn(window=NumpyWindow(capacity=2, size=2, values=[100.0, 150.0])))

In [703]:
lags.on_tick(110)

tensor([-0.3102,  0.4055,  0.1054])

In [704]:
[np.log(110/150), np.log(150/100),np.log(100/90)]

[np.float64(-0.3101549283038396),
 np.float64(0.4054651081081644),
 np.float64(0.10536051565782635)]

In [705]:
lags = LogReturnLags(3)
lags.on_tick(90)
lags.on_tick(100)
lags.on_tick(150)
lags.on_tick(110)
features = lags.on_tick(160)
features

tensor([ 0.3747, -0.3102,  0.4055])

### Streaming features into our model

In [706]:
X = features
with torch.no_grad():
    y_hat = model(X)
y_hat

tensor([-0.0060])

In [707]:
y_hat[0]

tensor(-0.0060)

## Build the Trading System

### Using Decimal to represent money

In [708]:
val = 0.1
total = 0.0
for i in range(10):
    total += val
total

0.9999999999999999

In [709]:
from decimal import Decimal
dp = Decimal('0.2')
val = Decimal(0.1).quantize(dp)
total = Decimal(0.0).quantize(dp)
for i in range(10):
    total += val
total

Decimal('1.0')

In [710]:
from dataclasses import dataclass

@dataclass(frozen=True)
class Order:
    sym: str
    signed_qty: Decimal

    def __str__(self) -> str:
        sign = "LONG" if self.signed_qty > 0 else "SHORT"
        return f"Order({sign} {self.signed_qty} {self.sym})"

In [711]:
from decimal import Decimal

def decimal_sign(d: Decimal) -> int:
    return 1 if d > Decimal(0) else -1

def is_long(x: Decimal) -> bool:
    return decimal_sign(x) > 0

@dataclass(frozen=True)
class Trade:
    sym: str
    signed_qty: Decimal
    price: Decimal
    pnl: Decimal

    def __str__(self) -> str:
        sign = "LONG" if is_long(self.signed_qty) else "SHORT"
        return f"Trade({sign} {self.signed_qty} {self.sym} {self.price} {self.pnl})"

    def is_long(self) -> bool:
        return is_long(self.signed_qty)




In [712]:
@dataclass
class Position:
    sym: str
    signed_qty: Decimal
    price: Decimal

    def close(self) -> "Order":
        return Order(self.sym, -self.signed_qty)
    
    def is_long(self) -> bool:
        return is_long(self.signed_qty)
    
    def unrealized_pnl(self, current_price: Decimal) -> Decimal:
        entry_val = self.price * self.signed_qty
        exit_val = current_price * -self.signed_qty
        return entry_val + exit_val


In [713]:
from abc import ABC, abstractmethod
from decimal import Decimal

class Account(ABC):
    @abstractmethod
    def balance(self) -> Decimal:
        pass

    @abstractmethod
    def get_position(self, sym: str) -> Optional[Position]:
        pass   

In [714]:
from decimal import Decimal
from typing import Dict, List, Optional

class TestAccount(Account):
    """A simulated account for testing or paper trading."""

    def __init__(self, _balance: Decimal) -> None:
        self._balance = _balance
        self._positions: Dict[str, Position] = {}
        self._trades: List[Trade] = []

    def balance(self) -> Decimal:
        return self._balance

    def get_position(self, sym) -> Optional[Position]:
        return self._positions.get(sym)

    def __repr__(self) -> str:
        return f"TestAccount(balance={self._balance}, positions={self._positions}, trades={self._trades})"


In [715]:
acc = TestAccount(Decimal(50.0))
acc.balance()

Decimal('50')

In [716]:
acc

TestAccount(balance=50, positions={}, trades=[])

### Model an Exchange

In [717]:
from abc import abstractmethod
from decimal import Decimal

class Exchange(Account):
    """Abstract base class representing a trading exchange/broker."""

    @abstractmethod
    def market_order(self, sym: str, signed_qty: Decimal, price: Decimal) -> Trade:
        """Execute a market order and return a Trade result."""
        pass

    @abstractmethod
    def limit_order(self, sym: str, signed_qty: Decimal, price: Decimal, post_only: bool = False) -> Optional[Trade]:
        """Execute a limit order and return a Trade if it crosses book."""
        pass

In [718]:
from typing import Dict,List

class TestExchange(Exchange):
    _account: TestAccount

    def __init__(self, account: TestAccount):
        self._account = account

    def market_order(self, sym: str, signed_qty: Decimal, price: Decimal) -> "Trade":        
        # Update balance and position
        trade = self._update_position(sym, signed_qty, price)
        self._account._balance += trade.pnl
        self._account._trades.append(trade)
        return trade
    
    def _update_position(self, sym: str, signed_qty, price: Decimal) -> Trade:
        position = self._account._positions.pop(sym, None)
        pnl = Decimal(0.0)
        if position is not None:
            entry_val = position.price * position.signed_qty
            exit_val = price * position.signed_qty              
            pnl = exit_val - entry_val  
        else:
            self._account._positions[sym] = Position(sym, signed_qty, price) 
        return Trade(sym, signed_qty, price, pnl)    

    def limit_order(self, sym, signed_qty, price, post_only = False):
        raise Exception("not yet implemented")
    
    def balance(self) -> Decimal:
        return self._account.balance()
    
    def get_position(self, sym) -> Optional[Position]:
        return self._account.get_positions(sym)
    
    def __repr__(self) -> str:
        return f"TestExchange(balance={self.balance()}, positions={self._account._positions}, trades={self._account._trades})"

### Open Position

In [719]:
exchange = TestExchange(TestAccount(Decimal(50.0)))

price = Decimal(10)
qty = Decimal(5.0)
exchange.market_order('BTCUSDT', qty, Decimal(price))

Trade(sym='BTCUSDT', signed_qty=Decimal('5'), price=Decimal('10'), pnl=Decimal('0'))

In [720]:
exchange

TestExchange(balance=50, positions={'BTCUSDT': Position(sym='BTCUSDT', signed_qty=Decimal('5'), price=Decimal('10'))}, trades=[Trade(sym='BTCUSDT', signed_qty=Decimal('5'), price=Decimal('10'), pnl=Decimal('0'))])

### Close Position

In [721]:
price = Decimal(15.0)
exchange.market_order('BTCUSDT', -qty, price)

Trade(sym='BTCUSDT', signed_qty=Decimal('-5'), price=Decimal('15'), pnl=Decimal('25'))

In [722]:
exchange

TestExchange(balance=75, positions={}, trades=[Trade(sym='BTCUSDT', signed_qty=Decimal('5'), price=Decimal('10'), pnl=Decimal('0')), Trade(sym='BTCUSDT', signed_qty=Decimal('-5'), price=Decimal('15'), pnl=Decimal('25'))])

In [723]:
entry_notional_value = Decimal(5) * Decimal(10)
entry_notional_value

Decimal('50')

In [724]:
exit_notional_val = Decimal(5) * Decimal(15)
exit_notional_val

Decimal('75')

In [725]:
exit_notional_val - entry_notional_value

Decimal('25')

### Build Strategy API

In [726]:
class Strategy(ABC):
    @abstractmethod
    def on_tick(self, price: float, account: Account) -> Optional[List[Order]]:
        pass

### Implement our strategy

In [727]:
import torch.nn as nn

class BasicTakerStrat(Strategy):
    def __init__(self, 
                 sym: str,
                 model: nn.Module, 
                 log_return_lags: LogReturnLags, 
                 scale_factor: Decimal = None) -> None:
        self.sym = sym
        self.model = model
        self.log_return_lags = log_return_lags
        if scale_factor is None:
            scale_factor = Decimal(1.0)
        self.scale_factor = Decimal(scale_factor)

    def _signed_compound_trade_size(self, y_hat: float, account: Account, cur_price: Decimal, position: Optional[Position]) -> Decimal:
        dir_signal = np.sign(y_hat)
        cur_balance =  account.balance()
        unrealized_balance = cur_balance + (position.unrealized_pnl(cur_price) if position else Decimal(0.0))
        qty = unrealized_balance / cur_price
        signed_qty = Decimal(dir_signal) * qty
        return signed_qty * self.scale_factor

    def _create_orders(self, y_hat: torch.Tensor, account: Account, price: Decimal) -> List[Order]:
        position = account.get_position(self.sym)
        signed_trade_size = self._signed_compound_trade_size(y_hat.item(), account, price, position)
        open_order = Order(self.sym, signed_trade_size)
        if position is not None:
            close_order = Order(position.sym, -position.signed_qty)
            return [close_order, open_order]
        return [open_order]      

    def on_tick(self, price: float, account: Account) -> List[Order]:
        X = self.log_return_lags.on_tick(price)
        if X is not None:
            with torch.no_grad():                
                y_hat = self.model(X)
                orders = self._create_orders(y_hat, account, Decimal(price))
                return orders
        return []

In [728]:
# Window to stream lagged log returns
lags = LogReturnLags(3)
# Create Account
acc = TestAccount(Decimal(100.0))
# Create strategy
strat = BasicTakerStrat('BTCUSDT', model, lags, Decimal(1.0))

# First 12 hour interval - 2025/10/20 00:00
strat.on_tick(10.0, acc)

[]

In [729]:
# Second 12 hour interval - 2025/10/20 12:00
strat.on_tick(120.0, acc)

[]

In [730]:
# Third 12 hour interval - 2025/10/21 00:00
strat.on_tick(90.0, acc)

[]

In [731]:
# Fourth 12 hour interval - 2025/10/21 12:00
orders = strat.on_tick(100, acc)
orders

[Order(sym='BTCUSDT', signed_qty=Decimal('1'))]

### Execute Order

In [732]:
exchange = TestExchange(acc)
order = orders[0]
exchange.market_order(order.sym, order.signed_qty, 100)

Trade(sym='BTCUSDT', signed_qty=Decimal('1'), price=100, pnl=Decimal('0'))

In [733]:
exchange

TestExchange(balance=100, positions={'BTCUSDT': Position(sym='BTCUSDT', signed_qty=Decimal('1'), price=100)}, trades=[Trade(sym='BTCUSDT', signed_qty=Decimal('1'), price=100, pnl=Decimal('0'))])

In [734]:
orders = strat.on_tick(115, acc)
orders

[Order(sym='BTCUSDT', signed_qty=Decimal('-1')),
 Order(sym='BTCUSDT', signed_qty=Decimal('-0.7391304347826086956521739130'))]

In [735]:
order = orders[0]
exchange.market_order(order.sym, order.signed_qty, 115)

Trade(sym='BTCUSDT', signed_qty=Decimal('-1'), price=115, pnl=Decimal('15'))

In [736]:
exchange

TestExchange(balance=115, positions={}, trades=[Trade(sym='BTCUSDT', signed_qty=Decimal('1'), price=100, pnl=Decimal('0')), Trade(sym='BTCUSDT', signed_qty=Decimal('-1'), price=115, pnl=Decimal('15'))])

In [737]:
order = orders[1]
exchange.market_order(order.sym, order.signed_qty, 115)

Trade(sym='BTCUSDT', signed_qty=Decimal('-0.7391304347826086956521739130'), price=115, pnl=Decimal('0'))

In [738]:
exchange

TestExchange(balance=115, positions={'BTCUSDT': Position(sym='BTCUSDT', signed_qty=Decimal('-0.7391304347826086956521739130'), price=115)}, trades=[Trade(sym='BTCUSDT', signed_qty=Decimal('1'), price=100, pnl=Decimal('0')), Trade(sym='BTCUSDT', signed_qty=Decimal('-1'), price=115, pnl=Decimal('15')), Trade(sym='BTCUSDT', signed_qty=Decimal('-0.7391304347826086956521739130'), price=115, pnl=Decimal('0'))])

In [739]:
orders = strat.on_tick(100, acc)
orders

[Order(sym='BTCUSDT', signed_qty=Decimal('0.7391304347826086956521739130')),
 Order(sym='BTCUSDT', signed_qty=Decimal('1.039130434782608695652173913'))]

In [740]:
order = orders[0]
exchange.market_order(order.sym, order.signed_qty, 100)

Trade(sym='BTCUSDT', signed_qty=Decimal('0.7391304347826086956521739130'), price=100, pnl=Decimal('11.08695652173913043478260870'))

In [741]:
exchange

TestExchange(balance=126.0869565217391304347826087, positions={}, trades=[Trade(sym='BTCUSDT', signed_qty=Decimal('1'), price=100, pnl=Decimal('0')), Trade(sym='BTCUSDT', signed_qty=Decimal('-1'), price=115, pnl=Decimal('15')), Trade(sym='BTCUSDT', signed_qty=Decimal('-0.7391304347826086956521739130'), price=115, pnl=Decimal('0')), Trade(sym='BTCUSDT', signed_qty=Decimal('0.7391304347826086956521739130'), price=100, pnl=Decimal('11.08695652173913043478260870'))])

In [742]:
order = orders[1]
exchange.market_order(order.sym, order.signed_qty, 100)

Trade(sym='BTCUSDT', signed_qty=Decimal('1.039130434782608695652173913'), price=100, pnl=Decimal('0'))

In [743]:
exchange

TestExchange(balance=126.0869565217391304347826087, positions={'BTCUSDT': Position(sym='BTCUSDT', signed_qty=Decimal('1.039130434782608695652173913'), price=100)}, trades=[Trade(sym='BTCUSDT', signed_qty=Decimal('1'), price=100, pnl=Decimal('0')), Trade(sym='BTCUSDT', signed_qty=Decimal('-1'), price=115, pnl=Decimal('15')), Trade(sym='BTCUSDT', signed_qty=Decimal('-0.7391304347826086956521739130'), price=115, pnl=Decimal('0')), Trade(sym='BTCUSDT', signed_qty=Decimal('0.7391304347826086956521739130'), price=100, pnl=Decimal('11.08695652173913043478260870')), Trade(sym='BTCUSDT', signed_qty=Decimal('1.039130434782608695652173913'), price=100, pnl=Decimal('0'))])