In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/Capstone

Mounted at /content/drive
/content/drive/MyDrive/Capstone


In [None]:
import pandas as pd

# 1. 读入 401k 的交易和账户表
tx401k = pd.read_csv("401k_transactions.csv", parse_dates=["transactionTimestamp"])
print(tx401k.columns.tolist())
acct401k = pd.read_csv("401k_accounts.csv")
print(acct401k.columns.tolist())


['accountId', 'transactionId', 'transactionTimestamp', 'description', 'amount', 'shares', 'securityId', 'securityIdType', 'symbol', 'type']
['accountId', 'accountNumberDisplay', 'accountType', 'currency', 'description', 'nickname', 'productName', 'status']


In [None]:
# 合并
merged = tx401k.merge(acct401k, on="accountId", how="left")

print("合并后总行数:", merged.shape[0])
print("accountType 缺失数:", merged["accountType"].isna().sum())
print("accounts 中一共有多少唯一 accountId:", acct401k["accountId"].nunique())
print("tx 中一共有多少唯一 accountId:", tx401k["accountId"].nunique())

合并后总行数: 3000
accountType 缺失数: 0
accounts 中一共有多少唯一 accountId: 100
tx 中一共有多少唯一 accountId: 100


In [None]:
# 按账户类型看交易量
print( merged.groupby("accountType")["transactionId"].count() )

# 按账户类型看金额分布
print( merged.groupby("accountType")["amount"].describe() )

accountType
401K    3000
Name: transactionId, dtype: int64
              count        mean         std  min     25%     50%     75%  \
accountType                                                                
401K         3000.0  391.864213  331.526023  0.0  47.465  349.44  683.45   

                max  
accountType          
401K         999.86  


In [None]:
# 聚合示例
feat = merged.groupby("accountId").agg(
    tx_count = ("transactionId", "count"),
    tx_sum_amount = ("amount", "sum"),
    tx_mean_amount = ("amount", "mean"),
    tx_std_amount = ("amount", "std"),
    unique_symbols = ("symbol", "nunique"),
    shares_sum = ("shares", "sum"),
    shares_mean = ("shares", "mean"),
).reset_index()

# 把账户属性合回来
feat = feat.merge(acct401k[["accountId","accountType","currency","status"]],
                  on="accountId", how="left")
print(feat)

     accountId  tx_count  tx_sum_amount  tx_mean_amount  tx_std_amount  \
0   1859000000        30       10635.98      354.532667     336.990197   
1   1859000001        30       11007.79      366.926333     268.871439   
2   1859000002        30       11063.40      368.780000     348.811723   
3   1859000003        30       10650.32      355.010667     328.036633   
4   1859000004        30       14799.57      493.319000     317.129980   
..         ...       ...            ...             ...            ...   
95  1859000095        30       12462.81      415.427000     345.942591   
96  1859000096        30       14887.00      496.233333     339.271172   
97  1859000097        30       11614.11      387.137000     298.031822   
98  1859000098        30       11005.98      366.866000     336.047640   
99  1859000099        30       11856.14      395.204667     346.323019   

    unique_symbols  shares_sum  shares_mean accountType  \
0                1      927.66    30.922000        4

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor

# —— 1. 准备特征 & 缩放 ——
num_cols = [
    "tx_count", "tx_sum_amount", "tx_mean_amount",
    "tx_std_amount", "unique_symbols",
    "shares_sum", "shares_mean",
]
X = feat[num_cols].fillna(0)
X_log = np.log1p(X)

# —— 2. IsolationForest ——
iso = IsolationForest(
    n_estimators=200,
    contamination=0.02,   # 提高到2%，看能否捕获更多
    random_state=42
)
feat["iso_label"] = iso.fit_predict(X_log)           # -1：异常，1：正常
feat["iso_score"] = iso.decision_function(X_log)     # 越小越异常

print("IF 异常账号数：", (feat["iso_label"] == -1).sum())

# —— 3. Local Outlier Factor ——
lof = LocalOutlierFactor(
    n_neighbors=20,
    contamination=0.02
)
# fit_predict 直接输出 -1/1
feat["lof_label"] = lof.fit_predict(X_log)
feat["lof_score"] = -lof.negative_outlier_factor_

print("LOF 异常账号数：", (feat["lof_label"] == -1).sum())

# —— 4. 融合：交集候选 ——
# 只有同时被 IF 和 LOF 标为异常的，才置为高置信度候选
feat["hybrid_anomaly"] = np.where(
    (feat["iso_label"] == -1) & (feat["lof_label"] == -1),
    1, 0
)
print("高置信度异常账号数：", feat["hybrid_anomaly"].sum())

# —— 5. 单笔交易 z‐score ——
def safe_z(x):
    if x.std() == 0:
        return pd.Series([0]*len(x), index=x.index)
    return (x - x.mean()) / x.std()

merged["amt_zscore"] = merged.groupby("accountId")["amount"] \
                            .transform(safe_z)

merged["amt_outlier"] = merged["amt_zscore"].abs() > 2.5

# 2) 先看看这 2 笔到底长什么样、属于哪些账号
extremes = merged.loc[merged["amt_outlier"],
                      ["accountId","transactionId","amount","amt_zscore"]]
print("All extreme transactions:\n", extremes)

# 3) 再看那个 hybrid 异常账号到底有没有低阈值下的“极端”交易
anom_id = feat.loc[feat["hybrid_anomaly"]==1, "accountId"].iloc[0]
tx_for_anom = merged[merged["accountId"]==anom_id]
# 用更宽松的阈值试试，比如 2σ
anom_hits = tx_for_anom[tx_for_anom["amt_zscore"].abs() > 2.0]
print(f"\nAccount {anom_id} | 2σ outliers:", anom_hits)
print("\n极端交易笔数：", merged["amt_outlier"].sum())

# —— 6. 查看高置信度异常用户的极端交易 ——
anom_ids = feat.loc[feat["hybrid_anomaly"] == 1, "accountId"]
tx_anom = merged[merged["accountId"].isin(anom_ids)]
print("\n", tx_anom[tx_anom["amt_outlier"]].sort_values("amt_zscore", ascending=False).head(10))


IF 异常账号数： 2
LOF 异常账号数： 2
高置信度异常账号数： 1
All extreme transactions:
        accountId  transactionId  amount  amt_zscore
1460  1859000048     6789900671  987.50    2.714656
1953  1859000065     1693329722  999.86    2.516928

Account 1859000031 | 2σ outliers: Empty DataFrame
Columns: [accountId, transactionId, transactionTimestamp, description_x, amount, shares, securityId, securityIdType, symbol, type, accountNumberDisplay, accountType, currency, description_y, nickname, productName, status, amt_zscore, amt_outlier]
Index: []

极端交易笔数： 2

 Empty DataFrame
Columns: [accountId, transactionId, transactionTimestamp, description_x, amount, shares, securityId, securityIdType, symbol, type, accountNumberDisplay, accountType, currency, description_y, nickname, productName, status, amt_zscore, amt_outlier]
Index: []


In [None]:
# 找出这两个极端交易对应的 accountId
extreme_ids = extremes["accountId"].unique().tolist()
# 在 feat 表里添加一个 flag
feat["extreme_flag"] = feat["accountId"].isin(extreme_ids)
# 合并所有异常信号：IF、LOF 或极端交易
feat["final_anomaly"] = (
    (feat["iso_label"] == -1) |
    (feat["lof_label"] == -1) |
    (feat["extreme_flag"])
)
print("最终待审查账号数：", feat["final_anomaly"].sum())

# 1) 只看标记为最终异常的账户
anom_feat = feat[feat["final_anomaly"]].copy()

# 2) 打印所有异常账户的聚合特征
print("异常账户详情：")
print(anom_feat)

最终待审查账号数： 5
异常账户详情：
     accountId  tx_count  tx_sum_amount  tx_mean_amount  tx_std_amount  \
20  1859000020        30       14011.75      467.058333     356.283951   
31  1859000031        30       10603.42      353.447333     373.015074   
48  1859000048        30        8387.55      279.585000     260.775198   
65  1859000065        30        8229.84      274.328000     288.260895   
70  1859000070        30       11290.99      376.366333     302.121632   

    unique_symbols  shares_sum  shares_mean accountType  \
20               1     1096.64    36.554667        401K   
31               1     1285.44    42.848000        401K   
48               1      594.37    19.812333        401K   
65               1      727.91    24.263667        401K   
70               1      547.54    18.251333        401K   

                   currency status  iso_label  iso_score  lof_label  \
20  {'currencyCode': 'USD'}   OPEN         -1  -0.005157          1   
31  {'currencyCode': 'USD'}   OPEN    

人工复核：查看它们的交易明细和描述，确认是否真是风险操作。

打标签：给真异常打上 1，给合理业务打上 0，积累弱标签。

迭代特征：根据人工检查发现的模式，增加新特征（如周末大额、发薪日前后交易等），再跑一轮模型。

## 为每个账户构建时序序列

In [None]:
from sklearn.preprocessing import MinMaxScaler

def build_amount_sequences(merged_df):
    ids = []
    seqs = []
    for aid, df in merged_df.groupby("accountId"):
        df2 = df.sort_values("transactionTimestamp")
        seqs.append(df2["amount"].values)
        ids.append(aid)
    X = np.array(seqs)                      # (n_accounts, timesteps)
    # reshape 为 LSTM 要求的三维：[样本数, 窗口长度, 特征维度]
    X = X.reshape(X.shape[0], X.shape[1], 1)  # amount 只有 1 维
    # 归一化到 [0,1]
    scaler = MinMaxScaler()
    X_flat = X.reshape(-1, 1)
    X_flat = scaler.fit_transform(X_flat)
    X = X_flat.reshape(X.shape)
    return ids, X

In [None]:
ids_401k, X_401k = build_amount_sequences(merged)

In [None]:
! pip install pyod

Collecting pyod
  Downloading pyod-2.0.5-py3-none-any.whl.metadata (46 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/46.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.3/46.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Downloading pyod-2.0.5-py3-none-any.whl (200 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.6/200.6 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyod
Successfully installed pyod-2.0.5


In [None]:
# 1. 扁平化序列 → (n_accounts, timesteps)
#    X_401k.shape == (n_accounts, timesteps, 1)
X_seq = X_401k.reshape(X_401k.shape[0], X_401k.shape[1])

# 2. 标准化（可选，但推荐）
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_seq_scaled = scaler.fit_transform(X_seq)

In [None]:
# 3. 导入并训练 AutoEncoder
from pyod.models.auto_encoder import AutoEncoder

ae = AutoEncoder(
    contamination=0.02,            # 2% 数据会被标为异常
    preprocessing=True,            # 内部再做一次 StandardScaler
    lr=0.001,                      # 学习率
    epoch_num=50,                  # 训练轮数
    batch_size=16,                 # 批大小
    optimizer_name='adam',
    random_state=42,
    verbose=1,

    # 网络结构相关
    hidden_neuron_list=[64, 32, 32, 64],
    hidden_activation_name='relu',
    batch_norm=True,
    dropout_rate=0.2
)

# 训练
ae.fit(X_seq_scaled)

# 拿标签和分数
ae_labels = dict(zip(ids_401k, ae.labels_))             # 0 正常，1 异常
ae_scores = dict(zip(ids_401k, ae.decision_scores_))    # 越大越异常

# 映射回 feat
feat["ae_label"] = feat["accountId"].map(ae_labels).fillna(0).astype(int)
feat["ae_score"] = feat["accountId"].map(ae_scores).fillna(0)

Training: 100%|██████████| 50/50 [00:02<00:00, 22.80it/s]


In [None]:
# 更新 final_anomaly
feat["final_anomaly"] = (
      (feat["iso_label"]   == -1)   # IsolationForest
    | (feat["lof_label"]   == -1)   # LOF
    | (feat["extreme_flag"]     )   # 单笔极端交易
    | (feat["ae_label"]    == 1)    # 深度 AutoEncoder
)
print("401k → 最终待审查账号数：", feat["final_anomaly"].sum())

401k → 最终待审查账号数： 7


In [None]:
# 1) 查看可用列名
print(merged.columns.tolist())

['accountId', 'transactionId', 'transactionTimestamp', 'description_x', 'amount', 'shares', 'securityId', 'securityIdType', 'symbol', 'type', 'accountNumberDisplay', 'accountType', 'currency', 'description_y', 'nickname', 'productName', 'status', 'amt_zscore', 'amt_outlier']


In [None]:
# 1. 拿到所有被标记的账号列表
anom_ids = feat.loc[feat["final_anomaly"], "accountId"].tolist()
print("待审查账号：", anom_ids)

# 2. 查看这些账号在 feat 表里的聚合特征和 AE/IF/LOF 分数
anom_feat = feat[feat["accountId"].isin(anom_ids)]
display(anom_feat)

# 3. 在原始流水里筛出它们的所有交易
anom_tx = merged[merged["accountId"].isin(anom_ids)]

# 4. 只展示关键字段，并按 accountId + amount 降序，方便人工复核
cols = [
    "accountId",
    "transactionTimestamp",
    "transactionId",
    "amount",
    "type"
]
display(
    anom_tx[cols]
      .sort_values(["accountId", "amount"], ascending=[True, False])
)


待审查账号： [1859000020, 1859000031, 1859000048, 1859000065, 1859000066, 1859000070, 1859000088]


Unnamed: 0,accountId,tx_count,tx_sum_amount,tx_mean_amount,tx_std_amount,unique_symbols,shares_sum,shares_mean,accountType,currency,status,iso_label,iso_score,lof_label,lof_score,hybrid_anomaly,extreme_flag,final_anomaly,ae_label,ae_score
20,1859000020,30,14011.75,467.058333,356.283951,1,1096.64,36.554667,401K,{'currencyCode': 'USD'},OPEN,-1,-0.005157,1,1.360421,0,False,True,0,5.830788
31,1859000031,30,10603.42,353.447333,373.015074,1,1285.44,42.848,401K,{'currencyCode': 'USD'},OPEN,-1,-0.068932,-1,2.077149,1,False,True,0,5.922033
48,1859000048,30,8387.55,279.585,260.775198,1,594.37,19.812333,401K,{'currencyCode': 'USD'},OPEN,1,0.000862,1,1.516537,0,True,True,0,4.753519
65,1859000065,30,8229.84,274.328,288.260895,1,727.91,24.263667,401K,{'currencyCode': 'USD'},OPEN,1,0.081,1,1.308412,0,True,True,0,4.493028
66,1859000066,30,16541.27,551.375667,376.56126,1,885.51,29.517,401K,{'currencyCode': 'USD'},OPEN,1,0.01236,1,1.351026,0,False,True,1,6.304177
70,1859000070,30,11290.99,376.366333,302.121632,1,547.54,18.251333,401K,{'currencyCode': 'USD'},OPEN,1,0.014878,-1,1.6339,0,False,True,0,4.897698
88,1859000088,30,14030.39,467.679667,387.400985,1,803.96,26.798667,401K,{'currencyCode': 'USD'},OPEN,1,0.099994,1,1.099072,0,False,True,1,6.291749


Unnamed: 0,accountId,transactionTimestamp,transactionId,amount,type
601,1859000020,2023-05-28,8820038015,978.04,CONTRIBUTION
620,1859000020,2020-09-20,1308085105,909.66,CONTRIBUTION
613,1859000020,2021-01-14,4157955812,909.21,EMPLOYER_MATCH
609,1859000020,2021-08-24,3350917614,906.10,CONTRIBUTION
622,1859000020,2022-02-21,3082913303,890.22,EMPLOYER_MATCH
...,...,...,...,...,...
2642,1859000088,2020-06-28,4950728840,0.00,REALLOCATION
2644,1859000088,2020-06-21,2702783627,0.00,REALLOCATION
2649,1859000088,2020-08-21,7228893896,0.00,REALLOCATION
2650,1859000088,2020-01-09,6940154613,0.00,REALLOCATION


## DEMO


In [None]:
# === 401k anomaly review mini-demo (ipywidgets) ===
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as w
from IPython.display import display, clear_output
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor
from sklearn.preprocessing import StandardScaler

plt.rcParams["figure.figsize"] = (8, 2.8)

# ---------- Config (与主流程对齐) ----------
IF_CONT = 0.02
LOF_CONT = 0.02
AE_CONT  = 0.02   # AE 你之前用 2%，就保持 2%
RANDOM_SEED = 42

# ---------- Helpers ----------

from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd

def full_detect_on_uploaded_401k(merged_df, use_ae=True, contamination=0.02, random_state=42):
    """
    与内存管道对齐的完整检测：聚合 → log1p → IF/LOF → 单笔z-score → （可选）AutoEncoder → final_anomaly
    """
    # 1) 账户级特征（尽量与内存一致）
    feat = merged_df.groupby("accountId").agg(
        tx_count=("transactionId", "count") if "transactionId" in merged_df.columns else ("amount","count"),
        tx_sum_amount=("amount","sum"),
        tx_mean_amount=("amount","mean"),
        tx_std_amount=("amount","std"),
        unique_symbols=("symbol","nunique") if "symbol" in merged_df.columns else ("accountId","nunique"),
        shares_sum = ("shares","sum")  if "shares" in merged_df.columns else ("amount","sum"),
        shares_mean= ("shares","mean") if "shares" in merged_df.columns else ("amount","mean"),
    ).reset_index()

    # 2) 单笔金额 z-score → 极端交易账号（阈值按你内存里设置）
    def safe_z(x):
        s = x.std()
        return (x - x.mean())/s if (s is not None and s>0) else pd.Series(0, index=x.index)

    tmp = merged_df.copy()
    tmp["amt_z"] = tmp.groupby("accountId")["amount"].transform(safe_z)
    tmp["amt_outlier"] = tmp["amt_z"].abs() > 2.5
    extreme_ids = tmp.loc[tmp["amt_outlier"], "accountId"].unique().tolist()
    feat["extreme_flag"] = feat["accountId"].isin(extreme_ids)

    # 3) IF/LOF（与内存一致：log1p）
    num_cols = ["tx_count","tx_sum_amount","tx_mean_amount","tx_std_amount",
                "unique_symbols","shares_sum","shares_mean"]
    X = feat[num_cols].fillna(0)
    X_log = np.log1p(X)  # 401k 金额非负，log1p 与内存一致

    iso = IsolationForest(contamination=IF_CONT,
                          random_state=RANDOM_SEED,
                          n_estimators=200,
                          max_samples='auto')

    feat["iso_label"] = iso.fit_predict(X_log)
    feat["iso_score"] = iso.decision_function(X_log)

    lof = LocalOutlierFactor(n_neighbors=20, contamination=LOF_CONT)
    feat["lof_label"] = lof.fit_predict(X_log)
    feat["lof_score"]  = -lof.negative_outlier_factor_

    # 4) （可选）AutoEncoder（与内存 AE 一致）
    if use_ae:
        # 构建序列（每账户按时间的 amount）
        ids = []
        seqs = []
        df_sorted = merged_df.sort_values(["accountId","transactionTimestamp"])
        for aid, df in df_sorted.groupby("accountId"):
            ids.append(aid)
            seqs.append(df["amount"].values)
        X_seq = np.asarray(seqs)                       # (n_accounts, timesteps)
        scaler = StandardScaler()
        X_seq_scaled = scaler.fit_transform(X_seq)     # 与内存中做法一致

        # 用你内存里那套 PyOD AE 参数
        from pyod.models.auto_encoder import AutoEncoder
        ae = AutoEncoder(
            contamination=AE_CONT,
            preprocessing=True,
            lr=0.001,
            epoch_num=50,
            batch_size=16,
            optimizer_name='adam',
            random_state=RANDOM_SEED,
            verbose=0,
            hidden_neuron_list=[64,32,32,64],
            hidden_activation_name='relu',
            batch_norm=True,
            dropout_rate=0.2
        )
        ae.fit(X_seq_scaled)
        ae_labels = dict(zip(ids, ae.labels_))
        ae_scores = dict(zip(ids, ae.decision_scores_))
        feat["ae_label"] = feat["accountId"].map(ae_labels).fillna(0).astype(int)
        feat["ae_score"] = feat["accountId"].map(ae_scores).fillna(0)
    else:
        feat["ae_label"] = 0
        feat["ae_score"] = 0.0

    # 5) 最终标记（与内存相同的 OR 规则）
    feat["final_anomaly"] = (
        (feat["iso_label"] == -1) |
        (feat["lof_label"] == -1) |
        (feat["extreme_flag"])    |
        (feat["ae_label"] == 1)
    )
    return feat

def _safe_col(df, candidates, default=None):
    """Return the first column name that exists in df from candidates."""
    for c in candidates:
        if c in df.columns:
            return c
    return default

def _ensure_final_anomaly(feat_df):
    """If final_anomaly is missing, create it from whatever flags exist."""
    if "final_anomaly" in feat_df.columns:
        return feat_df
    flags = []
    if "iso_label" in feat_df.columns: flags.append(feat_df["iso_label"] == -1)
    if "lof_label" in feat_df.columns: flags.append(feat_df["lof_label"] == -1)
    if "extreme_flag" in feat_df.columns: flags.append(feat_df["extreme_flag"].astype(bool))
    if "lstm_label" in feat_df.columns: flags.append(feat_df["lstm_label"] == 1)
    if "ae_label"   in feat_df.columns: flags.append(feat_df["ae_label"]   == 1)
    if flags:
        feat_df = feat_df.copy()
        feat_df["final_anomaly"] = np.logical_or.reduce(flags)
    else:
        raise ValueError("No anomaly flags found on feat. Please run detection steps first.")
    return feat_df

def _merge_uploaded(tx_df, acct_df):
    # guess ts col and parse
    ts_col = _safe_col(tx_df, ["transactionTimestamp", "timestamp", "time"])
    if ts_col is None:
        raise ValueError("Cannot find a timestamp column (e.g., 'transactionTimestamp').")
    if not np.issubdtype(tx_df[ts_col].dtype, np.datetime64):
        tx_df[ts_col] = pd.to_datetime(tx_df[ts_col], errors="coerce")

    if "accountId" not in tx_df.columns or "accountId" not in acct_df.columns:
        raise ValueError("Both CSVs must include 'accountId'.")

    merged_df = tx_df.merge(acct_df, on="accountId", how="left", suffixes=("_tx", "_acct"))
    return merged_df, ts_col

def _show_account_detail(merged_df, account_id):
    df = merged_df[merged_df["accountId"] == account_id].copy()
    if df.empty:
        print("No transactions for this account.")
        return

    # best-effort column picks
    ts_col = _safe_col(df, ["transactionTimestamp", "timestamp", "time"])
    type_col = _safe_col(df, ["type", "transactionType"])
    desc_col = _safe_col(df, ["description_x", "description_tx", "description"])

    cols = ["accountId"]
    if ts_col:   cols.append(ts_col)
    if "transactionId" in df.columns: cols.append("transactionId")
    cols.append("amount")
    if type_col: cols.append(type_col)
    if desc_col: cols.append(desc_col)

    # detail table (top by amount)
    disp = df.sort_values("amount", ascending=False)[cols].head(20).reset_index(drop=True)
    display(disp)

    # time series
    if ts_col:
        df2 = df.sort_values(ts_col)
        plt.plot(df2[ts_col], df2["amount"], marker="o")
        plt.title(f"Account {account_id} — Amount over Time")
        plt.xlabel("Time"); plt.ylabel("Amount"); plt.xticks(rotation=45)
        plt.tight_layout(); plt.show()

# ---------- UI ----------
st_source = w.ToggleButtons(
    options=[("Use in-memory (merged/feat)", "mem"),
             ("Upload CSVs (transactions + accounts)", "upload")],
    description="Source:",
)

use_ae_ck = w.Checkbox(value=True, description="Use AutoEncoder (slower)")
u_tx = w.FileUpload(accept=".csv", multiple=False, description="Upload 401k transactions CSV")
u_ac = w.FileUpload(accept=".csv", multiple=False, description="Upload 401k accounts CSV")

btn_run = w.Button(description="Run 401k review", button_style="primary")
out = w.Output()

def on_run(_):
    with out:
        clear_output()
        try:
            if st_source.value == "mem":
                # —— 用内存里已有的 merged / feat ——
                if "merged" not in globals() or "feat" not in globals():
                    print("Could not find variables 'merged' and 'feat' in memory. "
                          "Switch to 'Upload CSVs' or run the earlier pipeline cells.")
                    return
                merged_df = merged.copy()
                feat_df   = feat.copy()
                ts_col = _safe_col(merged_df, ["transactionTimestamp", "timestamp", "time"])
                # 确保有 final_anomaly（若没有就用现有标记合成）
                try:
                    feat_df = _ensure_final_anomaly(feat_df)
                except Exception as e:
                    print(f"{e}\nShowing aggregates only.")

            else:
                # —— 用上传的 CSV ——
                if (len(u_tx.value) == 0) or (len(u_ac.value) == 0):
                    print("Please upload both transactions and accounts CSV files.")
                    return

                # 读取上传
                tx_bytes = next(iter(u_tx.value.values()))["content"]
                ac_bytes = next(iter(u_ac.value.values()))["content"]
                tx_df = pd.read_csv(pd.io.common.BytesIO(tx_bytes))
                ac_df = pd.read_csv(pd.io.common.BytesIO(ac_bytes))

                # 合并
                merged_df, ts_col = _merge_uploaded(tx_df, ac_df)

                # 对上传数据跑一次轻量检测，得到 final_anomaly
                feat_df = full_detect_on_uploaded_401k(merged_df, use_ae=use_ae_ck.value, contamination=0.02, random_state=42)
                print("Full detection (with AE) finished on uploaded data.")

            # —— 展示异常账户列表 ——
            if "final_anomaly" in feat_df.columns and feat_df["final_anomaly"].any():
                anom = feat_df[feat_df["final_anomaly"]].copy()
                print(f"Accounts flagged for review: {len(anom)}")
                keep_cols = [c for c in [
                    "iso_score","lof_score","ae_score","hybrid_anomaly","extreme_flag",
                    "tx_sum_amount","tx_mean_amount","tx_std_amount","tx_count"
                ] if c in anom.columns]
                display(anom[["accountId"] + keep_cols].reset_index(drop=True))

                # 选择账号查看详情
                acc_dd = w.Dropdown(options=anom["accountId"].tolist(), description="Account:")
                btn_show = w.Button(description="Show detail")
                box2 = w.HBox([acc_dd, btn_show])
                display(box2)

                def _on_show(__):
                    with out:
                        clear_output(wait=True)
                        print(f"Accounts flagged for review: {len(anom)}")
                        display(anom[["accountId"] + keep_cols].reset_index(drop=True))
                        display(box2)
                        _show_account_detail(merged_df, acc_dd.value)

                btn_show.on_click(_on_show)

            else:
                print("No accounts currently flagged (or no flags available).")
                # 仍允许浏览任意账户明细
                all_ids = merged_df["accountId"].unique().tolist()
                if not all_ids:
                    return
                acc_dd = w.Dropdown(options=all_ids[:50], description="Account:")
                btn_show = w.Button(description="Show detail")
                display(w.HBox([acc_dd, btn_show]))

                def _on_show_any(__):
                    with out:
                        clear_output(wait=True)
                        _show_account_detail(merged_df, acc_dd.value)
                btn_show.on_click(_on_show_any)

        except Exception as e:
            print("Error:", e)


btn_run.on_click(on_run)

# Layout
box = w.VBox([
    w.HBox([st_source, use_ae_ck]),
    w.HBox([u_tx, u_ac]),
    btn_run,
    out
])
display(box)


VBox(children=(HBox(children=(ToggleButtons(description='Source:', options=(('Use in-memory (merged/feat)', 'm…