In [3]:
import os, gc, warnings, numpy as np, pandas as pd, torch, torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import average_precision_score, accuracy_score

warnings.filterwarnings("ignore")
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEV)

BASE_DIR      = "/content/drive/My Drive/....."
EDGE_DIR      = "/content/drive/MyDrive/....."
EDGE_TPL      = "ppi_{c}_plus_knn3_bidirectional.csv"
FEAT_TPL      = "features_for_{c}.csv"
LABEL_TPL     = "{c}_labels(0_1).csv"

CANCERS       = ["BRCA","BLCA","LUAD","LIHC","PRAD","CESC",
                 "COAD","STAD","THCA","LUSC","UCEC","ESCA"]

SHARED_DIM   = 64
HIDDEN_DIM   = 64
DROPOUT      = 0.5
PT_EPOCHS    = 300
FT_EPOCHS    = 200
PT_SAVE_DIR  = f"{BASE_DIR}/pretrain_adapters"
os.makedirs(PT_SAVE_DIR, exist_ok=True)


Device: cuda


In [4]:
from torch_geometric.utils import add_self_loops

def build_graph(cancer_code: str,
                scaler: StandardScaler | None = None) -> Data:
    """

      • x            : [N, d]
      • edge_index   : [2, E]
      • y            : [-1,0,1]
      • gene_names   : list length N

    """
    edge_path  = os.path.join(EDGE_DIR,  EDGE_TPL.format(c=cancer_code))
    feat_path  = os.path.join(BASE_DIR, FEAT_TPL.format(c=cancer_code))
    label_path = os.path.join(BASE_DIR, LABEL_TPL.format(c=cancer_code))

    edges_df  = pd.read_csv(edge_path)
    feat_df   = pd.read_csv(feat_path,  index_col=0).fillna(0)
    label_df  = pd.read_csv(label_path)

    genes = sorted(set(edges_df['gene1']) | set(edges_df['gene2']))
    n2i   = {g: i for i, g in enumerate(genes)}
    N     = len(genes)

    edge_index = torch.tensor(
        [[n2i[a], n2i[b]]
         for a, b in edges_df[['gene1', 'gene2']].values
         if a in n2i and b in n2i],
        dtype=torch.long).T.contiguous()

    d       = feat_df.shape[1]
    scaler  = scaler or StandardScaler().fit(feat_df.values)
    X_scaled = scaler.transform(feat_df.values)
    feat_df  = pd.DataFrame(X_scaled, index=feat_df.index)

    X        = np.zeros((N, d), dtype=np.float32)
    has_feat = np.zeros(N, dtype=bool)

    for g, row in feat_df.iterrows():
        if g in n2i:
            idx = n2i[g]
            X[idx]    = row.values
            has_feat[idx] = True

    neigh = {i: [] for i in range(N)}
    for s, d_ in edge_index.T.cpu().numpy():
        neigh[s].append(d_)
        neigh[d_].append(s)

    for i in range(N):
        if not has_feat[i]:
            neigh_feats = [X[n] for n in neigh[i] if has_feat[n]]
            if neigh_feats:
                X[i] = np.mean(neigh_feats, axis=0)

    y = torch.full((N,), -1, dtype=torch.long)
    for _, row in label_df.iterrows():
        g, lab = row['Gene'], row['Labels']
        if g in n2i:
            y[n2i[g]] = int(lab)

    data = Data(x=torch.tensor(X), edge_index=edge_index, y=y)
    data.gene_names = genes
    return data


In [5]:
graphs, input_dims = {}, {}
for c in CANCERS:
    graphs[c] = build_graph(c)
    input_dims[c] = graphs[c].num_node_features
print("Feature dimensionality for each cancer type:", input_dims)


Feature dimensionality for each cancer type: {'BRCA': 11, 'BLCA': 10, 'LUAD': 11, 'LIHC': 15, 'PRAD': 12, 'CESC': 18, 'COAD': 13, 'STAD': 12, 'THCA': 9, 'LUSC': 11, 'UCEC': 12, 'ESCA': 9}


In [16]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN_AE(nn.Module):
    def __init__(self, in_dim, hid_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim)
        self.conv2 = GCNConv(hid_dim, in_dim)
    def forward(self, x, edge_index):
        h = F.relu(self.conv1(x, edge_index))
        x_hat = self.conv2(h, edge_index)
        return x_hat

class GCN_Classifier_with_MLP(nn.Module):
    def __init__(self, in_dim, hid_dim=64, mlp_hidden=32, dropout=0.5):
        super().__init__()
        self.gcn1 = GCNConv(in_dim, hid_dim)
        self.gcn2 = GCNConv(hid_dim, hid_dim)
        self.mlp = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid_dim, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, 1)
        )
    def forward(self, x, edge_index):
        h = F.relu(self.gcn1(x, edge_index))
        h = F.relu(self.gcn2(h, edge_index))
        return self.mlp(h).squeeze()


In [7]:
def stage1_attribute_pretrain(data, hid=64, epoch=200, lr=1e-3, save_path=None):
    dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data = data.to(dev)
    model = GCN_AE(data.num_node_features, hid).to(dev)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    mse = nn.MSELoss()

    for ep in range(1, epoch+1):
        model.train(); opt.zero_grad()
        x_hat = model(data.x, data.edge_index)
        loss = mse(x_hat, data.x)
        loss.backward(); opt.step()
        if ep % 20 == 0 or ep == 1:
            print(f"[Stage1] Ep{ep:03d}/{epoch} | MSE {loss.item():.4f}")
    if save_path:
        torch.save(model.state_dict(), save_path)
    return model.state_dict()


In [10]:
import random, numpy as np
from sklearn.metrics import average_precision_score, accuracy_score

def stage2_label_pretrain(graphs, target, attr_state_dict,
                          hid=64, epochs=300, ckpt_dir="./ckpt"):
    dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data = graphs[target].to(dev)

    pos_genes = {g for c,gx in graphs.items() if c!=target
                 for g,l in zip(gx.gene_names, gx.y.tolist()) if l==1}
    pos_idx = [i for i,g in enumerate(data.gene_names) if g in pos_genes]
    neg_pool= [i for i in range(data.num_nodes) if data.y[i]==-1 and i not in pos_idx]
    neg_idx = random.sample(neg_pool, k=len(pos_idx))

    print(f"Stage2: pos {len(pos_idx)} | neg {len(neg_idx)}")

    y_syn = torch.full((data.num_nodes,), -1, dtype=torch.float32, device=dev)
    y_syn[pos_idx] = 1.0
    y_syn[neg_idx] = 0.0
    train_mask = y_syn != -1

    model = GCN_Classifier_with_MLP(data.num_node_features, hid).to(dev)
    model.gcn1.load_state_dict({k.replace('conv1.', ''): v
                            for k,v in attr_state_dict.items()
                            if k.startswith('conv1.')})

    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
    bce = nn.BCEWithLogitsLoss()

    for ep in range(1, epochs+1):
        model.train(); opt.zero_grad()
        logit = model(data.x, data.edge_index)
        loss  = bce(logit[train_mask], y_syn[train_mask])
        loss.backward(); opt.step()

        if ep % 20 == 0 or ep in (1,100,200,300):
            model.eval();
            with torch.no_grad():
                p = torch.sigmoid(model(data.x, data.edge_index)[train_mask]).cpu()
                au = average_precision_score(y_syn[train_mask].cpu(), p)
            print(f"[Stage2] Ep{ep:03d}/{epochs} | loss {loss.item():.4f} | AUPRC {au:.3f}")

        if ep in (100,200,300):
            ck = f"{ckpt_dir}/pretrain_{target}_ep{ep}.pth"
            torch.save(model.state_dict(), ck)
            print(" saved", ck)

    return model


In [None]:
SAVE_DIR = "/content/drive/MyDrive/pretrain_Graph"
os.makedirs(SAVE_DIR, exist_ok=True)

# ---- Stage 1 ----
attr_w = stage1_attribute_pretrain(
    data       = graphs["BLCA"],
    hid        = 64,
    epoch      = 200,
    lr         = 1e-3,
    save_path  = f"{SAVE_DIR}/attr_BLCA_ep200.pth"
)

# ---- Stage 2 ----
stage2_label_pretrain(
    graphs     = graphs,
    target     = "BLCA",
    attr_state_dict = attr_w,
    hid        = 64,
    epochs     = 300,
    ckpt_dir   = SAVE_DIR
)


In [29]:
import os, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, average_precision_score

def finetune_gcn_classifier(
        data: Data,
        ckpt_path: str,
        epochs: int = 200,
        lr: float = 1e-3,
        freeze_epochs: int = 20,
        patience: int = 40,
        weight_decay: float = 5e-4,
        seed: int = 42,
        save_dir: str = None):

    torch.manual_seed(seed); np.random.seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data   = data.to(device)

    labeled_idx = torch.where(data.y != -1)[0]
    y_np        = data.y[labeled_idx].cpu().numpy()
    skf         = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)

    all_metrics = []

    for fold, (trval, te) in enumerate(skf.split(np.arange(len(y_np)), y_np), 1):
        print(f"\n Fold {fold}/5 -------------------------------")

        trval = labeled_idx[trval]; te = labeled_idx[te]
        split = int(0.8*len(trval))
        tr_idx, va_idx = trval[:split], trval[split:]


        def mask(idx):
            m = torch.zeros(data.num_nodes, dtype=torch.bool, device=device)
            m[idx] = True
            return m
        m_tr, m_va, m_te = map(mask, [tr_idx, va_idx, te])

        model = GCN_Classifier_with_MLP(in_dim=data.num_node_features, hid_dim=64).to(device)

        model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=True)
        print(f" Loaded checkpoint {ckpt_path}")


        if freeze_epochs > 0:
            for p in model.gcn1.parameters(): p.requires_grad_(False)
            print(f"  Freeze encoder {freeze_epochs} epoch đầu")

        opt = torch.optim.Adam(filter(lambda p:p.requires_grad, model.parameters()),
                               lr=lr, weight_decay=weight_decay)
        loss_fn = nn.BCEWithLogitsLoss(
            pos_weight=torch.tensor(float((data.y[m_tr]==0).sum() /
                                          max(1, (data.y[m_tr]==1).sum())),
                                    device=device))

        best_state, best_val, wait = None, 1e9, 0
        for ep in range(1, epochs+1):


            if ep == freeze_epochs+1:
                for p in model.gcn1.parameters(): p.requires_grad_(True)
                opt = torch.optim.Adam(model.parameters(), lr=lr/10, weight_decay=weight_decay)
                print("  Unfreeze encoder")


            model.train(); opt.zero_grad()
            logit = model(data.x, data.edge_index)
            loss  = loss_fn(logit[m_tr], data.y[m_tr].float())
            loss.backward(); opt.step()


            model.eval();
            with torch.no_grad():
                logit_va = model(data.x, data.edge_index)[m_va]
                val_loss = loss_fn(logit_va, data.y[m_va].float()).item()


            if val_loss < best_val:
                best_val, best_state, wait = val_loss, model.state_dict(), 0
            else:
                wait += 1
            if wait >= patience:
                print(f"  Early-stop tại epoch {ep}")
                break

            if ep % 20 == 0 or ep == 1:
                with torch.no_grad():
                    probs_va = torch.sigmoid(logit_va).cpu()
                    au_va = average_precision_score(data.y[m_va].cpu(), probs_va)
                    print(f"Ep{ep:03d} | TrainL {loss.item():.4f} | ValL {val_loss:.4f} | ValAUPRC {au_va:.3f}")


        model.load_state_dict(best_state)
        model.eval();
        with torch.no_grad():
            probs_te = torch.sigmoid(model(data.x, data.edge_index)[m_te]).cpu()
            acc_te   = accuracy_score(data.y[m_te].cpu(), (probs_te>0.5).long())
            au_te    = average_precision_score(data.y[m_te].cpu(), probs_te)
            print(f"✅ Fold {fold} TEST | Acc {acc_te:.3f} | AUPRC {au_te:.3f}")
            all_metrics.append((acc_te, au_te))

        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            torch.save(best_state, f"{save_dir}/best_fold{fold}.pth")

    accs, auprs = zip(*all_metrics)
    print(f"\n📊 5-fold summary ⇒  Acc {np.mean(accs):.3f} ± {np.std(accs):.3f} | "
          f"AUPRC {np.mean(auprs):.3f} ± {np.std(auprs):.3f}")


In [None]:
CKPT_STAGE2 = "/content/drive/MyDrive/pretrain_Graph/pretrain_BLCA_ep300.pth"
DATA_cancer   = graphs["BLCA"]

finetune_gcn_classifier(
    data          = DATA_cancer,
    ckpt_path     = CKPT_STAGE2,
    epochs        = 300,
    freeze_epochs = 20,
    patience      = 40,
    lr            = 1e-3,
    save_dir      = "/content/drive/MyDrive/BLCA_all_finetune_graph"
)
