In [1]:
# %%
import os
import sys
import torch
from lightweight_graph.dataset import LightweightGraphDataset
from tqdm import tqdm
from typing import Dict, Generator, Tuple


# Set PYTHONPATH to project root for the import
project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) if "scripts" in os.getcwd() or "notebooks" in os.getcwd() else os.getcwd()
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# --- Configuration ---
SAVE_DIR = "lightweight_graph/data"
# Set the primary device for data loading (e.g., 'cuda:1' or 'cuda:2')
DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# --- Load and move data to the specified device ---
dataset = LightweightGraphDataset.load_or_create(save_dir=SAVE_DIR)
dataset.to(DEVICE)

print(f"Dataset loaded to {DEVICE}. Training contexts: {dataset.train_mask.sum().item()}")

[32m2025-09-18 18:28:11.274[0m | [1mINFO    [0m | [36mlightweight_graph.dataset[0m:[36mload_or_create[0m:[36m100[0m - [1mLoading instance-based graph dataset from lightweight_graph/data_instances...[0m


Dataset loaded to cuda:2. Training contexts: 250814


In [2]:
def batchify_contexts(
    dataset: LightweightGraphDataset,
    split_indices: torch.Tensor,
    batch_size: int
) -> Generator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None, None]:
    """
    Generator that yields mini-batches for GNN training or evaluation.
    Each batch contains a subgraph with all premises and a slice of contexts.
    """
    n_premises = dataset.premise_embeddings.size(0)
    
    # Pre-filter edges and labels for the entire split for efficiency
    split_edge_mask = torch.isin(dataset.context_edge_index[1], split_indices)
    split_context_edge_index = dataset.context_edge_index[:, split_edge_mask]
    split_context_edge_attr = dataset.context_edge_attr[split_edge_mask]
    
    split_label_mask = torch.isin(dataset.context_premise_labels[0], split_indices)
    split_context_premise_labels = dataset.context_premise_labels[:, split_label_mask]

    for start in range(0, len(split_indices), batch_size):
        end = min(start + batch_size, len(split_indices))
        batch_global_indices = split_indices[start:end]
        
        batch_global_to_local_map = torch.full(
            (batch_global_indices.max() + 1,), -1, dtype=torch.long, device=split_indices.device
        )
        batch_global_to_local_map[batch_global_indices] = torch.arange(
            len(batch_global_indices), device=split_indices.device
        )
        
        batch_context_embeddings = dataset.context_embeddings[batch_global_indices]
        batch_context_file_indices = dataset.context_to_file_idx_map[batch_global_indices]
        
        batch_edge_mask = torch.isin(split_context_edge_index[1], batch_global_indices)
        batch_context_edge_index_global = split_context_edge_index[:, batch_edge_mask]
        batch_context_edge_attr = split_context_edge_attr[batch_edge_mask]
        
        batch_label_mask = torch.isin(split_context_premise_labels[0], batch_global_indices)
        batch_labels_global = split_context_premise_labels[:, batch_label_mask]

        # Use the new batch-specific map for shifting indices
        batch_context_edge_index = batch_context_edge_index_global.clone()
        batch_context_edge_index[1] = batch_global_to_local_map[batch_context_edge_index[1]] + n_premises
        
        batch_labels = batch_labels_global.clone()
        batch_labels[0] = batch_global_to_local_map[batch_labels[0]]

        all_batch_embeddings = torch.cat([dataset.premise_embeddings, batch_context_embeddings], dim=0)
        all_batch_edge_index = torch.cat([dataset.premise_edge_index, batch_context_edge_index], dim=1)
        all_batch_edge_attr = torch.cat([dataset.premise_edge_attr, batch_context_edge_attr], dim=0)

        yield all_batch_embeddings, all_batch_edge_index, all_batch_edge_attr, batch_labels, batch_context_file_indices


class Model:
    """Abstract base class for a GNN-based retrieval model."""
    
    def train_batch(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, batch_labels: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError("Subclasses must implement the training step.")

    def train_epoch(self, dataset: LightweightGraphDataset, batch_size: int) -> float:
        train_indices = dataset.train_mask.nonzero(as_tuple=True)[0]
        train_generator = batchify_contexts(dataset, train_indices, batch_size)
        total_loss, num_batches = 0.0, 0
        pbar = tqdm(train_generator, desc="Training Epoch")
        for batch_embeddings, batch_edge_index, batch_edge_attr, batch_labels, _ in pbar:
            loss = self.train_batch(batch_embeddings, batch_edge_index, batch_edge_attr, batch_labels)
            total_loss += loss
            num_batches += 1
            pbar.set_postfix({"loss": f"{loss:.4f}"})
        return total_loss / num_batches if num_batches > 0 else 0.0

    def get_predictions(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, num_batch_contexts: int, n_premises: int) -> torch.Tensor:
        raise NotImplementedError("Subclasses must implement the prediction logic.")

    @torch.no_grad()
    def eval_batch(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, batch_labels: torch.Tensor, batch_context_file_indices: torch.Tensor, dataset: LightweightGraphDataset) -> Dict[str, float]:
        n_premises = dataset.premise_embeddings.shape[0]
        num_batch_contexts = batch_embeddings.shape[0] - n_premises
        scores = self.get_predictions(batch_embeddings, batch_edge_index, batch_edge_attr, num_batch_contexts, n_premises)

        # --- Create Accessibility Mask (This logic is now correct) ---
        accessible_mask = torch.zeros_like(scores, dtype=torch.bool)
        for i in range(num_batch_contexts):
            context_file_idx = batch_context_file_indices[i].item()
            
            # 1. Premises in the same file are accessible
            in_file_mask = (dataset.premise_to_file_idx_map == context_file_idx)
            
            # 2. Premises in imported files (transitive) are accessible
            # This is a single lookup because file_dependency_edge_index IS the transitive closure.
            dependency_file_indices = dataset.file_dependency_edge_index[1, dataset.file_dependency_edge_index[0] == context_file_idx]
            imported_mask = torch.isin(dataset.premise_to_file_idx_map, dependency_file_indices)
            
            accessible_mask[i] = in_file_mask | imported_mask
        
        scores.masked_fill_(~accessible_mask, -torch.inf)
        
        # --- Metric Calculation ---
        gt_mask = torch.zeros_like(scores, dtype=torch.bool)
        gt_mask[batch_labels[0], batch_labels[1]] = True
        num_positives = gt_mask.sum(dim=1)
        valid_contexts = num_positives > 0
        if not valid_contexts.any(): return {'R@1': 0.0, 'R@10': 0.0, 'MRR': 0.0}

        top_10_indices = scores.topk(k=10, dim=1).indices
        top_10_hits = gt_mask.gather(1, top_10_indices)

        recall_at_1 = (top_10_hits[:, 0][valid_contexts] / num_positives[valid_contexts]).mean().item()
        recall_at_10 = (top_10_hits.sum(dim=1)[valid_contexts] / num_positives[valid_contexts]).mean().item()
        
        sorted_indices = scores.argsort(dim=1, descending=True)
        sorted_gt = gt_mask.gather(1, sorted_indices)
        first_hit_rank = torch.argmax(sorted_gt[valid_contexts].int(), dim=1) + 1
        mrr = (1.0 / first_hit_rank).mean().item()
        
        return {'R@1': recall_at_1, 'R@10': recall_at_10, 'MRR': mrr}

    @torch.no_grad()
    def eval(self, dataset: LightweightGraphDataset, split: str, batch_size: int) -> Dict[str, float]:
        mask = getattr(dataset, f"{split}_mask", None)
        if mask is None: raise ValueError(f"Invalid split: {split}")
        
        split_indices = mask.nonzero(as_tuple=True)[0]
        eval_generator = batchify_contexts(dataset, split_indices, batch_size)
        
        all_metrics = []
        pbar = tqdm(eval_generator, desc=f"Evaluating on {split} split")
        for batch_embeddings, batch_edge_index, batch_edge_attr, batch_labels, batch_context_file_indices in pbar:
            metrics = self.eval_batch(batch_embeddings, batch_edge_index, batch_edge_attr, batch_labels, batch_context_file_indices, dataset)
            all_metrics.append(metrics)
            pbar.set_postfix(metrics)

        if not all_metrics: return {'R@1': 0.0, 'R@10': 0.0, 'MRR': 0.0}
        
        final_metrics = {key: torch.tensor([m[key] for m in all_metrics]).mean().item() for key in all_metrics[0]}
        
        print(f"\n--- Evaluation Results for '{split}' split ---")
        print(f"  Recall@1:  {final_metrics['R@1']:.4f}")
        print(f"  Recall@10: {final_metrics['R@10']:.4f}")
        print(f"  MRR:       {final_metrics['MRR']:.4f}")
        print("------------------------------------------")

        return final_metrics

In [3]:
import torch.nn.functional as F

class BaselineModel(Model):
    def __init__(self):
        super().__init__()
        print("Initialized BaselineModel (no GNN, no training).")

    def train_batch(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, batch_labels: torch.Tensor) -> float:
        return 0.0

    def get_predictions(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, num_batch_contexts: int, n_premises: int) -> torch.Tensor:
        # The input batch_embeddings are the initial LM embeddings.
        # We simply ignore the edge_index and edge_attr.
        initial_premise_embs = batch_embeddings[:n_premises]
        initial_context_embs = batch_embeddings[n_premises:]

        # L2-normalize for cosine similarity calculation.
        premise_embs_norm = F.normalize(initial_premise_embs, p=2, dim=1)
        context_embs_norm = F.normalize(initial_context_embs, p=2, dim=1)

        # Compute similarity scores via matrix multiplication.
        scores = torch.mm(context_embs_norm, premise_embs_norm.T)
        
        return scores

In [4]:
class RandomBaselineModel(Model):
    def __init__(self):
        super().__init__()
        print("Initialized RandomBaselineModel (random scores, no training).")

    def train_batch(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, batch_labels: torch.Tensor) -> float:
        return 0.0

    def get_predictions(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, num_batch_contexts: int, n_premises: int) -> torch.Tensor:
        # Generate random scores for each context-premise pair.
        scores = torch.rand((num_batch_contexts, n_premises), device=batch_embeddings.device)
        return scores

In [5]:
#baseline_model = BaselineModel()
#baseline_val_metrics = baseline_model.eval(dataset, split="train", batch_size=5120)
#baseline_val_metrics = baseline_model.eval(dataset, split="val", batch_size=256)
#baseline_test_metrics = baseline_model.eval(dataset, split="test", batch_size=256)

#random_model = RandomBaselineModel()
#random_val_metrics = random_model.eval(dataset, split="train", batch_size=5120)
#random_val_metrics = random_model.eval(dataset, split="val", batch_size=256)
#random_test_metrics = random_model.eval(dataset, split="test", batch_size=256)

In [None]:
# %%
from numpy import negative
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import RGCNConv
from torch_geometric.utils import dropout_edge
from typing import Tuple, Literal

# from your_abstract_model_file import Model
# from lightweight_graph.dataset import LightweightGraphDataset

def calculate_metrics(scores: torch.Tensor, gt_mask: torch.Tensor) -> float:
    num_positives = gt_mask.sum(dim=1)
    valid_contexts = num_positives > 0
    if not valid_contexts.any():
        return {"R@1": 0.0, "R@10": 0.0, "MRR": 0.0}

    # calculate R@1, R@10, MRR
    top_10_indices = scores.topk(k=10, dim=1).indices
    top_10_hits = gt_mask.gather(1, top_10_indices)

    tr1 = (torch.ones_like(top_10_hits)[:, 0][valid_contexts] / num_positives[valid_contexts]).mean().item()
    tr10 = (torch.ones_like(top_10_hits).sum(dim=1)[valid_contexts] / num_positives[valid_contexts]).mean().item()

    recall_at_1 = (top_10_hits[:, 0][valid_contexts] / num_positives[valid_contexts]).mean().item()
    recall_at_10 = (top_10_hits.sum(dim=1)[valid_contexts] / num_positives[valid_contexts]).mean().item()

    # compute reciprocal rank
    #ranks = torch.arange(1, 11, device=scores.device).float()  # [1,2,...,10]
    #reciprocal_ranks = (top_10_hits * (1.0 / ranks)).max(dim=1).values
    #mrr = reciprocal_ranks[valid_contexts].mean().item()

    return {"R@1": recall_at_1, "R@10": recall_at_10, "R@1 upper bound" : tr1, "R@10 upper bound" : tr10}#, "MRR": mrr}

class HeadAttentionScoring(nn.Module):
    def __init__(self, embedding_dim: int, num_heads: int, aggregation: Literal["logsumexp", "mean", "max", "gated", "nn"]):
        super(HeadAttentionScoring, self).__init__()
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim

        self.score_W = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.aggregation = aggregation

        if aggregation == "gated":
            self.gate_W = nn.Linear(embedding_dim, num_heads)
        if aggregation == "nn":
            self.nn = nn.Sequential(
                nn.Linear(2 * embedding_dim, embedding_dim),
                nn.ReLU(),
                nn.Linear(embedding_dim, 1)
            )

    def forward(self, premise_embs: torch.Tensor, context_embs: torch.Tensor) -> torch.Tensor:
        batch_size = context_embs.size(0)
        n_premises = premise_embs.size(0)

        if self.aggregation == "nn":
            # Expand context embeddings to match premise embeddings
            context_expanded = context_embs.unsqueeze(1).expand(-1, n_premises, -1)  # (batch_size, n_premises, embedding_dim)
            premise_expanded = premise_embs.unsqueeze(0).expand(batch_size, -1, -1)  # (batch_size, n_premises, embedding_dim)
            combined = torch.cat([context_expanded, premise_expanded], dim=-1)  # (batch_size, n_premises, 2 * embedding_dim)
            scores = self.nn(combined).squeeze(-1)  # (batch_size, n_premises)
            return scores  # (batch_size, n_premises)

        context_embs = self.score_W(context_embs)  # (batch_size, embedding_dim)
        # reshape for multi-head
        context_embs = context_embs.view(batch_size, self.num_heads, self.embedding_dim // self.num_heads)  # (batch_size, num_heads, head_dim)
        premise_embs = premise_embs.view(n_premises, self.num_heads, self.embedding_dim // self.num_heads)  # (n_premises, num_heads, head_dim)
        # compute attention scores to get (batch_size, n_premises, num_heads)
        scores = torch.einsum('bhd, phd -> bhp', context_embs, premise_embs)  # (batch_size, num_heads, n_premises)
        scores = scores.permute(0, 2, 1)  # (batch_size, n_premises, num_heads)

        if self.aggregation == "max":
            scores, _ = scores.max(dim=-1)  # (batch_size, n_premises)
        elif self.aggregation == "mean":
            scores = scores.mean(dim=-1)  # (batch_size, n_premises)
        elif self.aggregation == "logsumexp":
            scores = torch.logsumexp(scores, dim=-1)  # (batch_size, n_premises)
        elif self.aggregation == "gated":
            score_gates = self.gate_W(context_embs.mean(dim=1))  # (batch_size, num_heads)
            # apply softmax to get weights
            score_gates = F.softmax(score_gates, dim=-1)  # (batch_size, num_heads)
            scores = (scores * score_gates.unsqueeze(1)).sum(dim=-1)  # (batch_size, n_premises)
        else:
            raise ValueError(f"Unknown aggregation method: {self.aggregation}")

        return scores  # (batch_size, n_premises)

class TestModel(Model, nn.Module):
    def __init__(
        self,
        dataset: LightweightGraphDataset,
        hidden_dim: int,
        aggregation: Literal["mean", "max", "logsumexp", "gated", "nn"],
        n_heads: int,
        lr: float = 1e-4,
        loss : Literal["bce", "mse"] = "mse",
    ):
        Model.__init__(self)
        nn.Module.__init__(self)
        
        self.embedding_dim = dataset.premise_embeddings.shape[1]
        self.hidden_dim = hidden_dim
        self.num_relations = len(dataset.edge_types_map)
        
        self.random_premise_embeds = nn.Embedding(dataset.premise_embeddings.shape[0], self.hidden_dim)
        self.random_premise_embed_for_context = nn.Embedding(dataset.premise_embeddings.shape[0], self.embedding_dim)

        self.optimizer = torch.optim.AdamW(self.parameters(), lr=lr)
        # Keep track of metrics during hard mining
        self.last_hard_mining_recall = 0.0
        self.loss = loss

        self.rgcn = RGCNConv(
            in_channels=self.embedding_dim,
            out_channels=self.hidden_dim,
            num_relations=2,
        )

        self.scoring = HeadAttentionScoring(embedding_dim=self.hidden_dim, num_heads=n_heads, aggregation=aggregation)
        
        print(f"Initialized RGCNModel with {self.num_relations} relations, hidden_dim={self.hidden_dim}")

    def forward(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, n_premises: int) -> Tuple[torch.Tensor, torch.Tensor]:
        expected_dtype = torch.float32
        initial_premise_embs = self.random_premise_embeds.weight.to(expected_dtype)
        
        initial_context_embs = torch.zeros_like(batch_embeddings[n_premises:]).to(expected_dtype)
        initial_premise_emb_for_context = self.random_premise_embed_for_context.weight.to(expected_dtype)
        batch_embeddings_for_context = torch.cat([initial_premise_emb_for_context, initial_context_embs], dim=0)

        print(f"memory before rgcn: {torch.cuda.memory_allocated(DEVICE)/1e9 if torch.cuda.is_available() else 0.0} GB")
        refined_context_embs = self.rgcn(batch_embeddings_for_context, batch_edge_index, batch_edge_attr)
        print(f"memory after rgcn: {torch.cuda.memory_allocated(DEVICE)/1e9 if torch.cuda.is_available() else 0.0} GB")
        
        # refine the premise embeddings first
        return initial_premise_embs, refined_context_embs[n_premises:]

    def get_predictions(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, num_batch_contexts: int, n_premises: int, squash01 : bool = True) -> torch.Tensor:
        final_premise_embs, final_context_embs = self.forward(batch_embeddings, batch_edge_index, batch_edge_attr, n_premises)
        return self.scoring.forward(final_premise_embs, final_context_embs)

    def train_batch(self, batch_embeddings: torch.Tensor, batch_edge_index: torch.Tensor, batch_edge_attr: torch.Tensor, batch_labels: torch.Tensor, i) -> torch.Tensor:
        self.train()
        n_premises = self.premise_embeddings_shape[0]
        num_batch_contexts = batch_embeddings.shape[0] - n_premises
        logits_tensor = self.get_predictions(batch_embeddings, batch_edge_index, batch_edge_attr, num_batch_contexts, n_premises, squash01 = False)
        
        targets_tensor = torch.zeros_like(logits_tensor)
        pos_context_indices = batch_labels[0]
        pos_premise_indices = batch_labels[1]
        targets_tensor[pos_context_indices, pos_premise_indices] = 1.0

        # report R@10 during training for monitoring
        self.last_metrics = calculate_metrics(logits_tensor, targets_tensor)
        
        # --- Weighted Loss Calculation (works for all strategies) ---
        n_negative = (targets_tensor == 0).sum().item()
        n_positive = (targets_tensor == 1).sum().item()

        if n_positive == 0 or n_negative == 0:
            # Handle edge case where we have only one class
            weights = torch.ones_like(logits_tensor)
        else:
            # Calculate class weights to balance the loss
            pos_weight = n_negative / n_positive  # Higher weight for minority class
            weights = torch.ones_like(logits_tensor)
            weights[targets_tensor == 1] = pos_weight

        if self.loss == "bce":
            unweighted_loss = F.binary_cross_entropy_with_logits(logits_tensor, targets_tensor, reduction='none')
        else:
            raise ValueError(f"Unknown loss function: {self.loss}")
        weighted_loss = (unweighted_loss * weights).mean()

        return weighted_loss

    def train_epoch(self, dataset: LightweightGraphDataset, batch_size: int, accumulation_steps: int = 1) -> float:
        self.premise_embeddings_shape = dataset.premise_embeddings.shape
        train_indices = dataset.train_mask.nonzero(as_tuple=True)[0]
        train_generator = batchify_contexts(dataset, train_indices, batch_size)
        
        total_loss, num_batches_processed = 0.0, 0
        self.optimizer.zero_grad()
        pbar = tqdm(enumerate(train_generator), total=len(train_indices)//batch_size, desc="Training Epoch")
        for i, (batch_embeddings, batch_edge_index, batch_edge_attr, batch_labels, _) in pbar:
            if (i !=0):
                assert 0
            for j in range(10000):
                loss = self.train_batch(batch_embeddings, batch_edge_index, batch_edge_attr, batch_labels, j)
                loss = loss / accumulation_steps
                memory = torch.cuda.memory_allocated(DEVICE)/1e9 if torch.cuda.is_available() else 0.0
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                
                self.optimizer.step()
                self.optimizer.zero_grad()
            
                log_dict = {"loss": f"{loss.item() * accumulation_steps:.4f}", "memory (GB)": f"{memory}"}
                log_dict.update(self.last_metrics)
                pbar.set_postfix(log_dict)

                total_loss += loss.item() * accumulation_steps
                num_batches_processed += 1
                
        if num_batches_processed % accumulation_steps != 0:
            self.optimizer.step()
            self.optimizer.zero_grad()
            
        return total_loss / num_batches_processed if num_batches_processed > 0 else 0.0

In [7]:
HIDDEN_DIM = 512
LEARNING_RATE = 1e-2
# Note: For "all" strategy, a smaller batch size is needed due to the large logit matrix
BATCH_SIZE = 16
ACCUMULATION_STEPS = 1 # Effective batch size = 256 * 16 = 4096
EPOCHS = 100


# --- Instantiate the Model ---
testmodel = TestModel(
    dataset=dataset, 
    hidden_dim=HIDDEN_DIM, 
    lr=LEARNING_RATE,
    aggregation="nn",
    n_heads=8,
    loss="bce",
)
testmodel.to(DEVICE)

# --- Training Loop ---
for epoch in range(EPOCHS):
    print(f"\n--- [\"All\" Negatives] Epoch {epoch+1}/{EPOCHS} ---")
    avg_loss = testmodel.train_epoch(dataset, batch_size=BATCH_SIZE, accumulation_steps=ACCUMULATION_STEPS)
    print(f"End of Epoch {epoch+1}, Average Training Loss: {avg_loss:.4f}")
    
    # Evaluate on validation set after each epoch
    testmodel.eval(dataset, split="val", batch_size=BATCH_SIZE)

Initialized RGCNModel with 2 relations, hidden_dim=512

--- ["All" Negatives] Epoch 1/100 ---


Training Epoch:   0%|          | 0/15675 [00:01<?, ?it/s, loss=1.3792, memory (GB)=24.741586944, R@1=0, R@10=0, R@1 upper bound=0.556, R@10 upper bound=5.56]


OutOfMemoryError: CUDA out of memory. Tried to allocate 5.52 GiB. GPU 2 has a total capacity of 47.41 GiB of which 4.97 GiB is free. Including non-PyTorch memory, this process has 42.42 GiB memory in use. Of the allocated memory 31.22 GiB is allocated by PyTorch, and 10.88 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# %% Cell to Analyze Label Duplication in a Batch
import torch
from collections import Counter

def analyze_batch_label_duplication(
    dataset: LightweightGraphDataset,
    batch_global_indices: torch.Tensor
):
    print("--- Analyzing Batch Label Duplication ---")

    split_label_mask = torch.isin(dataset.context_premise_labels[0], batch_global_indices)
    batch_labels_global = dataset.context_premise_labels[:, split_label_mask]

    if batch_labels_global.shape[1] == 0:
        print("No positive labels found in this batch.")
        return

    positive_premise_indices = batch_labels_global[1]
    unique_premises, counts = torch.unique(positive_premise_indices, return_counts=True)
    
    total_labels = len(positive_premise_indices)
    num_unique_premises = len(unique_premises)
    
    print(f"Total positive labels in batch: {total_labels}")
    print(f"Unique premises to be retrieved: {num_unique_premises}")

    duplicated_mask = counts > 1
    num_duplicated_premises = duplicated_mask.sum().item()
    
    if num_duplicated_premises == 0:
        print("No premise is a correct label for more than one context in this batch.")
        return

    duplicated_premise_ids = unique_premises[duplicated_mask]
    duplicated_premise_counts = counts[duplicated_mask]
    
    total_labels_involved_in_duplication = duplicated_premise_counts.sum().item()

    num_redundant_labels = (duplicated_premise_counts - 1).sum().item()

    avg_duplication_factor = duplicated_premise_counts.float().mean().item()
    
    print(f"\n--- Duplication Stats ---")
    print(f"Number of unique premises that are duplicated: {num_duplicated_premises} (out of {num_unique_premises})")
    print(f"Percentage of unique premises that are duplicated: {num_duplicated_premises / num_unique_premises:.2%}")
    print(f"Total labels pointing to duplicated premises: {total_labels_involved_in_duplication} (out of {total_labels})")
    print(f"Percentage of labels that are for duplicated premises: {total_labels_involved_in_duplication / total_labels:.2%}")
    print(f"\nThis means {num_redundant_labels} times, the model must reuse a single premise embedding for different contexts.")
    print(f"On average, a duplicated premise is required {avg_duplication_factor:.2f} times within this batch.")
    print("----------------------------------------")


# --- How to use it ---
# Get the first batch of training indices to test
train_indices = dataset.train_mask.nonzero(as_tuple=True)[0]
first_batch_indices = train_indices[:BATCH_SIZE]

# Run the analysis
analyze_batch_label_duplication(dataset, first_batch_indices)

--- Analyzing Batch Label Duplication ---
Total positive labels in batch: 1745
Unique premises to be retrieved: 913

--- Duplication Stats ---
Number of unique premises that are duplicated: 353 (out of 913)
Percentage of unique premises that are duplicated: 38.66%
Total labels pointing to duplicated premises: 1185 (out of 1745)
Percentage of labels that are for duplicated premises: 67.91%

This means 832 times, the model must reuse a single premise embedding for different contexts.
On average, a duplicated premise is required 3.36 times within this batch.
----------------------------------------


In [None]:
# %% Cell to Analyze Label Duplication in a Batch
import torch
from collections import Counter

def analyze_batch_label_duplication(
    dataset: LightweightGraphDataset,
    batch_global_indices: torch.Tensor
):
    """
    Analyzes the duplication of ground-truth premises within a single batch.

    This helps quantify the "representation bottleneck" problem, where the model
    must learn a single embedding for a premise that needs to be retrieved by
    multiple, different contexts in the same batch.
    """
    print("--- Analyzing Batch Label Duplication ---")

    # 1. Find all labels that are relevant to this specific batch of contexts.
    # This logic is borrowed from the batchify_contexts generator.
    split_label_mask = torch.isin(dataset.context_premise_labels[0], batch_global_indices)
    batch_labels_global = dataset.context_premise_labels[:, split_label_mask]

    if batch_labels_global.shape[1] == 0:
        print("No positive labels found in this batch.")
        return

    # 2. Isolate the premise indices from the labels. These are the items being retrieved.
    positive_premise_indices = batch_labels_global[1]

    # 3. Count the occurrences of each unique premise index.
    unique_premises, counts = torch.unique(positive_premise_indices, return_counts=True)
    
    total_labels = len(positive_premise_indices)
    num_unique_premises = len(unique_premises)
    
    print(f"Total positive labels in batch: {total_labels}")
    print(f"Unique premises to be retrieved: {num_unique_premises}")

    # 4. Identify the duplicated premises and quantify the duplication.
    duplicated_mask = counts > 1
    num_duplicated_premises = duplicated_mask.sum().item()
    
    if num_duplicated_premises == 0:
        print("No premise is a correct label for more than one context in this batch.")
        return

    # 5. Calculate statistics on the duplicates.
    duplicated_premise_ids = unique_premises[duplicated_mask]
    duplicated_premise_counts = counts[duplicated_mask]
    
    # This is the total number of labels that point to a premise that is needed more than once.
    # For example, if premise P is needed 3 times, it contributes 3 to this sum.
    total_labels_involved_in_duplication = duplicated_premise_counts.sum().item()

    # This is the number of "extra" pulls on the same embedding.
    # If premise P is needed 3 times, it has 2 "extra" pulls.
    num_redundant_labels = (duplicated_premise_counts - 1).sum().item()

    avg_duplication_factor = duplicated_premise_counts.float().mean().item()
    
    print(f"\n--- Duplication Stats ---")
    print(f"Number of unique premises that are duplicated: {num_duplicated_premises} (out of {num_unique_premises})")
    print(f"Percentage of unique premises that are duplicated: {num_duplicated_premises / num_unique_premises:.2%}")
    print(f"Total labels pointing to duplicated premises: {total_labels_involved_in_duplication} (out of {total_labels})")
    print(f"Percentage of labels that are for duplicated premises: {total_labels_involved_in_duplication / total_labels:.2%}")
    print(f"\nThis means {num_redundant_labels} times, the model must reuse a single premise embedding for different contexts.")
    print(f"On average, a duplicated premise is required {avg_duplication_factor:.2f} times within this batch.")
    print("----------------------------------------")


# --- How to use it ---
# Get the first batch of training indices to test
BATCH_SIZE = 1024 # Use the same batch size as your training
train_indices = dataset.train_mask.nonzero(as_tuple=True)[0]
first_batch_indices = train_indices[:BATCH_SIZE]

# Run the analysis
analyze_batch_label_duplication(dataset, first_batch_indices)