In [110]:
import os
import json
import torch
from torch_geometric.data import Dataset, Data

class SingleFileEmbeddingPairDataset(Dataset):
    def __init__(self, json_path):
        super().__init__(os.path.dirname(json_path))
        # Load the JSON (which is a dict of model numbers)
        with open(json_path, 'r') as f:
            model_dict = json.load(f)
            # Flatten all model pairs into a list:
            self.all_pairs = []
            for model_num, pair in model_dict.items():
                # Optionally store model_num if you want it later: pair['model_num'] = model_num
                self.all_pairs.append(pair)

    def __len__(self):
        return len(self.all_pairs)

    def __getitem__(self, idx):
        pair = self.all_pairs[idx]

        # Extract embeddings for A and B
        A_ids = sorted(pair['A_embeddings'].keys(), key=lambda x: int(x))
        xA = torch.tensor([pair['A_embeddings'][xtid] for xtid in A_ids], dtype=torch.float)
        B_ids = sorted(pair['B_embeddings'].keys(), key=lambda x: int(x))
        xB = torch.tensor([pair['B_embeddings'][xtid] for xtid in B_ids], dtype=torch.float)

        # Dummy edge indices (empty for now)
        edge_indexA = torch.empty((2,0), dtype=torch.long)
        edge_indexB = torch.empty((2,0), dtype=torch.long)

        A_id_to_idx = {int(aid): idx for idx, aid in enumerate(A_ids)}
        B_id_to_idx = {int(bid): idx for idx, bid in enumerate(B_ids)}

        mappings = []
        for a, b in pair['mappings']:
            idx_a = -1 if a == 'NULL' else A_id_to_idx[int(a)]  
            idx_b = -1 if b == 'NULL' else B_id_to_idx[int(b)]
            mappings.append([idx_a, idx_b])

        matches = torch.tensor(mappings, dtype=torch.long)

        dataA = Data(x=xA, edge_index=edge_indexA, xt_entity_ids=A_ids)
        dataB = Data(x=xB, edge_index=edge_indexB, xt_entity_ids=B_ids)
        return dataA, dataB, matches

In [108]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class GraphEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, out_dim, dropout=0.3):
        super().__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.conv3 = SAGEConv(hidden_dim, out_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = F.relu(self.bn1(self.conv1(x, edge_index)))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.bn2(self.conv2(x, edge_index)))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv3(x, edge_index)
        return x


class SiameseGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, proj_dim=None, dropout=0.3):
        super().__init__()
        self.encoder = GraphEncoder(in_dim, hidden_dim, out_dim, dropout)
        
        # Optional projection MLP for better embedding alignment
        if proj_dim is not None:
            self.projector = nn.Sequential(
                nn.Linear(out_dim, proj_dim),
                nn.BatchNorm1d(proj_dim),
                nn.ReLU(),
                nn.Linear(proj_dim, proj_dim)
            )
            self.final_dim = proj_dim
        else:
            self.projector = None
            self.final_dim = out_dim

    def forward(self, data1, data2):
        h1 = self.encoder(data1.x, data1.edge_index)
        h2 = self.encoder(data2.x, data2.edge_index)

        if self.projector is not None:
            h1 = self.projector(h1)
            h2 = self.projector(h2)

        # Normalize embeddings for contrastive or cosine losses
        # h1 = F.normalize(h1, p=2, dim=-1)
        # h2 = F.normalize(h2, p=2, dim=-1)
        return h1, h2


In [111]:
import torch
import torch.nn.functional as F

# --- Core Contrastive Margin Loss ---
def contrastive_margin_loss(emb1, emb2, matches_pos, matches_neg, margin=1.0):
    """
    emb1: [N1, D]
    emb2: [N2, D]
    matches_pos: [K1, 2]  -> positive pairs (valid correspondences)
    matches_neg: [K2, 2]  -> negative pairs (non-correspondences)
    margin: margin distance for negatives
    """
    if matches_pos.numel() == 0:
        # No positives in this batch → zero loss but keep graph
        return (emb1.sum() * 0.0) + (emb2.sum() * 0.0)

    # ---- Positive loss (want small distances) ----
    pos1 = emb1[matches_pos[:, 0]]
    pos2 = emb2[matches_pos[:, 1]]
    pos_dist = (pos1 - pos2).pow(2).sum(dim=1)
    pos_loss = pos_dist.mean()

    # ---- Negative loss (want large distances > margin) ----
    if matches_neg.numel() == 0:
        neg_loss = torch.tensor(0.0, device=emb1.device)
    else:
        neg1 = emb1[matches_neg[:, 0]]
        neg2 = emb2[matches_neg[:, 1]]
        neg_dist = (neg1 - neg2).pow(2).sum(dim=1)
        neg_loss = F.relu(margin - neg_dist).mean()

    return pos_loss + neg_loss


# --- Random negative sampling ---
def get_random_negatives(matches_pos, emb1_size, emb2_size, num_neg=None):
    num_pos = matches_pos.size(0)
    num_neg = num_neg or num_pos

    rand_A = torch.randint(0, emb1_size, (num_neg,))
    rand_B = torch.randint(0, emb2_size, (num_neg,))
    mask = torch.ones(num_neg, dtype=torch.bool)
    pos_set = set(map(tuple, matches_pos.tolist()))
    for i in range(num_neg):
        if (rand_A[i].item(), rand_B[i].item()) in pos_set:
            mask[i] = False
    return torch.stack([rand_A[mask], rand_B[mask]], dim=1)


# --- Hard negative sampling ---
def get_hard_negatives(emb1, emb2, matches_pos, top_k=3):
    """
    Select hardest negatives (closest embeddings not in positive set)
    """
    with torch.no_grad():
        dist = torch.cdist(emb1, emb2, p=2)  # [N1, N2]
        pos_mask = torch.zeros_like(dist, dtype=torch.bool)
        pos_mask[matches_pos[:, 0], matches_pos[:, 1]] = True
        dist[pos_mask] = 1e6  # mask out positives

        # Select top-k smallest distances as negatives
        neg_indices = dist.topk(k=top_k, largest=False).indices  # [N1, top_k]
        matches_neg = []
        for i in range(neg_indices.size(0)):
            for j in neg_indices[i]:
                matches_neg.append([i, j.item()])
        matches_neg = torch.tensor(matches_neg, dtype=torch.long, device=emb1.device)
    return matches_neg


# --- Mixed negative sampling (recommended) ---
def get_mixed_negatives(emb1, emb2, matches_pos, ratio=0.5, top_k=3):
    """
    Combines random + hard negatives
    """
    hard_negs = get_hard_negatives(emb1, emb2, matches_pos, top_k)
    rand_negs = get_random_negatives(matches_pos, emb1.size(0), emb2.size(0))
    num_hard = int(ratio * min(len(hard_negs), len(rand_negs)))
    if num_hard == 0:
        return rand_negs
    mixed = torch.cat([hard_negs[:num_hard], rand_negs[:num_hard]], dim=0)
    return mixed



In [82]:
def batch_accuracy(emb1, emb2, matches):
    # emb1: [N1, D], emb2: [N2, D], matches: [K,2]
    valid = matches[:,1] != -1
    indices1 = matches[valid, 0]
    indices2 = matches[valid, 1]

    if len(indices1) == 0:
        return None  # avoid NaN

    dists = torch.cdist(emb1, emb2)  # [N1, N2]
    # For each ground-truth emb1 index, which emb2 is closest
    min_indices = torch.argmin(dists[indices1], dim=1)
    correct = (min_indices == indices2).float()
    return correct.mean().item()

In [81]:
def batch_topk_accuracy(emb1, emb2, matches, k=5):
    valid = matches[:,1] != -1
    indices1 = matches[valid, 0]
    indices2 = matches[valid, 1]

    if len(indices1) == 0:
        return None

    dists = torch.cdist(emb1, emb2)

    actual_k = min(k, emb2.shape[0])
    if actual_k == 0:
        return None  # no candidates

    topk = torch.topk(-dists[indices1], actual_k, dim=1).indices
    correct = torch.any(topk == indices2.unsqueeze(1), dim=1).float()
    return correct.mean().item()

In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

#  Load dataset
dataset = SingleFileEmbeddingPairDataset(
    "C:\\Users\\Z0054udc\\Downloads\\Siamese GNN\\XT_merged_doubled.json"
)
print(f"Found {len(dataset)} dataset entries.")
if len(dataset) == 0:
    raise ValueError("ERROR: No JSON files found!")

#  Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

#  Model and optimizer
model = SiameseGNN(in_dim=15, hidden_dim=64, out_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

#  Training hyperparameters
epochs = 150
grad_accum_steps = 8        # Simulates batch size = 8
loss_history, top1_history, top5_history = [], [], []

#  Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    total_top1 = 0.0
    total_top5 = 0.0
    num_acc = 0
    optimizer.zero_grad()

    for i in tqdm(range(len(dataset)), desc=f"Epoch {epoch+1}"):
        data1, data2, matches = dataset[i]
        data1, data2, matches = data1.to(device), data2.to(device), matches.to(device)

        # Forward pass for this pair
        emb1, emb2 = model(data1, data2)
        matches_pos = matches[(matches[:, 0] != -1) & (matches[:, 1] != -1)]
        matches_neg = get_mixed_negatives(emb1, emb2, matches_pos, ratio=0.5, top_k=3)
        loss = contrastive_margin_loss(emb1, emb2, matches_pos, matches_neg)
        loss.backward()

        total_loss += loss.item()

        # --- Accuracy tracking
        acc1 = batch_accuracy(emb1, emb2, matches)
        acc5 = batch_topk_accuracy(emb1, emb2, matches, k=5)
        if acc1 is not None and acc5 is not None:
            total_top1 += acc1
            total_top5 += acc5
            num_acc += 1

        # --- Gradient accumulation ---
        if (i + 1) % grad_accum_steps == 0 or (i + 1) == len(dataset):
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            optimizer.zero_grad()

    # ---- Epoch summary ----
    avg_loss = total_loss / len(dataset)
    avg_top1 = (total_top1 / num_acc) * 100 if num_acc > 0 else 0.0
    avg_top5 = (total_top5 / num_acc) * 100 if num_acc > 0 else 0.0

    loss_history.append(avg_loss)
    top1_history.append(avg_top1)
    top5_history.append(avg_top5)

    print(f"Epoch {epoch+1:03d}: Loss = {avg_loss:.4f}, Top-1 = {avg_top1:.2f}%, Top-5 = {avg_top5:.2f}%")

# ✅ Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(loss_history, label="Loss", color='firebrick')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Contrastive Loss Curve")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# ✅ Plot accuracy curves
plt.figure(figsize=(10, 5))
plt.plot(top1_history, label="Top-1 Accuracy (%)", color='royalblue')
plt.plot(top5_history, label="Top-5 Accuracy (%)", linestyle="--", color='seagreen')
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Accuracy Over Training")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


In [65]:
torch.save(model.state_dict(), "siamese_gnn.pth")
# To load: model.load_state_dict(torch.load("siamese_gnn.pth"))