
# SiameseBERT Training Notebook

This notebook trains a Siamese BERT model using cosine embedding loss on a dataset
of tokenized issue pairs. It is the notebook version of your original Python script.


In [None]:

# Mount Google Drive and set CS588 as the main directory (for Colab)

from google.colab import drive
import os

drive.mount('/content/drive')

# Adjust this path if your CS588 folder is elsewhere in Drive
BASE_DIR = '/content/drive/MyDrive/CS588'
os.chdir(BASE_DIR)
print("Current working directory:", os.getcwd())

Mounted at /content/drive
Current working directory: /content/drive/.shortcut-targets-by-id/1AX1-1KC_cgwCqo-A4ecfxoHQWVfm-yV1/CS588


# this cell for baselines.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
import csv
from sklearn.metrics import f1_score, recall_score, precision_score, confusion_matrix, ConfusionMatrixDisplay
from datetime import datetime
import numpy as np

from dataset import TokenizedDataset
from sbert import SiameseBERT


import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
import csv
from sklearn.metrics import f1_score, recall_score, precision_score, confusion_matrix, ConfusionMatrixDisplay
from datetime import datetime
import numpy as np

from dataset import TokenizedDataset
from sbert import SiameseBERT



def train(
    csv_path: str,
    model_name: str = "bert-base-uncased",
    batch_size: int = 1024,
    epochs: int = 3,
    lr: float = 2e-5,
    dataset_name: str = "thunderbird",

):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
    print(f"Using device: {device}")

    # 1. Load Dataset
    print(f"Loading dataset from {csv_path}...")
    dataset = TokenizedDataset(csv_path)
    print(f"Dataset size: {len(dataset)}")

    # Split into Train (60%), Validation (20%), Test (20%)
    total_len = len(dataset)
    train_len = int(0.6 * total_len)
    val_len = int(0.2 * total_len)
    test_len = total_len - train_len - val_len

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])

    print(f"Train size: {len(train_dataset)} | Validation size: {len(val_dataset)} | Test size: {len(test_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)} | Test batches: {len(test_loader)}")

    # 2. Initialize Model
    print(f"Initializing SiameseBERT with {model_name}...")
    model = SiameseBERT(model_name=model_name)
    model.to(device)

    # 3. Setup Training
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = nn.CosineEmbeddingLoss(margin=0.5)

    model.train()

    epoch_losses = []
    epoch_accuracies = []
    epoch_f1s = []
    epoch_recalls = []
    epoch_precisions = []

    val_losses = []
    val_accuracies = []
    val_f1s = []
    val_recalls = []
    val_precisions = []

    # === store per-iteration loss here ===
    step_losses = []             # list of (global_step, loss_value)
    global_step = 0

    embedding_dict = {}

    # 0. Create save directory with date (and time) ONCE
    run_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")  # e.g. 2025-12-09_13-45-02
    save_root = os.path.join("checkpoints", dataset_name)          # or any base folder you want
    save_dir = os.path.join(save_root, run_timestamp)
    os.makedirs(save_dir, exist_ok=True)
    print(f"Checkpoints and figures will be saved under: {save_dir}")

    # optional: CSV log for each iteration
    iter_log_path = "train_step_losses.csv"   # stays in CWD unless you also want this in save_dir
    last_all_labels = None
    last_all_preds = None

    with open(iter_log_path, "w", newline="") as f_log:
        writer = csv.writer(f_log)
        writer.writerow(["global_step", "epoch", "batch_idx", "loss"])

        threshold = 0.5  # for accuracy from cosine similarity

        for epoch in range(epochs):
            print(f"\nEpoch {epoch + 1}/{epochs}")
            total_loss = 0.0
            total_correct = 0
            total_examples = 0

            all_preds = []
            all_labels = []

            # Training Loop
            model.train()
            progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training")

            for batch_idx, batch in progress_bar:
                global_step += 1

                # Move batch to device
                input_ids1 = batch['input_ids1'].to(device)
                attention_mask1 = batch['attention_mask1'].to(device)
                input_ids2 = batch['input_ids2'].to(device)
                attention_mask2 = batch['attention_mask2'].to(device)
                issue_id1 = batch["issue_id1"]
                issue_id2 = batch["issue_id2"]

                labels01 = batch['label'].to(device).float()  # {0,1}

                # For CosineEmbeddingLoss: targets in {1, -1}
                targets = 2 * labels01 - 1.0  # 0 -> -1, 1 -> +1

                optimizer.zero_grad()

                # Forward pass
                embeddings1, embeddings2 = model(
                    input_ids1, attention_mask1,
                    input_ids2, attention_mask2
                )

                if issue_id1 is not None and issue_id2 is not None:
                    # Optimize: Batch processing for CLS tokens
                    # Note: This runs BERT a second time, but in batch mode (faster than loop)
                    # Ideally, modify model.forward to return CLS tokens to avoid re-computation
                    cls1 = model.create_embedding(input_ids1, attention_mask1)
                    cls2 = model.create_embedding(input_ids2, attention_mask2)

                    for i in range(len(issue_id1)):
                        embedding_dict[issue_id1[i].item()] = cls1[i].detach().cpu()
                        embedding_dict[issue_id2[i].item()] = cls2[i].detach().cpu()

                loss = criterion(embeddings1, embeddings2, targets)

                loss.backward()
                optimizer.step()

                loss_value = loss.item()
                total_loss += loss_value

                # ---- Accuracy computation ----
                with torch.no_grad():
                    cos_sim = F.cosine_similarity(embeddings1, embeddings2)  # [B]
                    preds = (cos_sim > threshold).float()                   # {0,1}
                    correct = (preds == labels01).sum().item()
                    total_correct += correct
                    total_examples += labels01.numel()
                    batch_acc = correct / labels01.numel()

                    preds_np = preds.cpu().numpy()
                    labels_np = labels01.cpu().numpy()

                    all_preds.extend(preds_np)
                    all_labels.extend(labels_np)

                    batch_f1 = f1_score(labels_np, preds_np, zero_division=0)

                # log per-iteration loss
                step_losses.append((global_step, loss_value))
                writer.writerow([global_step, epoch + 1, batch_idx, loss_value])

                progress_bar.set_postfix({
                    'loss': f"{loss_value:.4f}",
                    'acc': f"{batch_acc:.3f}",
                    'f1': f"{batch_f1:.3f}"
                })

            avg_loss = total_loss / len(train_loader)
            avg_acc = total_correct / total_examples

            epoch_f1 = f1_score(all_labels, all_preds)
            epoch_recall = recall_score(all_labels, all_preds)
            epoch_precision = precision_score(all_labels, all_preds)

            print(f"Train | Loss: {avg_loss:.4f} | Acc: {avg_acc:.4f} | F1: {epoch_f1:.4f} | Rec: {epoch_recall:.4f} | Prec: {epoch_precision:.4f}")

            epoch_losses.append(avg_loss)
            epoch_accuracies.append(avg_acc)
            epoch_f1s.append(epoch_f1)
            epoch_recalls.append(epoch_recall)
            epoch_precisions.append(epoch_precision)

            # Validation Loop
            model.eval()
            val_loss = 0.0
            val_all_sims = []
            val_all_labels = []

            # Open a file to log sample comparisons for this epoch
            sample_log_path = os.path.join(save_dir, f"val_samples_epoch_{epoch+1}.txt")
            with open(sample_log_path, "w") as f_sample:
                f_sample.write("Issue1_ID\tIssue2_ID\tLabel\tSimilarity\n")

                with torch.no_grad():
                    for batch in tqdm(val_loader, desc="Validation"):
                        input_ids1 = batch['input_ids1'].to(device)
                        attention_mask1 = batch['attention_mask1'].to(device)
                        input_ids2 = batch['input_ids2'].to(device)
                        attention_mask2 = batch['attention_mask2'].to(device)
                        labels01 = batch['label'].to(device).float()
                        targets = 2 * labels01 - 1.0

                        # Get issue IDs for logging
                        val_issue_id1 = batch.get("issue_id1")
                        val_issue_id2 = batch.get("issue_id2")

                        embeddings1, embeddings2 = model(input_ids1, attention_mask1, input_ids2, attention_mask2)
                        loss = criterion(embeddings1, embeddings2, targets)
                        val_loss += loss.item()


                        if val_issue_id1 is not None and val_issue_id2 is not None:
                        # Optimize: Batch processing for CLS tokens
                                # Note: This runs BERT a second time, but in batch mode (faster than loop)
                                # Ideally, modify model.forward to return CLS tokens to avoid re-computation
                                cls1 = model.create_embedding(input_ids1, attention_mask1)
                                cls2 = model.create_embedding(input_ids2, attention_mask2)

                                for i in range(len(val_issue_id1)):
                                    embedding_dict[val_issue_id1[i].item()] = cls1[i].detach().cpu()
                                    embedding_dict[val_issue_id2[i].item()] = cls2[i].detach().cpu()


                        cos_sim = F.cosine_similarity(embeddings1, embeddings2)

                        val_all_sims.extend(cos_sim.cpu().numpy())
                        val_all_labels.extend(labels01.cpu().numpy())

                        # Log samples
                        if val_issue_id1 is not None and val_issue_id2 is not None:
                            sims_np = cos_sim.cpu().numpy()
                            labels_np = labels01.cpu().numpy()
                            ids1_np = val_issue_id1.numpy() if isinstance(val_issue_id1, torch.Tensor) else val_issue_id1
                            ids2_np = val_issue_id2.numpy() if isinstance(val_issue_id2, torch.Tensor) else val_issue_id2

                            for i in range(len(sims_np)):
                                f_sample.write(f"{ids1_np[i]}\t{ids2_np[i]}\t{labels_np[i]}\t{sims_np[i]:.4f}\n")

            print(f"Validation samples logged to {sample_log_path}")
            avg_val_loss = val_loss / len(val_loader)

            # Adaptive Thresholding
            thresholds = np.arange(0.3, 0.95, 0.05)
            best_f1 = -1.0
            best_thresh = 0.5
            best_metrics = {}

            val_sims_np = np.array(val_all_sims)
            val_labels_np = np.array(val_all_labels)

            for th in thresholds:
                preds = (val_sims_np >= th).astype(int)
                f1 = f1_score(val_labels_np, preds, zero_division=0)

                if f1 > best_f1:
                    best_f1 = f1
                    best_thresh = th
                    best_metrics = {
                        'acc': (preds == val_labels_np).mean(),
                        'rec': recall_score(val_labels_np, preds, zero_division=0),
                        'prec': precision_score(val_labels_np, preds, zero_division=0)
                    }

            print(f"Val   | Best Thresh: {best_thresh:.2f} | Loss: {avg_val_loss:.4f} | Acc: {best_metrics['acc']:.4f} | F1: {best_f1:.4f} | Rec: {best_metrics['rec']:.4f} | Prec: {best_metrics['prec']:.4f}")

            val_losses.append(avg_val_loss)
            val_accuracies.append(best_metrics['acc'])
            val_f1s.append(best_f1)
            val_recalls.append(best_metrics['rec'])
            val_precisions.append(best_metrics['prec'])

            # keep last epoch's preds/labels (using best threshold)
            last_all_labels = np.array(all_labels)
            last_all_preds = np.array(all_preds)

            # Save checkpoint
            ckpt_path = os.path.join(save_dir, f"sbert_epoch_{epoch + 1}.pth")
            torch.save(model.state_dict(), ckpt_path)
            print(f"Model saved to {ckpt_path}")

    print(f"Per-iteration losses saved to {iter_log_path}")

    # === Confusion matrix for LAST epoch only ===
    if last_all_labels is not None and last_all_preds is not None:
        cm = confusion_matrix(last_all_labels, last_all_preds, labels=[0.0, 1.0])
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])

        fig, ax = plt.subplots(figsize=(5, 5))
        disp.plot(ax=ax, values_format="d")
        ax.set_title(f"Confusion Matrix - Last Epoch (Epoch {epochs})")

        cm_path = os.path.join(save_dir, "confusion_matrix_last_epoch.png")
        fig.savefig(cm_path, bbox_inches="tight")
        plt.close(fig)
        print(f"Saved confusion matrix for last epoch to {cm_path}")
    else:
        print("Warning: No labels/preds collected; confusion matrix not generated.")

    # Save embeddings
    emb_path = os.path.join(save_dir, "embedding_dict.pt")
    print(f"Saving {len(embedding_dict)} embeddings to {emb_path}...")
    torch.save(embedding_dict, emb_path)
    print("Done saving embeddings.")

    # Plotting epoch-level loss, accuracy, F1, precision, recall
    epochs_range = range(1, epochs + 1)
    plt.figure(figsize=(12, 10))

    # --- Loss & Accuracy ---
    plt.subplot(2, 1, 1)
    plt.plot(epochs_range, epoch_losses, marker='o', linestyle='-', label='Train Loss')
    plt.plot(epochs_range, val_losses, marker='x', linestyle='-', label='Val Loss')
    plt.plot(epochs_range, epoch_accuracies, marker='s', linestyle='--', label='Train Accuracy')
    plt.plot(epochs_range, val_accuracies, marker='d', linestyle='--', label='Val Accuracy')
    plt.title('Training & Validation Loss & Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.grid(True)
    plt.legend()

    # --- F1, Precision, Recall ---
    plt.subplot(2, 1, 2)
    # F1
    plt.plot(epochs_range, epoch_f1s, marker='^', linestyle='-', label='Train F1')
    plt.plot(epochs_range, val_f1s, marker='v', linestyle='-', label='Val F1')
    # Recall
    plt.plot(epochs_range, epoch_recalls, marker='o', linestyle='-', label='Train Recall')
    plt.plot(epochs_range, val_recalls, marker='o', linestyle='--', label='Val Recall')
    # Precision
    plt.plot(epochs_range, epoch_precisions, marker='s', linestyle='-', label='Train Precision')
    plt.plot(epochs_range, val_precisions, marker='s', linestyle='--', label='Val Precision')

    plt.title('Training & Validation F1 / Precision / Recall')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    loss_curve_path = os.path.join(save_dir, "loss_acc_pr_rc_f1_curve.png")
    plt.savefig(loss_curve_path)
    plt.close()
    print(f"Loss, accuracy, precision, recall, F1 curves saved to {loss_curve_path}")

    # === Test Evaluation ===
    print("\n" + "="*30)
    print(f"TEST EVALUATION (Threshold={best_thresh:.2f})")
    print("="*30)

    model.eval()
    test_loss = 0.0
    test_all_preds = []
    test_all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            input_ids1 = batch['input_ids1'].to(device)
            attention_mask1 = batch['attention_mask1'].to(device)
            input_ids2 = batch['input_ids2'].to(device)
            attention_mask2 = batch['attention_mask2'].to(device)
            labels01 = batch['label'].to(device).float()
            targets = 2 * labels01 - 1.0

            embeddings1, embeddings2 = model(input_ids1, attention_mask1, input_ids2, attention_mask2)
            loss = criterion(embeddings1, embeddings2, targets)
            test_loss += loss.item()

            cos_sim = F.cosine_similarity(embeddings1, embeddings2)
            preds = (cos_sim >= best_thresh).float()

            test_all_preds.extend(preds.cpu().numpy())
            test_all_labels.extend(labels01.cpu().numpy())

    avg_test_loss = test_loss / len(test_loader)
    test_acc = (np.array(test_all_preds) == np.array(test_all_labels)).mean()
    test_f1 = f1_score(test_all_labels, test_all_preds, zero_division=0)
    test_rec = recall_score(test_all_labels, test_all_preds, zero_division=0)
    test_prec = precision_score(test_all_labels, test_all_preds, zero_division=0)

    print(f"Test  | Loss: {avg_test_loss:.4f} | Acc: {test_acc:.4f} | F1: {test_f1:.4f} | Rec: {test_rec:.4f} | Prec: {test_prec:.4f}")

    # Test Confusion Matrix
    cm_test = confusion_matrix(test_all_labels, test_all_preds, labels=[0.0, 1.0])
    print("\nTest Confusion Matrix:")
    print(cm_test)

    # Plot Test Confusion Matrix
    disp_test = ConfusionMatrixDisplay(confusion_matrix=cm_test, display_labels=[0, 1])
    fig_test, ax_test = plt.subplots(figsize=(5, 5))
    disp_test.plot(ax=ax_test, values_format="d", cmap='Blues')
    ax_test.set_title(f"Test Confusion Matrix (Threshold={best_thresh:.2f})")

    cm_test_path = os.path.join(save_dir, "confusion_matrix_test.png")
    fig_test.savefig(cm_test_path, bbox_inches="tight")
    plt.close(fig_test)
    print(f"Saved test confusion matrix to {cm_test_path}")

    # Save Test Results
    test_results_path = os.path.join(save_dir, "test_results.txt")
    with open(test_results_path, "w") as f:
        f.write(f"Model Name: {model_name}\n")
        f.write(f"Test Evaluation (Threshold={best_thresh:.2f})\n")
        f.write(f"Loss: {avg_test_loss:.4f}\n")
        f.write(f"Accuracy: {test_acc:.4f}\n")
        f.write(f"F1 Score: {test_f1:.4f}\n")
        f.write(f"Recall: {test_rec:.4f}\n")
        f.write(f"Precision: {test_prec:.4f}\n")
        f.write("\nConfusion Matrix:\n")
        f.write(str(cm_test))
    print(f"Test results saved to {test_results_path}")


In [None]:
baseline = "codebert-base"

if baseline == "codebert-base":
  model_name = "microsoft/codebert-base"
elif baseline == "bert-base-uncased":
  model_name = "bert-base-uncased"
elif baseline == "roberta-base":
  model_name = "roberta-base"

# Adjust paths according to your project structure
relative_path = "eclipse"
dataset_dir = f"datasets/{relative_path}/"
cur_dir = os.getcwd()
csv_path = os.path.join(cur_dir, dataset_dir, f"tokenized_pairs_train_{baseline}_50000.csv")

print("CSV path:", csv_path)

if os.path.exists(csv_path):
    # You can change batch_size, epochs, lr here if needed
    train(csv_path, batch_size=2,epochs=5, dataset_name=relative_path, model_name=f"{model_name}")
else:
    print(f"Dataset not found at {csv_path}")


CSV path: /content/drive/.shortcut-targets-by-id/1AX1-1KC_cgwCqo-A4ecfxoHQWVfm-yV1/CS588/datasets/eclipse/tokenized_pairs_train_codebert-base_50000.csv
Using device: cuda
Loading dataset from /content/drive/.shortcut-targets-by-id/1AX1-1KC_cgwCqo-A4ecfxoHQWVfm-yV1/CS588/datasets/eclipse/tokenized_pairs_train_codebert-base_50000.csv...
Dataset size: 50000
Train size: 30000 | Validation size: 10000 | Test size: 10000
Train batches: 15000 | Val batches: 5000 | Test batches: 5000
Initializing SiameseBERT with microsoft/codebert-base...
Checkpoints and figures will be saved under: checkpoints/eclipse/2025-12-25_08-47-55

Epoch 1/5


Training:   1%|          | 126/15000 [00:11<22:17, 11.12it/s, loss=0.2339, acc=0.500, f1=0.667]


KeyboardInterrupt: 

In [None]:
import torch
import os

embedding_dict_path = "/content/drive/MyDrive/CS588/checkpoints/2025-12-09_17-22-30/embedding_dict.pt"

if os.path.exists(embedding_dict_path):
    print(f"Loading embedding dictionary from {embedding_dict_path}...")
    try:
        embedding_dict = torch.load(embedding_dict_path)
        print(f"Embedding dictionary loaded successfully.")
        print(f"Total number of embeddings: {len(embedding_dict)}")

        if len(embedding_dict) > 0:
            print("\n--- First 5 Embedding Keys and Shapes ---")
            for i, (key, embedding) in enumerate(embedding_dict.items()):
                if i >= 5: # Only print first 5 for brevity
                    break
                if hasattr(embedding, 'shape'):
                    print(f"Key: {key}, Embedding Shape: {embedding.shape}")
                elif isinstance(embedding, torch.Tensor):
                    print(f"Key: {key}, Embedding Shape: {embedding.shape}")
                else:
                    print(f"Key: {key}, Could not determine shape.")
            if len(embedding_dict) > 5:
                print(f"... and {len(embedding_dict) - 5} more entries.")
        else:
            print("The embedding dictionary is empty.")

    except Exception as e:
        print(f"Error loading embedding_dict.pt: {e}")
        print("The file might be corrupted or incomplete due to the interruption.")
else:
    print(f"Embedding dictionary file not found at {embedding_dict_path}.")
    print("This might be because the training was interrupted before the dictionary was saved.")

**This part for sbertgcn**

Ours!

In [None]:
# ============================================================
# SBERT + GCN Training Script (Colab-friendly)
# - TQDM shows per-batch: Loss, F1, Acc
# - At the end: epoch-level plots for Loss/F1/Acc/Precision/Recall
# - Saves test confusion matrix
# ============================================================

!pip install torch-geometric  # Uncomment if you really need PyG for your gnn2.py

import os
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from scipy import sparse
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

# Local Imports
from dataset import TokenizedDataset
from sbert import SiameseBERT
from gnn3 import SBERTGCN, normalize_edge_index


# ---------------------------
# Device
# ---------------------------
def _device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


# ---------------------------
# Graph loading
# ---------------------------
def _load_graph(dataset_dir: str, device: torch.device):
    adj = sparse.load_npz(os.path.join(dataset_dir, "graph_adj.npz")).tocsr()
    node_ids = np.load(os.path.join(dataset_dir, "graph_adj_node_ids.npy"))
    coo = adj.tocoo()

    edge_index = torch.from_numpy(
        np.vstack([coo.row, coo.col]).astype(np.int64)
    ).to(device)

    # normalize_edge_index is assumed to return (edge_index, edge_weight)
    edge_index, edge_weight = normalize_edge_index(edge_index, len(node_ids))
    edge_index = edge_index.to(device)
    edge_weight = edge_weight.to(device)

    return (edge_index, edge_weight), node_ids


# ---------------------------
# Batch pair -> graph indices
# ---------------------------
def _pairs_to_indices(batch, id_to_idx, device):
    idx1, idx2, labs = [], [], []

    for a, b, y in zip(batch["issue_id1"], batch["issue_id2"], batch["label"]):
        a_id, b_id = int(a), int(b)
        if a_id in id_to_idx and b_id in id_to_idx:
            idx1.append(id_to_idx[a_id])
            idx2.append(id_to_idx[b_id])
            labs.append(float(y))

    if not idx1:
        return None, None, None

    return (
        torch.tensor(idx1, device=device),
        torch.tensor(idx2, device=device),
        torch.tensor(labs, device=device),
    )


# ---------------------------
# Threshold selection on VAL
# ---------------------------
def _get_best_threshold(sims: np.ndarray, labs: np.ndarray) -> float:
    thresholds = np.linspace(0.1, 0.9, 17)
    best_f1, best_th = 0.0, 0.5
    for th in thresholds:
        preds = (sims >= th).astype(int)
        f1 = f1_score(labs, preds, zero_division=0)
        if f1 > best_f1:
            best_f1, best_th = f1, th
    return float(best_th)


# ---------------------------
# Epoch runner
# - tqdm: per-batch Loss/F1/Acc only
# - returns epoch metrics: loss/acc/f1/precision/recall + sims/labs
# ---------------------------
def _run_epoch(
    loader,
    node_emb,
    model,
    edge_index,
    edge_weight,
    id_to_idx,
    device,
    criterion,
    optimizer=None,
    thresh=0.5,
    desc="Run",
):
    is_train = optimizer is not None
    model.train(is_train)
    node_emb.train(is_train)

    total_loss = 0.0
    seen_batches = 0
    all_sims, all_labs = [], []

    pbar = tqdm(loader, desc=desc, leave=False)

    for batch in pbar:
        idx, idy, labs_batch = _pairs_to_indices(batch, id_to_idx, device)
        if idx is None:
            continue

        ids1 = batch["input_ids1"].to(device)
        m1 = batch["attention_mask1"].to(device)
        ids2 = batch["input_ids2"].to(device)
        m2 = batch["attention_mask2"].to(device)

        with torch.set_grad_enabled(is_train):
            out1, out2 = model(
                node_emb,
                idx,
                idy,
                ids1,
                m1,
                ids2,
                m2,
                edge_index,
                edge_weight,
            )
            out1 = F.normalize(out1, dim=1)
            out2 = F.normalize(out2, dim=1)

            loss = criterion(out1, out2, (2 * labs_batch - 1.0))

            if is_train:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()

        sims = F.cosine_similarity(out1, out2).detach().cpu().numpy()
        labs = labs_batch.detach().cpu().numpy().astype(int)
        preds = (sims >= thresh).astype(int)

        b_loss = float(loss.item())
        b_f1 = f1_score(labs, preds, zero_division=0)
        b_acc = float((labs == preds).mean())

        pbar.set_postfix({"L": f"{b_loss:.3f}", "F1": f"{b_f1:.3f}", "Acc": f"{b_acc:.3f}"})

        total_loss += b_loss
        seen_batches += 1
        all_sims.extend(sims.tolist())
        all_labs.extend(labs.tolist())

    all_sims = np.asarray(all_sims, dtype=np.float32)
    all_labs = np.asarray(all_labs, dtype=np.int32)

    if all_labs.size == 0:
        return {
            "loss": 0.0,
            "acc": 0.0,
            "f1": 0.0,
            "precision": 0.0,
            "recall": 0.0,
            "sims": all_sims,
            "labs": all_labs,
        }

    preds_all = (all_sims >= thresh).astype(int)

    return {
        "loss": total_loss / max(1, seen_batches),
        "acc": float((preds_all == all_labs).mean()),
        "f1": f1_score(all_labs, preds_all, zero_division=0),
        "precision": precision_score(all_labs, preds_all, zero_division=0),
        "recall": recall_score(all_labs, preds_all, zero_division=0),
        "sims": all_sims,
        "labs": all_labs,
    }


# ---------------------------
# Plot epoch-level metrics (end only)
# ---------------------------
def _save_epoch_plots(epoch_hist, ckpt_dir, dataset_name=""):
    metrics = ["loss", "f1", "acc", "precision", "recall"]
    epochs = np.arange(1, len(epoch_hist["train"]["loss"]) + 1)

    for m in metrics:
        plt.figure(figsize=(10, 4))
        plt.plot(epochs, epoch_hist["train"][m], label=f"Train {m.upper()}")
        plt.plot(epochs, epoch_hist["val"][m], label=f"Val {m.upper()}")
        plt.title(f"Epoch-level {m.upper()} - {dataset_name}")
        plt.xlabel("Epoch")
        plt.ylabel(m.upper())
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(ckpt_dir, f"epoch_{m}.png"))
        plt.close()


# ---------------------------
# Confusion matrix plot (end only)
# ---------------------------
def _save_confusion_matrix(ts_m, ckpt_dir, test_th=0.5):
    preds = (ts_m["sims"] >= test_th).astype(int)
    cm = confusion_matrix(ts_m["labs"], preds)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm)
    plt.title(f"Test Confusion Matrix (th={test_th:.2f})")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.xticks([0, 1], ["Unique", "Duplicate"])
    plt.yticks([0, 1], ["Unique", "Duplicate"])

    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")

    plt.tight_layout()
    plt.savefig(os.path.join(ckpt_dir, "test_confusion_matrix.png"))
    plt.close()


# ---------------------------
# Main training pipeline
# ---------------------------
def train_bertgnn(csv_path, dataset_dir, batch_size=16, epochs=5, lr=2e-5):
    device = _device()
    dataset_name = os.path.basename(os.path.normpath(dataset_dir))
    run_ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    ckpt_dir = os.path.join("checkpoints", dataset_name, f"run_{run_ts}")
    os.makedirs(ckpt_dir, exist_ok=True)
    print(f"Dataset: {dataset_name} | Artifacts: {ckpt_dir}")

    (edge_index, edge_weight), node_ids = _load_graph(dataset_dir, device)
    id_to_idx = {int(nid): i for i, nid in enumerate(node_ids)}

    dataset = TokenizedDataset(csv_path)
    n = len(dataset)

    n_train = int(0.6 * n)
    n_val = int(0.2 * n)
    n_test = n - n_train - n_val

    train_ds, val_ds, test_ds = random_split(dataset, [n_train, n_val, n_test])

    loaders = {
        "train": DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        "val": DataLoader(val_ds, batch_size=batch_size, shuffle=False),
        "test": DataLoader(test_ds, batch_size=batch_size, shuffle=False),
    }

    node_emb = nn.Embedding(len(node_ids), 128).to(device)
    model = SBERTGCN(SiameseBERT().to(device), len(node_ids), 128, 128, l=0.5).to(device)

    optimizer = AdamW(list(model.parameters()) + list(node_emb.parameters()), lr=lr)
    criterion = nn.CosineEmbeddingLoss(margin=0.5)

    best_val_f1 = 0.0
    locked_th = 0.5

    epoch_hist = {
        "train": {k: [] for k in ["loss", "f1", "acc", "precision", "recall"]},
        "val": {k: [] for k in ["loss", "f1", "acc", "precision", "recall"]},
    }

    for epoch in range(epochs):
        tr_m = _run_epoch(
            loaders["train"],
            node_emb,
            model,
            edge_index,
            edge_weight,
            id_to_idx,
            device,
            criterion,
            optimizer=optimizer,
            thresh=locked_th,
            desc=f"Ep {epoch+1} Train",
        )

        val_m = _run_epoch(
            loaders["val"],
            node_emb,
            model,
            edge_index,
            edge_weight,
            id_to_idx,
            device,
            criterion,
            optimizer=None,
            thresh=locked_th,
            desc=f"Ep {epoch+1} Val",
        )

        for k in ["loss", "f1", "acc", "precision", "recall"]:
            epoch_hist["train"][k].append(tr_m[k])
            epoch_hist["val"][k].append(val_m[k])

        epoch_best_th = _get_best_threshold(val_m["sims"], val_m["labs"])

        print(
            f"Epoch {epoch+1}/{epochs} | "
            f"Val F1: {val_m['f1']:.4f} | Val Prec: {val_m['precision']:.4f} | Val Rec: {val_m['recall']:.4f} | "
            f"Next Th: {epoch_best_th:.2f}"
        )

        if val_m["f1"] > best_val_f1:
            best_val_f1 = val_m["f1"]
            locked_th = epoch_best_th
            torch.save(
                {"model": model.state_dict(), "emb": node_emb.state_dict(), "thresh": locked_th},
                os.path.join(ckpt_dir, "best_model.pth"),
            )

    # -----------------------
    # Final test (best ckpt)
    # -----------------------
    print("\n" + "=" * 20 + " FINAL TEST " + "=" * 20)
    best_ckpt = torch.load(os.path.join(ckpt_dir, "best_model.pth"), weights_only=False)
    model.load_state_dict(best_ckpt["model"])
    node_emb.load_state_dict(best_ckpt["emb"])
    test_th = float(best_ckpt["thresh"])

    ts_m = _run_epoch(
        loaders["test"],
        node_emb,
        model,
        edge_index,
        edge_weight,
        id_to_idx,
        device,
        criterion,
        optimizer=None,
        thresh=test_th,
        desc="Test Set",
    )

    _save_epoch_plots(epoch_hist, ckpt_dir, dataset_name=dataset_name)
    _save_confusion_matrix(ts_m, ckpt_dir, test_th=test_th)

    print(
        f"Final Test | F1: {ts_m['f1']:.4f} | Prec: {ts_m['precision']:.4f} | "
        f"Rec: {ts_m['recall']:.4f} | Th: {test_th:.2f}"
    )
    print(f"Saved plots in: {ckpt_dir}")


# ---------------------------
# Entry
# ---------------------------
if __name__ == "__main__":
    D_DIR = "datasets/eclipse"
    CSV = os.path.join(D_DIR, "tokenized_pairs_train_bert-base-uncased_50000.csv")
    train_bertgnn(CSV, D_DIR, epochs=10, batch_size=64)


In [None]:
!pwd