In [None]:
# ## whole 40 df
import pandas as pd
import os
import glob
# ---------- CONFIG ----------
folder_path = "mouse_GRN"
file_pattern = os.path.join(folder_path, "whole_*.txt")
trrust_path = "trrust_rawdata.mouse.tsv"
nes_threshold = 3.0
genie3_threshold = 0.003
MODEL_NAME      = "gemini-2.0-pro-exp-02-05"
MODEL_PROVIDER  = "google"
N_NEG_PER_POS   = 1        # class balance; 1 negative for each positive
OUTFILE     = "mouse_GRN/3_llm_reasoning_log_Stomach.jsonl"     # one JSON row per question
TASK_DF_LOCATION = "mouse_GRN/3_Stomach_tasks.csv" ### modify this
RESULT_OUTPUT_LOCATION = "mouse_GRN/3_Stomach_score.txt"
TEST_EVAL = "mouse_GRN/3_Stomach_cutoff_test.csv"
# GCN model parameters
PREDICT_CONTEXT = "Stomach"
epochs = 20               ###### increase  this 
MAX_PROMPT_LEN  = 4096     # guardrail
LEARNING_RATE=1e-2
### LLM parameters:
BINARY_CUTOFF = 0.2

  from pandas.core import (


In [2]:
# --- 1.  LOAD DATA ----------------------------------------------------------
trrust_df = pd.read_csv(
    trrust_path,
    sep="\t",
    names=["TF", "Target", "Mode", "PMID"]
)
all_dfs = []

for file_path in glob.glob(file_pattern):
    df = pd.read_csv(file_path, sep="\t")
    
    # Filter for "High" confidence
    df_filtered = df[df["Confidence"] == "High"].copy()
    
    # Additional filters
    df_filtered = df_filtered[
        (df_filtered["NES"] >= nes_threshold) & 
        (df_filtered["Genie3Weight"].notnull()) & 
        (df_filtered["Genie3Weight"] >= genie3_threshold)
    ]
    
    # Extract context from filename
    filename = os.path.basename(file_path)
    context = filename.replace("whole_", "").replace("-regulons.txt", "")
    
    # Add context column
    df_filtered["Context"] = context
    
    all_dfs.append(df_filtered)

# Combine all filtered dataframes
df_combined = pd.concat(all_dfs, ignore_index=True)

df_combined["Context"] = df_combined["Context"].astype("category")

In [5]:

import pandas as pd, numpy as np, re, ast, json
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from utils.LLM import query_llm

# --- 2.  TASK BUILDER -------------------------------------------------------
# ---------------------------------------------------------------------
#  2‑bis.  TASK BUILDER  (train‑pool  →  held‑out graph)
# ---------------------------------------------------------------------
def build_tasks_multi(df, trrust,
                      train_ctxs,           # list of training Context names
                      test_ctx,             # single held‑out Context name
                      n_neg_per_pos=1,
                      max_known=50):  
    """Return a list[dict] where each dict is one binary TF-gene question."""
    
    df_train = df[df.Context.isin(train_ctxs)].copy()
    df_test  = df.query("Context == @test_ctx").copy()
    
    tasks = []
    
    # ---- TFs that have at least one edge in *both* pools --------------
    tf_common = set(df_train.TF).intersection(df_test.TF)
    
    for tf in tf_common:
        # ----------------  context A  (union over 39 graphs)  ----------
        known_A = (
            df_train.loc[df_train.TF == tf, "gene"]
                     .unique()
                     .tolist()
        )
        context_A = df_train.loc[df_train.TF == tf, "Context"].unique().tolist()
        if not known_A:
            continue
        
        # ----------------  context B  (held‑out graph)  ----------------
        cand_B = (
            df_test.loc[df_test.TF == tf, "gene"]
                    .unique()
                    .tolist()
        )
        if not cand_B:
            continue
        
        # ----------------  positives / negatives -----------------------
        pos_set = set(trrust[trrust.TF == tf].Target) & set(cand_B)
        neg_set = set(cand_B) - pos_set
        
        if len(pos_set) == 0 or len(neg_set) == 0:
            continue
        
        # balanced negative sampling
        rng = np.random.default_rng(0)  # reproducible
        n_neg = min(len(pos_set)*n_neg_per_pos, len(neg_set))
        neg_sample = rng.choice(list(neg_set), size=n_neg, replace=False)
        
        # build question dicts
        for gene, label in (
            list(zip(pos_set, [1]*len(pos_set))) +
            list(zip(neg_sample, [0]*len(neg_sample)))
        ):
            tasks.append({
                "TF"        : tf,
                "gene"      : gene,
                "context_A" : context_A,      
                "context_B" : test_ctx,
                "known_A"   : known_A[:max_known],
                "label"     : label
            })
    return tasks


train_contexts  = [c for c in df_combined.Context.unique()
                   if c != PREDICT_CONTEXT]
tasks = build_tasks_multi(df_combined, trrust_df,
                          train_ctxs=train_contexts,
                          test_ctx=PREDICT_CONTEXT,
                          n_neg_per_pos=1)

print(f"Total binary questions: {len(tasks)}  "
      f"({sum(t['label'] for t in tasks)} positives)")

Total binary questions: 46  (23 positives)


In [3]:


# ---------------------------------------------------------------------
# 0.  Imports & helpers
# ---------------------------------------------------------------------
import torch, torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn   import GCNConv
import numpy  as np
import pandas as pd
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_auc_score
import random, warnings

warnings.filterwarnings("ignore", category=UserWarning)   # PyG verbosity

# ---------------------------------------------------------------------
# 1.  Select which graphs are train vs. test
# ---------------------------------------------------------------------
# -- YOUR df_combined must contain at least the columns
#    'TF', 'gene', 'Genie3Weight', 'Context'
# ---------------------------------------------------------------------
# 39 training contexts  ->  put the names in a list
train_contexts = [ctx for ctx in df_combined["Context"].unique().tolist() if ctx != PREDICT_CONTEXT]
test_context   = PREDICT_CONTEXT          # held‑out graph

df_train = df_combined[df_combined.Context.isin(train_contexts)].copy()
df_test  = df_combined.query("Context == @test_context").copy()

print(f"train graphs = {train_contexts!r}")
print(f"test  graph  = {test_context!r}")
print(f"train edges = {len(df_train):,d} | test edges = {len(df_test):,d}")

# ---------------------------------------------------------------------
# 2.  Shared node index across *all* graphs
# ---------------------------------------------------------------------
all_nodes = pd.Index(df_combined.TF).union(df_combined.gene)
node2idx  = {n: i for i, n in enumerate(all_nodes)}
num_nodes = len(all_nodes)

def edges_to_index(df):
    src = df.TF  .map(node2idx).to_numpy()
    dst = df.gene.map(node2idx).to_numpy()
    return torch.as_tensor(np.vstack([src, dst]), dtype=torch.long)

edge_index_train  = edges_to_index(df_train)
edge_weight_train = torch.tensor(df_train.Genie3Weight.values,
                                 dtype=torch.float32)

edge_index_test   = edges_to_index(df_test)          # for later

# ---------------------------------------------------------------------
# 3.  Node features – simple trainable embeddings
# ---------------------------------------------------------------------
feat_dim = 128
x_embed  = torch.nn.Embedding(num_nodes, feat_dim)

# PyG Data object that *includes* edge weights
data = Data(x=x_embed.weight,
            edge_index=edge_index_train,
            edge_weight=edge_weight_train)

# ---------------------------------------------------------------------
# 4.  GCN encoder + dot‑product decoder (unchanged)
# ---------------------------------------------------------------------
class GCNLink(torch.nn.Module):
    def __init__(self, in_dim, hid=64):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid)
        self.conv2 = GCNConv(hid,  hid)

    def forward(self, x, edge_index, w):
        h = F.relu(self.conv1(x, edge_index, w))
        h = self.conv2(h, edge_index, w)
        return h

def dot_score(h, pairs):                       # pairs = 2×N indices
    return (h[pairs[0]] * h[pairs[1]]).sum(dim=-1)

# ---------------------------------------------------------------------
# 5.  Negative‑edge sampler (uniform corruption, unchanged)
# ---------------------------------------------------------------------
pos_set = set(zip(edge_index_train[0].tolist(),
                  edge_index_train[1].tolist()))

def sample_neg(num_neg):
    u = torch.randint(0, num_nodes, (num_neg,))
    v = torch.randint(0, num_nodes, (num_neg,))
    mask = torch.tensor([(u[i].item(), v[i].item()) not in pos_set
                         for i in range(num_neg)])
    return torch.stack([u[mask], v[mask]], 0)

# ---------------------------------------------------------------------
# 6.  Training loop
# ---------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data   = data.to(device)
model  = GCNLink(feat_dim).to(device)

opt    = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(1, epochs + 1):
    model.train(); opt.zero_grad()
    h = model(data.x, data.edge_index, data.edge_weight)

    # positive & negative scores
    pos_s = dot_score(h, data.edge_index)
    neg_i = sample_neg(pos_s.size(0)).to(device)
    neg_s = dot_score(h, neg_i)

    y_true = torch.cat([torch.ones_like(pos_s), torch.zeros_like(neg_s)])
    y_pred = torch.cat([pos_s,              neg_s            ])
    loss   = F.binary_cross_entropy_with_logits(y_pred, y_true)

    loss.backward(); opt.step()

    if epoch % 20 == 0:
        print(f"epoch {epoch:03d} | loss = {loss.item():.4f}")

  _torch_pytree._register_pytree_node(


train graphs = ['MammaryGland.Involution', 'NeonatalMuscle', 'MammaryGland.Virgin', 'MesenchymalStemCells', 'FetalKidney', 'FetalLiver', 'PeripheralBlood', 'Spleen', 'FetalLung', 'MammaryGland.Virgin.CD45', 'MammaryGland.Lactation', 'Liver', 'BoneMarrowcKit', 'NeonatalRib', 'Kidney', 'MammaryGland.Pregnancy', 'SmallIntestine', 'Testis', 'Ovary', 'Prostate', 'TrophoblastStemCells', 'FetalStomach', 'FetalIntestine', 'Lung', 'BoneMarrow', 'EmbryonicMesenchyme', 'Bladder', 'Brain', 'NeonatalSkin', 'NeonatalHeart', 'NeonatalCalvaria', 'MammaryGland.Involution.CD45', 'Muscle', 'FetalBrain', 'EmbryonicStemCells', 'Thymus', 'Placenta', 'Uterus', 'NeonatalPancreas', 'SmallIntestine.CD45']
test  graph  = 'Stomach'
train edges = 337,025 | test edges = 12,783
epoch 020 | loss = 0.6387


In [None]:
# ---------------------------------------------------------------------
# 7.  Embeddings for *all* nodes after training
# ---------------------------------------------------------------------
model.eval()
with torch.no_grad():
    H = model(data.x, data.edge_index, data.edge_weight).cpu()

# ---------------------------------------------------------------------
# 8.  Helper: evaluate on a paired‑question list (unchanged API)
# ---------------------------------------------------------------------
from torch.nn.functional import sigmoid

def gcn_predict(tasks, H, node2idx, thresh=0.5):
    y_true, y_pred, y_score = [], [], []
    for t in tasks:
        i_tf   = node2idx.get(t["TF"])
        i_gene = node2idx.get(t["gene"])
        if i_tf is None or i_gene is None:      # unseen node guard
            continue
        logit = torch.dot(H[i_tf], H[i_gene]).item()
        prob  = sigmoid(torch.tensor(logit)).item()
        y_true.append(t["label"])
        y_score.append(prob)
        y_pred.append(1 if prob >= thresh else 0)
    return np.array(y_true), np.array(y_pred), np.array(y_score)

# ---------------------------------------------------------------------
# 9.  Example evaluation
#     (replace `tasks` with your actual paired‑question list)
# ---------------------------------------------------------------------
y_true_gcn_whole, y_pred_gcn_whole, y_score_gcn_whole = gcn_predict(tasks, H, node2idx)

def evaluate(y_true, y_pred, y_score):
    p,r,f,_ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )
    try:
        auc = roc_auc_score(y_true, y_score)
    except ValueError:
        auc = float("nan")
    print(f"Precision = {p:.2f}\nRecall    = {r:.2f}"
          f"\nF1 score  = {f:.2f}\nAUROC     = {auc:.2f}")
    return (f"{p:.2f}",f"{r:.2f}",f"{f:.2f}",f"{auc:.2f}")
    

Precision,Recall,F1, AUROC = evaluate(y_true_gcn_whole, y_pred_gcn_whole, y_score_gcn_whole)

Total binary questions: 46  (23 positives)
Precision = 0.51
Recall    = 0.96
F1 score  = 0.67
AUROC     = 0.63


## trying GAT

In [6]:
# ---------------------------------------------------------------------
# 4‑bis.  GAT encoder + dot‑product decoder
#         (replace the old Section 4 with this block)
# ---------------------------------------------------------------------
from torch_geometric.nn import GATConv          # NEW import

class GATLink(torch.nn.Module):
    """
    Two layer GAT.
      • First layer: 8 attention heads, ELU activation, concat=True  
      • Second layer: 1 head, concat=False to return hidden dim size
      • Edge weights are passed as a 1D edge_attr so the attention
        mechanism can learn to incorporate Genie3 scores.
    """
    def __init__(self, in_dim, hid=64, heads=8, dropout=0.1):
        super().__init__()
        self.conv1 = GATConv(
            in_channels=in_dim,
            out_channels=hid,
            heads=heads,
            edge_dim=1,          # <‑‑ one scalar per edge
            dropout=dropout
        )
        self.conv2 = GATConv(
            in_channels=hid * heads,   # because concat=True above
            out_channels=hid,
            heads=1,
            concat=False,              # keep hidden size = hid
            edge_dim=1,
            dropout=dropout
        )

    def forward(self, x, edge_index, edge_weight):
        edge_attr = edge_weight.unsqueeze(-1)        # shape [E, 1]
        h = F.elu(self.conv1(x, edge_index, edge_attr))
        h = self.conv2(h, edge_index, edge_attr)
        return h

In [7]:
learning_rate_2 = 5e-3
epoch_2 = 40

In [None]:
model = GATLink(feat_dim, hid=64, heads=8, dropout=0.1).to(device)
opt    = torch.optim.Adam(model.parameters(), lr=learning_rate_2)

for epoch in range(1, epochs + 1):
    model.train(); opt.zero_grad()
    h = model(data.x, data.edge_index, data.edge_weight)

    # positive & negative scores
    pos_s = dot_score(h, data.edge_index)
    neg_i = sample_neg(pos_s.size(0)).to(device)
    neg_s = dot_score(h, neg_i)

    y_true = torch.cat([torch.ones_like(pos_s), torch.zeros_like(neg_s)])
    y_pred = torch.cat([pos_s,              neg_s            ])
    loss   = F.binary_cross_entropy_with_logits(y_pred, y_true)

    loss.backward(); opt.step()

    if epoch % 2 == 0:
        print(f"epoch {epoch:03d} | loss = {loss.item():.4f}")

In [None]:
# ---------------------------------------------------------------------
# 7.  Embeddings for *all* nodes after training
# ---------------------------------------------------------------------
model.eval()
with torch.no_grad():
    H = model(data.x, data.edge_index, data.edge_weight).cpu()

# ---------------------------------------------------------------------
# 8.  Helper: evaluate on a paired‑question list (unchanged API)
# ---------------------------------------------------------------------
from torch.nn.functional import sigmoid

def gcn_predict(tasks, H, node2idx, thresh=0.5):
    y_true, y_pred, y_score = [], [], []
    for t in tasks:
        i_tf   = node2idx.get(t["TF"])
        i_gene = node2idx.get(t["gene"])
        if i_tf is None or i_gene is None:      # unseen node guard
            continue
        logit = torch.dot(H[i_tf], H[i_gene]).item()
        prob  = sigmoid(torch.tensor(logit)).item()
        y_true.append(t["label"])
        y_score.append(prob)
        y_pred.append(1 if prob >= thresh else 0)
    return np.array(y_true), np.array(y_pred), np.array(y_score)

# ---------------------------------------------------------------------
# 9.  Example evaluation
#     (replace `tasks` with your actual paired‑question list)
# ---------------------------------------------------------------------
y_true_gcn_whole, y_pred_gcn_whole, y_score_gcn_whole = gcn_predict(tasks, H, node2idx)

def evaluate(y_true, y_pred, y_score):
    p,r,f,_ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )
    try:
        auc = roc_auc_score(y_true, y_score)
    except ValueError:
        auc = float("nan")
    print(f"Precision = {p:.2f}\nRecall    = {r:.2f}"
          f"\nF1 score  = {f:.2f}\nAUROC     = {auc:.2f}")
    return (f"{p:.2f}",f"{r:.2f}",f"{f:.2f}",f"{auc:.2f}")
    

Precision,Recall,F1, AUROC = evaluate(y_true_gcn_whole, y_pred_gcn_whole, y_score_gcn_whole)