In [None]:
import os, numpy as np, pandas as pd
import matplotlib.pyplot as plt
import shap
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score
import torch, torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from scipy.special import softmax

plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False

PATH = r"C:\**\**\"
try:
    df = pd.read_excel(PATH, engine="openpyxl")
except:
    import xlrd
    df = pd.read_excel(PATH, engine="xlrd")

exclude = {"Name","Level_OA","fold_id","G","P_PI3K","P_PPAR","P_ROS","P_LPS","P_OA","PF_C_num"}
X_cols = [c for c in df.columns if c not in exclude and df[c].dtype != 'O']
X_df = df[X_cols].copy()
y = df["Level_OA"].astype(int).values
print(f: {len(X_cols)}")
print("", X_cols[:8])

class TinyTabTransformer(nn.Module):
    def __init__(self, in_dim, n_classes=3, d_model=64, n_heads=4, n_layers=2, dropout=0.35):
        super().__init__()
        self.proj = nn.Linear(in_dim, d_model)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model*2,
            dropout=dropout, batch_first=True, activation="gelu"
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Sequential(
            nn.Linear(d_model, d_model//2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model//2, n_classes)
        )

    def forward(self, x):
        x = self.proj(x).unsqueeze(1)
        x = self.encoder(x)
        x = self.norm(x.squeeze(1))
        return self.head(x)

def fit_tiny_model(X, y, seed=2025, val_size=0.1, epochs=500, patience=50,
                   d_model=64, n_heads=4, n_layers=2, dropout=0.35,
                   lr=1e-3, weight_decay=1e-4, batch_size=16):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=seed)
    train_idx, val_idx = next(sss.split(X, y))
    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]

    scaler = StandardScaler()
    X_tr = scaler.fit_transform(X_tr)
    X_val = scaler.transform(X_val)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TinyTabTransformer(in_dim=X_tr.shape[1], n_classes=3,
                               d_model=d_model, n_heads=n_heads,
                               n_layers=n_layers, dropout=dropout).to(device)

    classes, counts = np.unique(y_tr, return_counts=True)
    weight_map = {c: (np.sum(counts)/cnt) for c,cnt in zip(classes, counts)}
    weights = torch.tensor([weight_map[i+1] for i in range(3)], dtype=torch.float32).to(device)
    criterion = nn.CrossEntropyLoss(weight=weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    tr_ds = TensorDataset(torch.tensor(X_tr, dtype=torch.float32),
                          torch.tensor(y_tr-1, dtype=torch.long))
    va_ds = TensorDataset(torch.tensor(X_val, dtype=torch.float32),
                          torch.tensor(y_val-1, dtype=torch.long))
    tr_loader = DataLoader(tr_ds, batch_size=batch_size, shuffle=True)
    va_loader = DataLoader(va_ds, batch_size=batch_size, shuffle=False)

    best_f1, best_state, no_imp = -1.0, None, 0
    for ep in range(epochs):
        model.train()
        for bx, by in tr_loader:
            bx, by = bx.to(device), by.to(device)
            loss = criterion(model(bx), by)
            optimizer.zero_grad(); loss.backward(); optimizer.step()

        model.eval()
        all_pred, all_true = [], []
        with torch.no_grad():
            for bx, by in va_loader:
                logits = model(bx.to(device)).cpu().numpy()
                pred = logits.argmax(axis=1)
                all_pred.append(pred); all_true.append(by.numpy())
        y_true = np.concatenate(all_true)
        y_pred = np.concatenate(all_pred)
        f1 = f1_score(y_true, y_pred, average='macro')

        if f1 > best_f1:
            best_f1 = f1; best_state = {k:v.cpu().clone() for k,v in model.state_dict().items()}; no_imp = 0
        else:
            no_imp += 1
            if no_imp >= patience: break

    model.load_state_dict({k:v for k,v in best_state.items()})
    return model.to(device), scaler, (train_idx, val_idx), best_f1

X_all = X_df.values
model, scaler, (tr_idx, va_idx), best_f1 = fit_tiny_model(X_all, y)
print(f" Macro-F1≈{best_f1:.3f}")

device = next(model.parameters()).device
def predict_proba_level3(X_in_np):
    X_std = scaler.transform(X_in_np)
    with torch.no_grad():
        logits = model(torch.tensor(X_std, dtype=torch.float32).to(device)).cpu().numpy()
    prob = softmax(logits, axis=1)
    return prob[:, 2]  

bg = shap.kmeans(X_df.values, 30)
explainer = shap.KernelExplainer(predict_proba_level3, bg)
X_explain = X_df.values
shap_values = explainer.shap_values(X_explain)
X_explain_df = pd.DataFrame(X_explain, columns=X_cols)

outdir = r"C:\**"
os.makedirs(outdir, exist_ok=True)

plt.figure(figsize=(8,6))
shap.summary_plot(shap_values, X_explain_df, plot_type="bar", show=False, max_display=20)
plt.tight_layout()
plt.savefig(os.path.join(outdir, "global_feature_importance_bar_descriptors_top20.png"), dpi=300)
plt.close()

plt.figure(figsize=(8,6))
shap.summary_plot(shap_values, X_explain_df, show=False, max_display=20)
plt.tight_layout()
plt.savefig(os.path.join(outdir, "shap_beeswarm_descriptors_top20.png"), dpi=300)
plt.close()

mean_abs = np.abs(shap_values).mean(axis=0)
order = np.argsort(-mean_abs)
topK = 6
top_feats = [X_cols[i] for i in order[:topK]]
for feat in top_feats:
    plt.figure(figsize=(5.2,4.2))
    shap.dependence_plot(feat, shap_values, X_explain_df, show=False, interaction_index=None)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"dep_{feat}.png"), dpi=300)
    plt.close()

print("")


In [None]:
import os, numpy as np, pandas as pd
import matplotlib.pyplot as plt

plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False

outdir = r"C:\Users\81005\Desktop\CYH\3-PFAS\深度学习验证\fig_shap_descriptors"
os.makedirs(outdir, exist_ok=True)

mean_abs = np.abs(shap_values).mean(axis=0)       # (n_features,)
order = np.argsort(-mean_abs)                     # 从大到小
topK = 20
top_idx = order[:topK]
top_feats = [X_explain_df.columns[i] for i in top_idx]
top_imp   = mean_abs[top_idx]

fig, ax = plt.subplots(figsize=(8, 5.5))
y_pos = np.arange(len(top_feats))[::-1]  
ax.barh(y_pos, top_imp, color=plt.cm.viridis(np.linspace(0.1, 0.9, len(top_feats))))
ax.set_yticks(y_pos); ax.set_yticklabels(top_feats, fontsize=11)
ax.set_xlabel("Importance Score", fontsize=12)
ax.set_ylabel("Feature", fontsize=12)
ax.grid(axis='x', linestyle=':', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(outdir, "fig_a_global_bar_descriptors_top20.png"), dpi=300)
plt.close()

bin_means = []
for f in top_feats:
    v = X_explain_df[f].values
    q1, q2 = np.quantile(v, [1/3, 2/3])
    low_m   = v <= q1
    med_m   = (v > q1) & (v <= q2)
    high_m  = v > q2
    col_idx = X_explain_df.columns.get_loc(f)
    s = np.abs(shap_values[:, col_idx])
    bin_means.append([
        s[low_m].mean()  if low_m.sum()  > 0 else 0.0,
        s[med_m].mean()  if med_m.sum()  > 0 else 0.0,
        s[high_m].mean() if high_m.sum() > 0 else 0.0
    ])
bin_means = np.array(bin_means) 

low  = bin_means[:, 0]
med  = bin_means[:, 1]
high = bin_means[:, 2]

fig, ax = plt.subplots(figsize=(9, 5.7))
y_pos = np.arange(len(top_feats))[::-1]

ax.barh(y_pos, low,  color=c_low,  label="low")
ax.barh(y_pos, med,  left=low,           color=c_med,  label="medium")
ax.barh(y_pos, high, left=low+med,       color=c_high, label="high")

ax.set_yticks(y_pos); ax.set_yticklabels(top_feats, fontsize=11)
ax.set_xlabel("mean(|SHAP value|) (average impact on model output magnitude)", fontsize=12)
ax.grid(axis='x', linestyle=':', alpha=0.3)

leg = ax.legend(frameon=False, loc="lower right", fontsize=11, title="Feature value")
plt.setp(leg.get_title(), fontsize=12)

plt.tight_layout()
plt.savefig(os.path.join(outdir, "fig_b_binned_meanabs_shap_top20.png"), dpi=300)
plt.close()

print("")

