<a href="https://colab.research.google.com/github/david132313/A_shareStock_data_and_model/blob/main/Update_dataBase.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cell 0 — Mount Drive

In [None]:
from google.colab import drive
drive.mount("/content/drive")


Mounted at /content/drive


# Cell 1 — 安装依赖

In [None]:
!pip -q install tushare tqdm pandas


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/143.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.6/143.6 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h

Cell 2 — 配置路径 + TuShare Token

In [None]:
# Cell 2 — Config (env first; fallback to Colab Secrets)

import os

DB_DRIVE = "/content/drive/MyDrive/AshareDB/db/ashare.sqlite"
DB_LOCAL = "/content/ashare.sqlite"   # local working copy for fast writes

# 1) try environment variable first
TUSHARE_TOKEN = os.environ.get("TUSHARE_TOKEN", "")

# 2) fallback to Colab Secrets (if env not set)
if not TUSHARE_TOKEN:
    try:
        from google.colab import userdata
        # change this to EXACTLY your secret name in the left "Secrets" panel
        TUSHARE_TOKEN = userdata.get("E_TOKEN") or ""
    except Exception:
        TUSHARE_TOKEN = ""

if not TUSHARE_TOKEN:
    raise RuntimeError(
        "Missing TuShare token.\n"
        "Option A (env): %env TUSHARE_TOKEN=YOUR_TOKEN\n"
        "Option B (Secrets): set a secret named E_TOKEN (or update the name in this cell)."
    )

print("DB_DRIVE:", DB_DRIVE)
print("DB_LOCAL:", DB_LOCAL)
print("TUSHARE_TOKEN loaded ✅ (hidden)")


DB_DRIVE: /content/drive/MyDrive/AshareDB/db/ashare.sqlite
DB_LOCAL: /content/ashare.sqlite
TUSHARE_TOKEN loaded ✅ (hidden)


# Cell 3 — 打开 DB（本地备份 + 建表确保存在）

In [None]:
import shutil, sqlite3
from datetime import datetime

# 把 Drive DB 拷到本地
shutil.copy2(DB_DRIVE, DB_LOCAL)
print("copied DB to local:", DB_LOCAL)

def connect_sqlite(path: str) -> sqlite3.Connection:
    conn = sqlite3.connect(path, timeout=60)
    conn.isolation_level = None
    conn.execute("PRAGMA journal_mode=WAL;")
    conn.execute("PRAGMA synchronous=NORMAL;")
    conn.execute("PRAGMA temp_store=MEMORY;")
    conn.execute("PRAGMA cache_size=-200000;")
    conn.execute("PRAGMA busy_timeout=60000;")
    return conn

SCHEMA = """
CREATE TABLE IF NOT EXISTS security_map (
  sec_id  INTEGER PRIMARY KEY AUTOINCREMENT,
  ts_code TEXT NOT NULL UNIQUE
);

CREATE TABLE IF NOT EXISTS daily_price (
  trade_date INTEGER NOT NULL,
  sec_id     INTEGER NOT NULL,
  open       REAL,
  high       REAL,
  low        REAL,
  close      REAL,
  pre_close  REAL,
  change     REAL,
  pct_chg    REAL,
  vol        REAL,
  amount     REAL,
  PRIMARY KEY (trade_date, sec_id)
);

CREATE TABLE IF NOT EXISTS ingest_manifest (
  trade_date   INTEGER PRIMARY KEY,
  parquet_file TEXT NOT NULL,
  rows         INTEGER NOT NULL,
  status       TEXT NOT NULL,
  message      TEXT,
  loaded_at    TEXT NOT NULL
);

CREATE INDEX IF NOT EXISTS idx_daily_date     ON daily_price(trade_date);
CREATE INDEX IF NOT EXISTS idx_daily_sec_date ON daily_price(sec_id, trade_date);
"""

conn = connect_sqlite(DB_LOCAL)
conn.executescript(SCHEMA)

# 当前 DB 最大日期
row = conn.execute("SELECT MAX(trade_date) FROM daily_price;").fetchone()
last_db_date = int(row[0]) if row and row[0] is not None else None
print("DB last trade_date:", last_db_date)


copied DB to local: /content/ashare.sqlite
DB last trade_date: 20251231


# Cell 4 — 用 TuShare 找“缺失交易日列表”

In [None]:
import tushare as ts
import pandas as pd
from datetime import date

ts.set_token(TUSHARE_TOKEN)
pro = ts.pro_api()

def yyyymmdd(i: int) -> str:
    return str(i)

# 计算 start/end：从 DB 最后一天的下一天开始
if last_db_date is None:
    # 如果 DB 为空（不太可能），你可以手动设起始日
    start_date = "20100101"
else:
    # 下一天（这里用字符串简单加 1 天会麻烦，所以用 pandas）
    dt = pd.to_datetime(str(last_db_date))
    start_date = (dt + pd.Timedelta(days=1)).strftime("%Y%m%d")

# end_date 设为今天；trade_cal 会返回 <= today 的开市日
end_date = pd.Timestamp.today().strftime("%Y%m%d")
print("Query trade_cal range:", start_date, "->", end_date)

cal = pro.trade_cal(exchange="SSE", start_date=start_date, end_date=end_date, is_open="1",
                    fields="cal_date,is_open")
open_dates = cal["cal_date"].tolist()
print("open dates to consider:", len(open_dates), "first:", (open_dates[0] if open_dates else None),
      "last:", (open_dates[-1] if open_dates else None))

# 如果某些日期你之前跑过但未写入（manifest fail），也可以强制重跑：
# fail_dates = [str(r[0]) for r in conn.execute("SELECT trade_date FROM ingest_manifest WHERE status='fail' ORDER BY trade_date;").fetchall()]
# open_dates = sorted(set(open_dates + fail_dates))


Query trade_cal range: 20260101 -> 20260110
open dates to consider: 5 first: 20260109 last: 20260105


# Cell 5 — 增量更新入库（自动续跑/重试/一次性更新所有缺失日）

In [None]:
import time
from tqdm import tqdm

UPSERT_DAILY_SQL = """
INSERT INTO daily_price(
  trade_date, sec_id, open, high, low, close, pre_close, change, pct_chg, vol, amount
) VALUES (?,?,?,?,?,?,?,?,?,?,?)
ON CONFLICT(trade_date, sec_id) DO UPDATE SET
  open=excluded.open,
  high=excluded.high,
  low=excluded.low,
  close=excluded.close,
  pre_close=excluded.pre_close,
  change=excluded.change,
  pct_chg=excluded.pct_chg,
  vol=excluded.vol,
  amount=excluded.amount;
"""

# security_map cache
cache_ts2id = {}
for ts_code, sec_id in conn.execute("SELECT ts_code, sec_id FROM security_map;"):
    cache_ts2id[ts_code] = sec_id

def ensure_sec_ids_cached(conn, codes):
    new_codes = [c for c in codes if c not in cache_ts2id]
    if not new_codes:
        return
    conn.executemany("INSERT OR IGNORE INTO security_map(ts_code) VALUES (?);",
                     [(c,) for c in set(new_codes)])
    # 查回
    uniq = sorted(set(new_codes))
    for i in range(0, len(uniq), 900):
        batch = uniq[i:i+900]
        placeholders = ",".join(["?"] * len(batch))
        q = f"SELECT ts_code, sec_id FROM security_map WHERE ts_code IN ({placeholders});"
        for ts_code, sec_id in conn.execute(q, batch):
            cache_ts2id[ts_code] = sec_id

def already_ok(trade_date_int: int) -> bool:
    r = conn.execute("SELECT status FROM ingest_manifest WHERE trade_date=?;", (trade_date_int,)).fetchone()
    return (r is not None and r[0] == "ok")

def fetch_daily_all(trade_date: str, max_retry=5, sleep_base=1.0) -> pd.DataFrame:
    # TuShare: pro.daily(trade_date=YYYYMMDD) 返回全市场当日
    for k in range(max_retry):
        try:
            df = pro.daily(trade_date=trade_date,
                           fields="ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount")
            return df
        except Exception as e:
            if k == max_retry - 1:
                raise
            time.sleep(sleep_base * (2 ** k))

ok = fail = skip = 0

for d in tqdm(open_dates, desc="update daily"):
    td = int(d)
    if already_ok(td):
        skip += 1
        continue

    try:
        df = fetch_daily_all(d)
        if df is None or df.empty:
            # 有时最新日期 TuShare 还没更新，会返回空；记录并跳过
            conn.execute("BEGIN;")
            conn.execute("""
            INSERT OR REPLACE INTO ingest_manifest(trade_date, parquet_file, rows, status, message, loaded_at)
            VALUES (?,?,?,?,?,?)
            """, (td, f"tushare:{d}", 0, "fail", "empty daily from tushare (maybe not available yet)",
                  datetime.now().isoformat(timespec="seconds")))
            conn.execute("COMMIT;")
            fail += 1
            continue

        # 确保类型
        df["trade_date"] = df["trade_date"].astype(str).astype(int)
        codes = df["ts_code"].astype(str).unique().tolist()

        conn.execute("BEGIN;")
        ensure_sec_ids_cached(conn, codes)
        df["sec_id"] = df["ts_code"].astype(str).map(cache_ts2id).astype(int)

        df2 = df[["trade_date","sec_id","open","high","low","close","pre_close","change","pct_chg","vol","amount"]]
        df2 = df2.where(pd.notnull(df2), None)

        conn.executemany(UPSERT_DAILY_SQL, df2.itertuples(index=False, name=None))

        conn.execute("""
        INSERT OR REPLACE INTO ingest_manifest(trade_date, parquet_file, rows, status, message, loaded_at)
        VALUES (?,?,?,?,?,?)
        """, (td, f"tushare:{d}", int(len(df2)), "ok", None,
              datetime.now().isoformat(timespec="seconds")))

        conn.execute("COMMIT;")
        ok += 1

    except Exception as e:
        try:
            conn.execute("ROLLBACK;")
        except Exception:
            pass
        conn.execute("BEGIN;")
        conn.execute("""
        INSERT OR REPLACE INTO ingest_manifest(trade_date, parquet_file, rows, status, message, loaded_at)
        VALUES (?,?,?,?,?,?)
        """, (td, f"tushare:{d}", 0, "fail", repr(e),
              datetime.now().isoformat(timespec="seconds")))
        conn.execute("COMMIT;")
        fail += 1

print("update done | ok:", ok, "fail:", fail, "skip:", skip, "| mapping cache:", len(cache_ts2id))


update daily: 100%|██████████| 5/5 [00:11<00:00,  2.39s/it]

update done | ok: 5 fail: 0 skip: 0 | mapping cache: 5790





# Cell 6 — 一次性写回 Drive（生成“干净主库”，避免 -wal/-shm 残留）

In [None]:
import sqlite3, os

def backup_local_to_drive(local_conn: sqlite3.Connection, drive_path: str):
    tmp = drive_path + ".tmp"
    dst = sqlite3.connect(tmp)
    try:
        local_conn.backup(dst)
        dst.commit()
    finally:
        dst.close()
    os.replace(tmp, drive_path)

backup_local_to_drive(conn, DB_DRIVE)

print("saved to Drive:", DB_DRIVE, "size(MB)=", os.path.getsize(DB_DRIVE)/1024/1024)
print("integrity_check:", conn.execute("PRAGMA integrity_check;").fetchone()[0])


saved to Drive: /content/drive/MyDrive/AshareDB/db/ashare.sqlite size(MB)= 2298.1640625
integrity_check: ok


# Cell 7 — 验证更新结果（日期范围/失败日）

In [None]:
import pandas as pd

print(pd.read_sql_query("""
SELECT MIN(trade_date) AS min_date, MAX(trade_date) AS max_date, COUNT(*) AS total_rows
FROM daily_price;
""", conn))

print(pd.read_sql_query("""
SELECT status, COUNT(*) AS n
FROM ingest_manifest
GROUP BY status;
""", conn))

print(pd.read_sql_query("""
SELECT trade_date, parquet_file, message
FROM ingest_manifest
WHERE status='fail'
ORDER BY trade_date DESC
LIMIT 20;
""", conn))

conn.close()


   min_date  max_date  total_rows
0  20000104  20260109    16501185
  status     n
0     ok  6306
Empty DataFrame
Columns: [trade_date, parquet_file, message]
Index: []


#