In [6]:
"""
NCI1 â†” NCI109 Domain Adaptation Template
----------------------------------------
1) Load NCI1 (source) and NCI109 (target)
2) Train a small GCN on source only
3) *** PLACEHOLDER *** for your (F)GW-based adaptation
4) Evaluate on target

Switch source/target by swapping dataset names at the top.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_add_pool
import torch_geometric.transforms as T
from sklearn.metrics import accuracy_score

# ----------------------------
# Config
# ----------------------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_ROOT = './data'
SOURCE_NAME = 'NCI1'
TARGET_NAME = 'NCI109'
BATCH_SIZE = 64
LR = 1e-3
EPOCHS = 200
HIDDEN = 64
TRAIN_RATIO = 0.8
OH_DEGREE = 8  # one-hot degree cap (keeps features simple & consistent)

In [7]:
# ----------------------------
# 1) Load datasets (+ robust node features)
# ----------------------------
# Use OneHotDegree to guarantee x exists even if raw features are missing.
import os, certifi
os.environ["SSL_CERT_FILE"] = certifi.where()
transform = T.OneHotDegree(max_degree=OH_DEGREE, cat=False)
src_ds = TUDataset(root=DATA_ROOT, name=SOURCE_NAME, transform=transform)
tgt_ds = TUDataset(root=DATA_ROOT, name=TARGET_NAME, transform=transform)

# Ensure identical feature dimensions across domains (replace, don't concat)
assert src_ds.num_features == tgt_ds.num_features, f"Feature dim mismatch: src={src_ds.num_features}, tgt={tgt_ds.num_features}. Use OneHotDegree(..., cat=False) or pad features."
in_dim = src_ds.num_features

print(f"Source: {SOURCE_NAME} -> {len(src_ds)} graphs, x_dim={src_ds.num_features}, classes={src_ds.num_classes}")
print(f"Target: {TARGET_NAME} -> {len(tgt_ds)} graphs, x_dim={tgt_ds.num_features}, classes={tgt_ds.num_classes}")

num_classes = src_ds.num_classes  # NCI1/NCI109 are binary

# Simple random split
def split_dataset(ds, ratio=TRAIN_RATIO):
    n = len(ds)
    perm = torch.randperm(n)
    n_tr = int(ratio * n)
    return ds[perm[:n_tr]], ds[perm[n_tr:]]

src_train, src_test = split_dataset(src_ds)
tgt_train, tgt_test = split_dataset(tgt_ds)

# Limit training graphs for speed / scalability
SRC_LIMIT = 200
TGT_LIMIT = 200
if len(src_train) > SRC_LIMIT:
    src_train = src_train[:SRC_LIMIT]
if len(tgt_train) > TGT_LIMIT:
    tgt_train = tgt_train[:TGT_LIMIT]
print(f"Using {len(src_train)} source-train and {len(tgt_train)} target-train graphs (capped at 200 each).")

src_loader = DataLoader(src_train, batch_size=BATCH_SIZE, shuffle=True)
src_test_loader = DataLoader(src_test, batch_size=BATCH_SIZE, shuffle=False)
tgt_loader = DataLoader(tgt_train, batch_size=BATCH_SIZE, shuffle=True)
tgt_test_loader = DataLoader(tgt_test, batch_size=BATCH_SIZE, shuffle=False)

Source: NCI1 -> 4110 graphs, x_dim=9, classes=2
Target: NCI109 -> 4127 graphs, x_dim=9, classes=2
Using 200 source-train and 200 target-train graphs (capped at 200 each).


In [8]:
# ----------------------------
# 2) Define a minimal GCN classifier
# ----------------------------
class SimpleGCN(nn.Module):
    def __init__(self, in_channels, hidden, classes):
        super().__init__()
        # Two GCNConv layers for neighborhood aggregation
        self.conv1 = GCNConv(in_channels, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.lin = nn.Linear(hidden, classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_add_pool(x, batch)
        return self.lin(x)

model = SimpleGCN(in_dim, HIDDEN, num_classes).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [9]:
# ----------------------------
# 3) Train on SOURCE only
# ----------------------------
def train_epoch(loader):
    model.train()
    total = 0.0
    for data in loader:
        data = data.to(DEVICE)
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()
        total += loss.item() * data.num_graphs
    return total / len(loader.dataset)

for ep in range(1, EPOCHS + 1):
    loss = train_epoch(src_loader)
    if ep % 5 == 0:
        print(f"[Pretrain] Epoch {ep}/{EPOCHS}  loss={loss:.4f}")

print("Source-only pretraining done.")

[Pretrain] Epoch 5/200  loss=0.6421
[Pretrain] Epoch 10/200  loss=0.6314
[Pretrain] Epoch 15/200  loss=0.6310
[Pretrain] Epoch 20/200  loss=0.6311
[Pretrain] Epoch 25/200  loss=0.6278
[Pretrain] Epoch 30/200  loss=0.6270
[Pretrain] Epoch 35/200  loss=0.6208
[Pretrain] Epoch 40/200  loss=0.6260
[Pretrain] Epoch 45/200  loss=0.6192
[Pretrain] Epoch 50/200  loss=0.6209
[Pretrain] Epoch 55/200  loss=0.6171
[Pretrain] Epoch 60/200  loss=0.6187
[Pretrain] Epoch 65/200  loss=0.6245
[Pretrain] Epoch 70/200  loss=0.6195
[Pretrain] Epoch 75/200  loss=0.6127
[Pretrain] Epoch 80/200  loss=0.6109
[Pretrain] Epoch 85/200  loss=0.6202
[Pretrain] Epoch 90/200  loss=0.6065
[Pretrain] Epoch 95/200  loss=0.6064
[Pretrain] Epoch 100/200  loss=0.6187
[Pretrain] Epoch 105/200  loss=0.6048
[Pretrain] Epoch 110/200  loss=0.6052
[Pretrain] Epoch 115/200  loss=0.6022
[Pretrain] Epoch 120/200  loss=0.6110
[Pretrain] Epoch 125/200  loss=0.6004
[Pretrain] Epoch 130/200  loss=0.5973
[Pretrain] Epoch 135/200  loss=0

In [10]:
# ----------------------------
# 4a) OT-based adaptation with GW ground cost + GW-barycentric mapping
# ----------------------------
# Requirements: pip install POT networkx
import numpy as np
import networkx as nx
import ot  # Python Optimal Transport (POT)

GW_REG = None        # regularization for inner GW (None = classic GW)
SINKHORN_REG = 0.01   # entropic reg for outer OT over graphs
TOPK = 5             # for per-source barycentric mapping, use top-k targets by OT weights
BARY_NITERS = 40     # iterations for GW barycenter solver
KNN_K = 4            # edges per node when rebuilding graph from barycenter distances
INF_DIST = 1e6

def pyg_to_distmat(data):
    """Compute an all-pairs shortest-path distance matrix from a PyG Data graph."""
    g = nx.Graph()
    num_nodes = data.num_nodes
    g.add_nodes_from(range(num_nodes))
    # unweighted SSSP on existing edges
    ei = data.edge_index.cpu().numpy()
    for u, v in ei.T:
        if u != v:
            g.add_edge(int(u), int(v), weight=1.0)
    # ensure connectivity by adding very large distances where unreachable
    dist = np.full((num_nodes, num_nodes), INF_DIST, dtype=np.float64)
    np.fill_diagonal(dist, 0.0)
    lengths = dict(nx.all_pairs_shortest_path_length(g))
    for i, row in lengths.items():
        for j, d in row.items():
            dist[int(i), int(j)] = float(d)
    # symmetrize just in case
    dist = 0.5 * (dist + dist.T)
    return dist

def uniform_weights(n):
    w = np.full(n, 1.0 / n, dtype=np.float64)
    return w

# Precompute intra-graph distance matrices and node weights
src_Cs, src_as = [], []
for d in src_train:
    src_Cs.append(pyg_to_distmat(d))
    src_as.append(uniform_weights(d.num_nodes))

tgt_Cs, tgt_as = [], []
for d in tgt_train:
    tgt_Cs.append(pyg_to_distmat(d))
    tgt_as.append(uniform_weights(d.num_nodes))

# Pairwise GW^2 cost between graphs (outer ground cost)
ns, nt = len(src_train), len(tgt_train)
outer_cost = np.zeros((ns, nt), dtype=np.float64)
for i in range(ns):
    C1, a1 = src_Cs[i], src_as[i]
    for j in range(nt):
        print("inner ot ", (float(i*nt) + j) / float(nt * ns))
        C2, a2 = tgt_Cs[j], tgt_as[j]
        # squared GW loss
        # outer_cost[i, j] = ot.gromov.entropic_gromov_wasserstein2(C1, C2, a1, a2, 'square_loss', epsilon=0.1, max_iter=100, tol=1e-7)
        outer_cost[i, j] = ot.gromov.gromov_wasserstein2(
            C1, C2, a1, a2, 'square_loss', log=False, armijo=False, max_iter=100, tol=1e-7
        )

# Outer OT coupling between source-graphs and target-graphs (uniform masses)
mu = uniform_weights(ns)
nu = uniform_weights(nt)
# entropic OT on the outer ground cost to get a soft plan
T_outer = ot.sinkhorn(mu, nu, outer_cost, reg=SINKHORN_REG)

# --- Barycentric mapping: per source graph i, build a GW-barycenter of top-k target graphs
def knn_graph_from_dist(C, k=KNN_K):
    """Return undirected edge index from a distance matrix using symmetric kNN."""
    n = C.shape[0]
    # for each node, pick k smallest (excluding self)
    neigh = np.argsort(C + np.eye(n) * INF_DIST, axis=1)[:, :k]
    edges = set()
    for i in range(n):
        for j in neigh[i]:
            if i != j:
                a, b = (int(i), int(j))
                if a > b: a, b = b, a
                edges.add((a, b))
    # ensure symmetry
    edge_index = np.array(list(edges), dtype=np.int64).T
    if edge_index.size == 0:
        edge_index = np.zeros((2,0), dtype=np.int64)
    return edge_index

def degree_onehot(edge_index, num_nodes, max_degree=OH_DEGREE):
    deg = np.zeros((num_nodes,), dtype=np.int64)
    if edge_index.size > 0:
        for u, v in edge_index.T:
            deg[u] += 1
            deg[v] += 1
    deg = np.minimum(deg, max_degree)  # cap
    x = np.zeros((num_nodes, max_degree + 1), dtype=np.float32)
    x[np.arange(num_nodes), deg] = 1.0
    return torch.from_numpy(x)

from torch_geometric.data import Data

transported_src = []
for i, d_src in enumerate(src_train):
    print(float(i) / len(src_train))
    w = T_outer[i]  # weights over target graphs
    if w.sum() <= 0:
        transported_src.append(d_src)  # fallback
        continue
    # pick top-k targets for efficiency
    top_idx = np.argsort(-w)[:TOPK]
    lambdas = w[top_idx]
    lambdas = lambdas / lambdas.sum()

    # assemble inputs for GW barycenter
    Cs = [tgt_Cs[j] for j in top_idx]
    ps = [tgt_as[j] for j in top_idx]
    # choose barycenter node count = source node count (keeps labels aligned)
    nb = d_src.num_nodes

    Cb = ot.gromov.gromov_barycenters(
        N=nb,
        Cs=Cs,
        ps=ps,
        p=np.full(nb, 1.0 / nb, dtype=np.float64),  # barycenter node weights (|p| = nb)
        lambdas=lambdas,
        loss_fun='square_loss',
        max_iter=BARY_NITERS,
        tol=1e-6,
        verbose=False,
        random_state=0,
    )
    # rebuild an unweighted symmetric kNN graph from the barycenter distance matrix
    ei = knn_graph_from_dist(Cb, k=KNN_K)
    # build node features as degree one-hot (consistent with OneHotDegree(..., cat=False))
    x_b = degree_onehot(ei, nb, max_degree=OH_DEGREE)
    # label stays the same as the source graph label
    y_b = d_src.y.clone()
    data_b = Data(x=x_b.to(DEVICE), edge_index=torch.from_numpy(ei).long().to(DEVICE), y=y_b.to(DEVICE))
    transported_src.append(data_b)

# Fine-tune the classifier on transported source graphs, then evaluate on target test set
trans_loader = DataLoader(transported_src, batch_size=BATCH_SIZE, shuffle=True)

inner ot  0.0
inner ot  2.5e-05
inner ot  5e-05
inner ot  7.5e-05
inner ot  0.0001
inner ot  0.000125
inner ot  0.00015
inner ot  0.000175
inner ot  0.0002
inner ot  0.000225
inner ot  0.00025
inner ot  0.000275
inner ot  0.0003
inner ot  0.000325
inner ot  0.00035
inner ot  0.000375
inner ot  0.0004
inner ot  0.000425
inner ot  0.00045
inner ot  0.000475
inner ot  0.0005
inner ot  0.000525
inner ot  0.00055
inner ot  0.000575
inner ot  0.0006
inner ot  0.000625
inner ot  0.00065
inner ot  0.000675
inner ot  0.0007
inner ot  0.000725
inner ot  0.00075
inner ot  0.000775
inner ot  0.0008
inner ot  0.000825
inner ot  0.00085
inner ot  0.000875
inner ot  0.0009
inner ot  0.000925
inner ot  0.00095
inner ot  0.000975
inner ot  0.001
inner ot  0.001025
inner ot  0.00105
inner ot  0.001075
inner ot  0.0011
inner ot  0.001125
inner ot  0.00115
inner ot  0.001175
inner ot  0.0012
inner ot  0.001225
inner ot  0.00125
inner ot  0.001275
inner ot  0.0013
inner ot  0.001325
inner ot  0.00135
inner

  v = b / KtransposeU


0.01
0.015
0.02
0.025
0.03
0.035
0.04
0.045
0.05
0.055
0.06
0.065
0.07
0.075
0.08
0.085
0.09
0.095
0.1
0.105
0.11
0.115
0.12
0.125
0.13
0.135
0.14
0.145
0.15
0.155
0.16
0.165
0.17
0.175
0.18
0.185
0.19
0.195
0.2
0.205
0.21
0.215
0.22
0.225
0.23
0.235
0.24
0.245
0.25
0.255
0.26
0.265
0.27
0.275
0.28
0.285
0.29
0.295
0.3
0.305
0.31
0.315
0.32
0.325
0.33
0.335
0.34
0.345
0.35
0.355
0.36
0.365
0.37
0.375
0.38
0.385
0.39
0.395
0.4
0.405
0.41
0.415
0.42
0.425
0.43
0.435
0.44
0.445
0.45
0.455
0.46
0.465
0.47
0.475
0.48
0.485
0.49
0.495
0.5
0.505
0.51
0.515
0.52
0.525
0.53
0.535
0.54
0.545
0.55
0.555
0.56
0.565
0.57
0.575
0.58
0.585
0.59
0.595
0.6
0.605
0.61
0.615
0.62
0.625
0.63
0.635
0.64
0.645
0.65
0.655
0.66
0.665
0.67
0.675
0.68
0.685
0.69
0.695
0.7
0.705
0.71
0.715
0.72
0.725
0.73
0.735
0.74
0.745
0.75
0.755
0.76
0.765
0.77
0.775
0.78
0.785
0.79
0.795
0.8
0.805
0.81
0.815
0.82
0.825
0.83
0.835
0.84
0.845
0.85
0.855
0.86
0.865
0.87
0.875
0.88
0.885
0.89
0.895
0.9
0.905
0.91
0.915
0.92
0.9

In [12]:
# ----------------------------
# 5) Evaluation helpers
# ----------------------------
@torch.no_grad()
def eval_accuracy(loader):
    model.eval()
    all_y, all_pred = [], []
    for data in loader:
        data = data.to(DEVICE)
        logits = model(data)
        pred = logits.argmax(dim=1).cpu()
        all_pred.append(pred)
        all_y.append(data.y.cpu())
    y_true = torch.cat(all_y)
    y_pred = torch.cat(all_pred)
    return accuracy_score(y_true, y_pred)

# Before adaptation
tgt_acc_before = eval_accuracy(tgt_test_loader)
print(f"Target accuracy BEFORE adaptation ({SOURCE_NAME}â†’{TARGET_NAME}): {tgt_acc_before:.3f}")

Target accuracy BEFORE adaptation (NCI1â†’NCI109): 0.627


In [13]:
src_acc = eval_accuracy(src_loader)
print(f"Source accuracy:  {src_acc:.3f}")

Source accuracy:  0.685


In [14]:
tgt_acc_after = eval_accuracy(trans_loader)
print(f"Target accuracy AFTER adaptation ({SOURCE_NAME}â†’{TARGET_NAME}):  {tgt_acc_after:.3f}")

Target accuracy AFTER adaptation (NCI1â†’NCI109):  0.545
