# Benchmark Your Device

In [1]:
import random
import json
import time
from datetime import datetime
import pytz
from collections import defaultdict
from threading import Lock
import concurrent.futures

# 假设原始结构
SYMBOLS = ["AAPL", "TSLA", "GOOG", "SPY", "HOOD", "NVDA", "F", "QQQ", "AMD", "TSM"]
raw_ticks_lock = Lock ()
raw_ticks = defaultdict (lambda: defaultdict (list))


# 模拟 trade 消息
def generate_test_message (symbol="AAPL", num_trades=100):
    trades = []
    now = int (time.time () * 1000)
    for _ in range (num_trades):
        trades.append ({
            "s": symbol,
            "p": 100.0,
            "v": 10,
            "t": now
        })
    return json.dumps ({"type": "trade", "data": trades})


# 包装 on_message 函数测试吞吐量
def test_throughput (on_message_fn, num_batches=1000, trades_per_batch=100):
    msgs = []
    for _ in range (num_batches):
        symbol = random.choice (SYMBOLS)
        message = generate_test_message (symbol=symbol, num_trades=trades_per_batch)
        msgs.append (message)

    start = time.time ()
    for msg in msgs:
        on_message_fn (None, msg)
    end = time.time ()
    total_trades = num_batches * trades_per_batch
    elapsed = end - start
    print (f"Processed {total_trades} trades in {elapsed:.2f} seconds")
    print (f"Throughput: {total_trades / elapsed:.2f} trades/sec")


def test_peak_throughput (on_message_fn, num_threads=100, trades_per_thread=1000):
    msgs = []
    for _ in range (num_threads):
        symbol = random.choice (SYMBOLS)
        message = generate_test_message (symbol=symbol, num_trades=trades_per_thread)
        msgs.append (message)


    def task ():
        on_message_fn (None, message)


    start = time.time ()
    with concurrent.futures.ThreadPoolExecutor (max_workers=num_threads) as executor:
        futures = [executor.submit (task) for message in msgs]
        concurrent.futures.wait (futures)
    end = time.time ()

    total_trades = num_threads * trades_per_thread
    elapsed = end - start
    print (f"[Peak Load] Processed {total_trades} trades in {elapsed:.4f} seconds")
    print (f"[Peak Load] Peak throughput: {total_trades / elapsed:.2f} trades/sec")




# 和websocket里的方法一摸一样
def on_message (ws, message):
    msg = json.loads (message)
    if msg.get ('type') != 'trade':
        return
    for trade in msg['data']:
        sym, price, vol, ts = trade['s'], trade['p'], trade['v'], trade['t']
        if sym not in SYMBOLS:
            continue
        # 转换为美东时间，并 floor 到分钟
        dt = datetime.fromtimestamp (ts / 1000, pytz.UTC) \
            .astimezone (pytz.timezone ('America/New_York'))
        minute_ts = dt.replace (second=0, microsecond=0) \
            .strftime ('%Y-%m-%d %H:%M:00')

        with raw_ticks_lock:
            raw_ticks[minute_ts][sym].append ({
                'timestamp': dt,
                'price': price,
                'volume': vol
            })

### (1) Test your overall process speed (tcp receiving)

In [None]:
# 调用测试
test_throughput (on_message, num_batches=1000, trades_per_batch=100)

### (2) Test your peak process speed (tcp receiving)

In [None]:
test_peak_throughput (on_message, num_threads=1000, trades_per_thread=100)