In [1]:
!pip install numpy pandas networkx scikit-learn node2vec tqdm
!pip install torch torch-geometric torch-scatter torch-sparse torch-geometric-temporal
!pip install google-cloud-bigquery
!pip install google-cloud-storage

Collecting torch-geometric
  Using cached torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
Collecting torch-scatter
  Using cached torch_scatter-2.1.2.tar.gz (108 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-sparse
  Using cached torch_sparse-0.6.18.tar.gz (209 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-geometric-temporal
  Using cached torch_geometric_temporal-0.56.2-py3-none-any.whl.metadata (1.9 kB)
Using cached torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
Using cached torch_geometric_temporal-0.56.2-py3-none-any.whl (102 kB)
Building wheels for collected packages: torch-scatter, torch-sparse
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp312-cp312-linux_x86_64.whl size=3806389 sha256=943033add0679941f36546c4d6f9538c7cccf6f316ba4bc98d62a4d5d3449e4d
  Stored in directory: /root/.cache/pip/wheels/84/20/50/44800723f57cd798630e77b3ec

In [2]:
import torch
import os
TORCH_VERSION = torch.__version__.split('+')[0]
CUDA_VERSION = torch.version.cuda
if CUDA_VERSION:
    CUDA_VERSION = "cu" + CUDA_VERSION.replace('.', '')
else:
    CUDA_VERSION = "cpu"

In [3]:
from google.auth import default
creds, _ = default()

In [None]:
!gsutil ls

gs://cs224w-mimic-data/


## Prepare Preprocessed Data

In [4]:
# Retrieve graph from MIMIC preprocessing

from google.cloud import storage
import io
import pandas as pd

bucket_name = "cs224w-mimic-data"
edges_path = "disease_edges.csv"
nodes_path = "disease_nodes.csv"

client = storage.Client()
bucket = client.get_bucket(bucket_name)

edges_blob = bucket.blob(edges_path)
nodes_blob = bucket.blob(nodes_path)

edges_bytes = edges_blob.download_as_bytes()
nodes_bytes = nodes_blob.download_as_bytes()

edges_df = pd.read_csv(io.BytesIO(edges_bytes), parse_dates=["timestamp"])
nodes_df = pd.read_csv(io.BytesIO(nodes_bytes))

print(edges_df.head())
print(nodes_df.head())
print("Total edges:", len(edges_df))
print("Total unique diseases:", len(nodes_df))

     src    dst           timestamp
0    496  30981 2180-05-06 22:23:00
1  30981  07070 2180-05-06 22:23:00
2  07070   5723 2180-05-06 22:23:00
3   5723   5715 2180-05-06 22:23:00
4   5715  78959 2180-05-06 22:23:00
  disease_code
0          496
1        30981
2        07070
3         5723
4         5715
Total edges: 6141197
Total unique diseases: 28562


In [5]:
# Create Train/Test Split
import networkx as nx
import numpy as np

# sort by timestamp
edges_df = edges_df.sort_values("timestamp")

# Create a chronological 70/15/15 train/validation/test split
# Resets dataframe indexing to sorted order
train_ratio = 0.70
val_ratio = 0.15
test_ratio = 0.15
train_end_idx = int(train_ratio * len(edges_df))
val_end_idx = int((train_ratio + val_ratio) * len(edges_df))

# Create three splits
train_edges_df = edges_df.iloc[:train_end_idx].reset_index(drop=True)
val_edges_df = edges_df.iloc[train_end_idx:val_end_idx].reset_index(drop=True)
test_edges_df = edges_df.iloc[val_end_idx:].reset_index(drop=True)

# Verify splits
total_edges = len(edges_df)
train_pct = len(train_edges_df) / total_edges * 100
val_pct = len(val_edges_df) / total_edges * 100
test_pct = len(test_edges_df) / total_edges * 100

print(f"Total edges: {total_edges:,}")
print(f"Train: {len(train_edges_df):,} ({train_pct:.1f}%)")
print(f"Val:   {len(val_edges_df):,} ({val_pct:.1f}%)")
print(f"Test:  {len(test_edges_df):,} ({test_pct:.1f}%)")
print()

# Show temporal boundaries
print("Temporal boundaries:")
print(f"Train: {train_edges_df['timestamp'].min()} to {train_edges_df['timestamp'].max()}")
print(f"Val:   {val_edges_df['timestamp'].min()} to {val_edges_df['timestamp'].max()}")
print(f"Test:  {test_edges_df['timestamp'].min()} to {test_edges_df['timestamp'].max()}")
print()


Total edges: 6,141,197
Train: 4,298,837 (70.0%)
Val:   921,180 (15.0%)
Test:  921,180 (15.0%)

Temporal boundaries:
Train: 2105-10-04 17:26:00 to 2171-12-07 13:52:00
Val:   2171-12-07 13:52:00 to 2183-12-21 19:57:00
Test:  2183-12-21 19:57:00 to 2214-12-15 19:11:00



In [21]:
import numpy as np
import pandas as pd
from collections import defaultdict

# ---------- Generate positive samples -------------
train_pos_df = train_edges_df.copy()
val_pos_df = val_edges_df.copy()
test_pos_df = test_edges_df.copy()

# Convert to list format (for some baseline methods)
train_pos = train_edges_df[["src", "dst", "timestamp"]].values.tolist()
val_pos = val_edges_df[["src", "dst", "timestamp"]].values.tolist()
test_pos = test_edges_df[["src", "dst", "timestamp"]].values.tolist()

print(f"Positive samples:")
print(f" Train: {len(train_pos):,}")
print(f" Val:   {len(val_pos):,}")
print(f" Test:  {len(test_pos):,}")
print()

# ---------- Generate negative samples ----------

def negative_sampling(pos_edges_df, all_nodes, exclude_edges_df=None,
                           n_negatives_per_positive=1, seed=42):
    """
    Strategy:
    1. Generate many candidate negatives in large batches (vectorized)
    2. Filter out positives using set operations
    3. Sample exactly n negatives per positive

    Args:
        pos_edges_df: Positive edges DataFrame with columns ['src', 'dst', 'timestamp']
        all_nodes: Array of all disease codes
        exclude_edges_df: Edges to exclude from negatives
        n_negatives_per_positive: Number of negatives per positive
        seed: Random seed
    """

    np.random.seed(seed)
    n_pos = len(pos_edges_df)
    needed = n_pos * n_negatives_per_positive

    if exclude_edges_df is None:
        exclude_set = set(zip(pos_edges_df["src"], pos_edges_df["dst"]))
    else:
        exclude_set = set(zip(exclude_edges_df["src"], exclude_edges_df["dst"]))

    all_nodes = np.array(all_nodes)
    n_nodes = len(all_nodes)

    print(f"Need {needed:,} negatives... generating candidates...")

    # Generate large batch
    oversample_factor = 3
    n_candidates = needed * oversample_factor
    src_candidates = np.random.choice(all_nodes, size=n_candidates)
    dst_candidates = np.random.choice(all_nodes, size=n_candidates)

    # Filter self-loops
    mask = src_candidates != dst_candidates

    # Build array of tuples for filtering
    src_str = src_candidates.astype(str)
    dst_str = dst_candidates.astype(str)
    pairs = np.char.add(src_str, np.char.add(",", dst_str))

    exclude_pairs = set([f"{s},{d}" for (s, d) in exclude_set])
    mask &= np.array([p not in exclude_pairs for p in pairs])

    # Select valid negatives
    valid_src = src_candidates[mask]
    valid_dst = dst_candidates[mask]

    # Take exactly the first `needed`
    neg_src = valid_src[:needed]
    neg_dst = valid_dst[:needed]

    neg_df = pd.DataFrame({
        "src": neg_src,
        "dst": neg_dst,
        "timestamp": None
    })

    print(f"Generated {len(neg_df):,} negatives.")
    return neg_df


all_diseases = nodes_df["disease_code"].values

print(f"Generating training negative samples...")
train_neg_df = negative_sampling(
    train_pos_df,
    all_diseases,
    exclude_edges_df=train_edges_df,  # exclude pos train examples
    n_negatives_per_positive=1
)

print(f"Generating validation negative samples...")
train_and_val = pd.concat([train_edges_df, val_edges_df])
val_neg_df = negative_sampling(
    val_pos_df,
    all_diseases,
    exclude_edges_df=train_and_val  # exclude pos train + validation examples
)

print(f"Generating testing negative samples...")
test_neg_df = negative_sampling(
    test_pos_df,
    all_diseases,
    exclude_edges_df=edges_df  # exclude all positive examples
)

print()
print(f"Negative samples generated:")
print(f" Train: {len(train_neg_df):,}")
print(f" Val:   {len(val_neg_df):,}")
print(f" Test:  {len(test_neg_df):,}")
print()

Positive samples:
 Train: 4,298,837
 Val:   921,180
 Test:  921,180

Generating training negative samples...
Need 4,298,837 negatives... generating candidates...
Generated 4,298,837 negatives.
Generating validation negative samples...
Need 921,180 negatives... generating candidates...
Generated 921,180 negatives.
Generating testing negative samples...
Need 921,180 negatives... generating candidates...
Generated 921,180 negatives.

Negative samples generated:
 Train: 4,298,837
 Val:   921,180
 Test:  921,180



In [25]:
import os

def save_splits_local(train_pos_df, train_neg_df,
                     val_pos_df, val_neg_df,
                     test_pos_df, test_neg_df,
                     save_dir="/content/data_splits"):
    """
    Save all splits to local files (useful for Colab session)
    """


    print("Saving splits from local files")
    os.makedirs(save_dir, exist_ok=True)

    # Save positive edges
    train_pos_df.to_csv(f"{save_dir}/train_pos.csv", index=False)
    val_pos_df.to_csv(f"{save_dir}/val_pos.csv", index=False)
    test_pos_df.to_csv(f"{save_dir}/test_pos.csv", index=False)

    # Save negative edges
    train_neg_df.to_csv(f"{save_dir}/train_neg.csv", index=False)
    val_neg_df.to_csv(f"{save_dir}/val_neg.csv", index=False)
    test_neg_df.to_csv(f"{save_dir}/test_neg.csv", index=False)

    print(f"✓ Saved 6 files to {save_dir}/")
    print(f"  train_pos.csv: {len(train_pos_df):,} rows")
    print(f"  train_neg.csv: {len(train_neg_df):,} rows")
    print(f"  val_pos.csv:   {len(val_pos_df):,} rows")
    print(f"  val_neg.csv:   {len(val_neg_df):,} rows")
    print(f"  test_pos.csv:  {len(test_pos_df):,} rows")
    print(f"  test_neg.csv:  {len(test_neg_df):,} rows")

    # Calculate total size
    total_size = sum(os.path.getsize(f"{save_dir}/{f}")
                    for f in os.listdir(save_dir)) / (1024**2)
    print(f"  Total size: {total_size:.1f} MB")
    print()

def load_splits_local(save_dir="/content/splits"):
    """
    Load all splits from local files
    """
    print("Load splits from local files")

    train_pos_df = pd.read_csv(f"{save_dir}/train_pos.csv", parse_dates=["timestamp"])
    val_pos_df = pd.read_csv(f"{save_dir}/val_pos.csv", parse_dates=["timestamp"])
    test_pos_df = pd.read_csv(f"{save_dir}/test_pos.csv", parse_dates=["timestamp"])

    train_neg_df = pd.read_csv(f"{save_dir}/train_neg.csv")
    val_neg_df = pd.read_csv(f"{save_dir}/val_neg.csv")
    test_neg_df = pd.read_csv(f"{save_dir}/test_neg.csv")

    print(f"✓ Loaded 6 files from {save_dir}/")
    print(f"  train_pos: {len(train_pos_df):,} rows")
    print(f"  train_neg: {len(train_neg_df):,} rows")
    print(f"  val_pos:   {len(val_pos_df):,} rows")
    print(f"  val_neg:   {len(val_neg_df):,} rows")
    print(f"  test_pos:  {len(test_pos_df):,} rows")
    print(f"  test_neg:  {len(test_neg_df):,} rows")
    print()

    return train_pos_df, train_neg_df, val_pos_df, val_neg_df, test_pos_df, test_neg_df

save_splits_local(train_pos_df, train_neg_df, val_pos_df, val_neg_df,test_pos_df, test_neg_df,save_dir="/content/data_splits")

Saving splits from local files
✓ Saved 6 files to /content/data_splits/
  train_pos.csv: 4,298,837 rows
  train_neg.csv: 4,298,837 rows
  val_pos.csv:   921,180 rows
  val_neg.csv:   921,180 rows
  test_pos.csv:  921,180 rows
  test_neg.csv:  921,180 rows
  Total size: 263.2 MB



In [None]:
train_pos_df, train_neg_df, val_pos_df, val_neg_df, test_pos_df, test_neg_df = load_splits_local() # run when need

In [27]:
# --------- Create validation set with labels ---------
val_pos_w_label = val_pos_df[["src", "dst"]].copy()
val_pos_w_label["label"] = 1

val_neg_w_label = val_neg_df[["src", "dst"]].copy()
val_neg_w_label["label"] = 0

val_all = pd.concat([val_pos_w_label, val_neg_w_label], ignore_index=True)

# Test set with labels
test_pos_w_label = test_pos_df[["src", "dst"]].copy()
test_pos_w_label["label"] = 1

test_neg_w_label = test_neg_df[["src", "dst"]].copy()
test_neg_w_label["label"] = 0

test_all = pd.concat([test_pos_w_label, test_neg_w_label], ignore_index=True)

# Shuffle for better evaluation (optional)
val_all = val_all.sample(frac=1, random_state=42).reset_index(drop=True)
test_all = test_all.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"Labeled datasets:")
print(f"  Validation: {len(val_all):,} edges ({val_all['label'].sum():,} positive, "
      f"{(val_all['label'] == 0).sum():,} negative)")
print(f"  Test:       {len(test_all):,} edges ({test_all['label'].sum():,} positive, "
      f"{(test_all['label'] == 0).sum():,} negative)")
print()


# ---------- Create training graph ------------

G_train = nx.DiGraph()
G_train.add_edges_from(train_edges_df[["src", "dst"]].values)

print(f"Training graph statistics:")
print(f"  Nodes: {G_train.number_of_nodes():,}")
print(f"  Edges: {G_train.number_of_edges():,}")
print(f"  Density: {nx.density(G_train):.6f}")

# Check for disconnected components
if nx.is_weakly_connected(G_train):
    print(f"  Graph is weakly connected")
else:
    num_components = nx.number_weakly_connected_components(G_train)
    print(f"  Graph has {num_components} weakly connected components")

print()

Labeled datasets:
  Validation: 1,842,360 edges (921,180 positive, 921,180 negative)
  Test:       1,842,360 edges (921,180 positive, 921,180 negative)

Training graph statistics:
  Nodes: 26,508
  Edges: 1,527,244
  Density: 0.002174
  Graph is weakly connected



##Baseline: Completely Random

In [None]:
import random

diseases = list(nodes_df["disease_code"])
y_true = test_all["label"].values
y_score_random = [random.random() for _ in range(len(y_true))]
auc_random = roc_auc_score(y_true, y_score_random)
ap_random = average_precision_score(y_true, y_score_random)

hits = 0
total = len(test_pos)
for src, dst, _ in test_pos:
    random_10_diseases = random.choices(diseases, k=10)
    if dst in random_10_diseases:
        hits += 1
hits_at_10_random = hits / total

print("Baseline: Completely Random")
print("ROC AUC:", auc_random)
print("Average Precision:", ap_random)
print("Hits@10:", hits_at_10_random)

Baseline: Completely Random
ROC AUC: 0.4999290437408782
Average Precision: 0.5002865647195095
Hits@10: 0.0003272975965609327


##Baseline: Most Frequent Next Hop

In [28]:
import pandas as pd
from collections import defaultdict

# count transition per src -> dst pair
transition_counts = train_pos_df.groupby(["src", "dst"]).size().reset_index(name="count")

# build transition frequency dictionary from training edges
next_disease_freq = defaultdict(dict)
for _, row in transition_counts.iterrows():
    next_disease_freq[row["src"]][row["dst"]] = row["count"]

print(f"Number of src diseases: {len(next_disease_freq)}")

total_edges = len(edges_df)
def most_frequent_next_hop_score(src, dst):
    # Just return the actual frequency/probability
    freq = next_disease_freq[src].get(dst, 0)
    total_from_src = sum(next_disease_freq[src].values())
    if total_from_src == 0:
        return 0
    return freq / total_from_src  # P(dst | src)

# score edges by how common they are
test_all["score_most_common"] = test_all.apply(
    lambda row: most_frequent_next_hop_score(row["src"], row["dst"]),
    axis=1
)
print(f"Snippet of test_all: {test_all[:5]}")

Number of src diseases: 26138
Snippet of test_all:        src      dst  label  score_most_common
0  T82897A    H7092      0           0.000000
1    V1271    V6284      1           0.000926
2     E785     I255      1           0.002132
3    66951  S30860A      0           0.000000
4   I97120  V584XXA      0           0.000000


In [29]:
from sklearn.metrics import roc_auc_score, average_precision_score

# evaluate baseline
y_true = test_all["label"].values
y_score = test_all["score_most_common"].values
top10_next = dict()
for src, freqs in next_disease_freq.items():
  top10_dst = sorted(freqs.items(), key=lambda x: x[1], reverse=True)[:10]
  top10_next[src] = top10_dst
top10_next = {src: {dst for dst, _ in lst} for src, lst in top10_next.items()}

auc = roc_auc_score(y_true, y_score)
ap = average_precision_score(y_true, y_score)

hits = 0
total = len(test_pos)
for src, dst, _ in test_pos:
    if src in top10_next and dst in top10_next[src]:
        hits += 1

hits_at_10 = hits / total


print("Baseline: Most Frequent Next Hop")
print("ROC AUC:", auc)
print("Average Precision:", ap)
print("Hits@10:", hits_at_10)


Baseline: Most Frequent Next Hop
ROC AUC: 0.879977854490979
Average Precision: 0.879977854490979
Hits@10: 0.17548361883670943


##Baseline: Degree Aware

In [30]:
import pandas as pd
import numpy as np

# in-degree from train edges
in_deg = train_edges_df["dst"].value_counts()  # disease_code -> count
max_deg = in_deg.max()

def degree_aware_random_score(src, dst):
    deg = in_deg.get(dst, 0)
    # normalized degree
    deg_norm = deg / max_deg
    # random score but scaled a bit by degree
    return deg_norm

# score edges by in degree freq
test_all["score_deg_rand"] = test_all.apply(
    lambda row: degree_aware_random_score(row["src"], row["dst"]),
    axis=1
)

In [31]:
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np

y_true = test_all["label"].values
y_score = test_all["score_deg_rand"].values

auc = roc_auc_score(y_true, y_score)
ap  = average_precision_score(y_true, y_score)

print("Baseline: Degree Aware")
print("ROC AUC:", auc)
print("AP:", ap)

def compute_hits_at_k(test_df, score_col, k=10):
    """
    test_df: DataFrame with columns ['src', 'dst', 'label', score_col]
    score_col: name of the column with scores (e.g., 'score_deg_rand')
    k: cutoff for Hits@K
    """
    hits = []

    for src, group in test_df.groupby("src"):
        # sort by model score descending
        group_sorted = group.sort_values(score_col, ascending=False)
        top_k = group_sorted.head(k)

        # hit if any of the top-k edges is actually positive
        hit = (top_k["label"] == 1).any()
        hits.append(int(hit))

    return np.mean(hits)

hits10 = compute_hits_at_k(test_all, score_col="score_deg_rand", k=10)
print("Hits@10:", hits10)


Baseline: Degree Aware
ROC AUC: 0.9654606428395175
AP: 0.9644292119282976
Hits@10: 0.572683985715286
