In [11]:
import sigrun
from sigrun import polars as pl
from sigrun.polars import col as c
from sigrun.workflow import factor

from sigrun import virgo as sv
VIRGO_TOKEN = "ltIIImLNjbF4V6Qk"
sv.init("virgo", VIRGO_TOKEN)

import datetime

import os
os.environ["SIGRUN_IO_ENGINE"] = "false"

import logging
logging.basicConfig(level=logging.INFO)

import numpy as np
from black_litterman import BLModel


In [12]:
today = datetime.date(2025, 3, 28)
yesterday = datetime.date(2025, 3, 27)
from_date = datetime.date(2025, 3, 1)

def hot_pivoted(start_date, end_date):
    lf_hot = sv.table.read("stock.ths_hot_list", partitions = { "time_window":"24h"},
        start_date=start_date.strftime("%Y-%m-%d"),
        end_date=end_date.strftime("%Y-%m-%d"),
    )

    lf_hot = (lf_hot
                # .filter(  (c("date").dt.date() < pl.date("2025-03-30"))
                #         | ((c("date").dt.date() == pl.date("2025-03-30")) & (c("time").str.strptime(pl.Time, "%H:%M:%S")<pl.time(10, 0, 0)))
                #     )
        .with_columns_expr(c("date").dt.strftime("%Y-%m-%d").alias("date_str"))
        .with_columns((c("date_str")+" "+c("time")).alias("datetime_str"))
        .with_columns(c("datetime_str").str.strptime(pl.Datetime, "%Y-%m-%d %H:%M:%S").alias("datetime"))
        .with_columns(c("datetime").cast(pl.Datetime("ns")))
        .with_columns(c("rate").cast(pl.Float64))
        .with_columns(rate=pl.when(c("rate")==0.0).then(None).otherwise(c("rate")))
        .drop(["date", "time", "date_str", "datetime_str"])
        .select("datetime", "symbol", "rate", "type")
    )
    hot_pivoted = (lf_hot.collect().pivot(values="rate", index=["symbol", "datetime"], columns="type"))

    hot_types = ["normal", "skyrocket", "tech", "value", "trend"]
    for col in hot_types:
        if col not in hot_pivoted.columns:
            hot_pivoted = hot_pivoted.with_columns(pl.lit(None, dtype=pl.Float64).alias(col))

    hot_pivoted = (hot_pivoted#.group_by("symbol")
                    # .agg(
                    #     normal = c("normal").max(),
                    #     skyrocket = c("skyrocket").max(),
                    #     tech = c("tech").max(),
                    #     value = c("value").max(),
                    #     trend = c("trend").max(),
                    # )
                    .with_columns(
                        pl.when(pl.col(col).is_null()).then(0.0).otherwise(pl.col(col)).alias(col)
                        for col in hot_types
                    )
                    .lazy()
    )
    return hot_pivoted

print(hot_pivoted(from_date, today).collect())


def price_ret(start_date, end_date):
    lf_price = (sv.stock.bars("all", "1m", from_date=start_date, to_date=end_date)
        # .filter((c("date")<today) | (((c("date")==today) & (c("time")<36_000_000_000_000)))) #早上10点前数据
        .with_columns(ret=c("close")/c("pre_close") - 1)
        .with_columns(log_ret=np.log(c("close")/c("pre_close")))
        .with_columns(datetime=c("date")+c("time"))
        .select("symbol", "date", "datetime", "ret", "log_ret")
        # .group_by("symbol")
        # .agg([
        #     c("close").filter(c("date") == today).mean().alias("mean_close_today"),
        #     c("close").mean().alias("mean_close_all"),
        #     c("close").std().alias("std_close_all"),
        # ])
        # .select(
        #     "symbol",
        #     td_norm = (c("mean_close_today") - c("mean_close_all")) / c("std_close_all")
        # )
    )
    return lf_price
print(price_ret(from_date, today).collect())


df = hot_pivoted(from_date, today).join(
    price_ret(from_date, today),
    on=["symbol", "datetime"],
)

df.collect()

  [2m2025-07-08T09:11:04.358848505+08:00[0m [33m WARN[0m [1;33mtable[0m[33m:[33m310: [33mtransport error, retrying, reason: status: Unknown, message: "transport error", details: [], metadata: MetadataMap { headers: {} }[0m

  [2m2025-07-08T09:11:04.859616081+08:00[0m [33m WARN[0m [1;33mtable[0m[33m:[33m310: [33mtransport error, retrying, reason: status: Unknown, message: "transport error", details: [], metadata: MetadataMap { headers: {} }[0m

  [2m2025-07-08T09:11:05.900501679+08:00[0m [33m WARN[0m [1;33mtable[0m[33m:[33m310: [33mtransport error, retrying, reason: status: Unknown, message: "transport error", details: [], metadata: MetadataMap { headers: {} }[0m



  hot_pivoted = (lf_hot.collect().pivot(values="rate", index=["symbol", "datetime"], columns="type"))


shape: (185_860, 7)
┌────────────┬─────────────────────┬─────────────┬───────────┬─────────┬───────┬──────────┐
│ symbol     ┆ datetime            ┆ normal      ┆ skyrocket ┆ tech    ┆ trend ┆ value    │
│ ---        ┆ ---                 ┆ ---         ┆ ---       ┆ ---     ┆ ---   ┆ ---      │
│ str        ┆ datetime[ns]        ┆ f64         ┆ f64       ┆ f64     ┆ f64   ┆ f64      │
╞════════════╪═════════════════════╪═════════════╪═══════════╪═════════╪═══════╪══════════╡
│ 600126.SSE ┆ 2025-03-01 00:00:00 ┆ 5.0596897e7 ┆ 0.0       ┆ 0.0     ┆ 0.0   ┆ 112489.0 │
│ 002369.SZE ┆ 2025-03-01 00:00:00 ┆ 4.4714157e7 ┆ 0.0       ┆ 0.0     ┆ 0.0   ┆ 66976.0  │
│ 002261.SZE ┆ 2025-03-01 00:00:00 ┆ 2.946216e7  ┆ 0.0       ┆ 0.0     ┆ 0.0   ┆ 76490.0  │
│ 603496.SSE ┆ 2025-03-01 00:00:00 ┆ 2.9266911e7 ┆ 0.0       ┆ 72986.0 ┆ 0.0   ┆ 134973.0 │
│ 000759.SZE ┆ 2025-03-01 00:00:00 ┆ 2.7368897e7 ┆ 0.0       ┆ 32182.0 ┆ 0.0   ┆ 51222.0  │
│ …          ┆ …                   ┆ …           ┆ …        

symbol,datetime,normal,skyrocket,tech,trend,value,date,ret,log_ret
str,datetime[ns],f64,f64,f64,f64,f64,datetime[ns],f64,f64
"""000032.SZE""",2025-03-03 10:00:00,0.0,0.0,5812.0,0.0,0.0,2025-03-03 00:00:00,-0.003798,-0.003805
"""000032.SZE""",2025-03-03 11:00:00,0.0,0.0,4818.0,0.0,0.0,2025-03-03 00:00:00,0.003077,0.003072
"""000032.SZE""",2025-03-03 14:00:00,0.0,0.0,29817.0,0.0,0.0,2025-03-03 00:00:00,-0.00077,-0.00077
"""000032.SZE""",2025-03-03 15:00:00,0.0,0.0,10627.0,0.0,0.0,2025-03-03 00:00:00,-0.002311,-0.002314
"""000034.SZE""",2025-03-03 10:00:00,674126.0,0.0,0.0,0.0,22767.0,2025-03-03 00:00:00,-0.005382,-0.005396
…,…,…,…,…,…,…,…,…,…
"""688409.SSE""",2025-03-28 10:00:00,0.0,21732.5,0.0,0.0,0.0,2025-03-28 00:00:00,0.001311,0.001311
"""688605.SSE""",2025-03-28 15:00:00,0.0,56698.5,0.0,0.0,0.0,2025-03-28 00:00:00,0.000267,0.000267
"""688630.SSE""",2025-03-28 10:00:00,0.0,18278.0,0.0,0.0,0.0,2025-03-28 00:00:00,-0.001475,-0.001476
"""688981.SSE""",2025-03-28 10:00:00,4.719677e6,57855.5,0.0,0.0,47959.0,2025-03-28 00:00:00,-0.001423,-0.001424


In [9]:

def bl_prediction(df):
    window_size = 60

    # Collect all data into a DataFrame for groupby
    df = df.collect()

    # Prepare to store predictions
    predictions = []

    # Group by stock
    for stock_name, stock_df in df.group_by("symbol"):
        # Sort by date
        stock_df = stock_df.sort("datetime")
        hot_rate_cols = ["normal", "skyrocket", "tech", "value", "trend"]
        hot_rate_mat = stock_df.select(hot_rate_cols).to_numpy()
        returns_arr = stock_df["ret"].to_numpy()
        dates_arr = stock_df["date"].to_numpy()

        n = len(stock_df)
        for i in range(window_size, n - 1):
            window_hot_rate = hot_rate_mat[i-window_size:i, :]
            window_returns = returns_arr[i-window_size:i]
            next_hot_rate = hot_rate_mat[i+1, :]
            next_date = dates_arr[i+1]

            bl = BLModel(window_hot_rate, window_returns)
            bl.fit()
            pred = bl.predict_next_return(next_hot_rate)

            predictions.append({
                "stock": stock_name,
                "date": next_date,
                "predicted_return": pred
            })

    # Convert predictions to Polars DataFrame
    pred_df = pl.DataFrame(predictions)
    return pred_df

print(bl_prediction(df))


  posterior = (tau_var * view + self.omega * self.prior_mean) / (tau_var + self.omega)


shape: (404, 3)
┌────────────────┬───────────┬──────────────────┐
│ stock          ┆ date      ┆ predicted_return │
│ ---            ┆ ---       ┆ ---              │
│ list[str]      ┆ f64       ┆ f64              │
╞════════════════╪═══════════╪══════════════════╡
│ ["600126.SSE"] ┆ 1.7428e18 ┆ -0.000182        │
│ ["600126.SSE"] ┆ 1.7428e18 ┆ -0.000315        │
│ ["600126.SSE"] ┆ 1.7428e18 ┆ -0.000121        │
│ ["600126.SSE"] ┆ 1.7429e18 ┆ -0.000125        │
│ ["600126.SSE"] ┆ 1.7429e18 ┆ -0.000109        │
│ …              ┆ …         ┆ …                │
│ ["300059.SZE"] ┆ 1.7430e18 ┆ -0.000127        │
│ ["300059.SZE"] ┆ 1.7431e18 ┆ -0.000142        │
│ ["300059.SZE"] ┆ 1.7431e18 ┆ -0.000135        │
│ ["300059.SZE"] ┆ 1.7431e18 ┆ -0.000169        │
│ ["300059.SZE"] ┆ 1.7431e18 ┆ -0.000155        │
└────────────────┴───────────┴──────────────────┘
