In [None]:
import asyncio
import aiohttp
import pandas as pd
from tqdm.asyncio import tqdm
import json

try:
    df = pd.read_csv("/workspace/nodeidx2paperid.csv")
    print(f"Loaded {len(df)} IDs.")
    
    # OGB mapping files usually have a column 'paper id' for the MAG ID
    # We rename it to 'mag_id' for clarity
    if 'paper id' in df.columns:
        df = df.rename(columns={'paper id': 'mag_id'})
    
    # Ensure they are strings for the URL
    mag_ids = df['mag_id'].astype(str).tolist()
    
except KeyError:
    print(f"Error: Columns found: {df.columns}. Please check the CSV header.")
    mag_ids = []

# 2. Async Fetcher
async def fetch_batch(session, ids):
    # Join IDs with pipe | for "OR" logic
    # Filter 'ids.mag' tells OpenAlex these are Microsoft Academic Graph IDs
    ids_param = "|".join(ids)
    url = f"https://api.openalex.org/works?filter=ids.mag:{ids_param}&per-page=100&select=id,ids,title,abstract_inverted_index"
    
    try:
        async with session.get(url) as response:
            if response.status == 200:
                data = await response.json()
                return data.get('results', [])
            elif response.status == 429:
                # Rate limited? Wait a sec and retry (basic handling)
                await asyncio.sleep(2)
                return []
            else:
                return []
    except Exception as e:
        print(f"Error: {e}")
        return []

async def main(ids_list):
    batch_size = 50 # 50 is safer for OpenAlex URL length limits than 100
    tasks = []
    results = []
    
    # Limit to 10 simultaneous connections to be polite
    connector = aiohttp.TCPConnector(limit_per_host=10)
    
    async with aiohttp.ClientSession(connector=connector) as session:
        # Create all tasks
        for i in range(0, len(ids_list), batch_size):
            batch = ids_list[i : i + batch_size]
            tasks.append(fetch_batch(session, batch))
        
        # Run them with a progress bar
        # gathered_results will be a list of lists
        gathered_results = await tqdm.gather(*tasks)
        
    # Flatten the list of lists
    for batch_result in gathered_results:
        results.extend(batch_result)
        
    return results

if __name__ == "__main__":
    if len(mag_ids) > 0:
        # CORRECT WAY for .py scripts:
        final_data = await main(mag_ids)
        
        print(f"Fetched metadata for {len(final_data)} papers.")
        
        # 4. Save to JSON
        with open('arxiv_mag_metadata.json', 'w') as f:
            json.dump(final_data, f)
        print("Saved to arxiv_mag_metadata.json")

In [None]:
import json
import os
from multiprocessing import Pool
from tqdm import tqdm

def reconstruct_abstract(inverted_index):
    if not inverted_index: return ""
    max_index = 0
    for indices in inverted_index.values():
        if indices: max_index = max(max_index, max(indices))
    
    reconstructed_list = [""] * (max_index + 1)
    for word, indices in inverted_index.items():
        for idx in indices:
            reconstructed_list[idx] = word
    return " ".join(reconstructed_list)

def process_item(item):
    title = item.get('title', "")
    inv_index = item.get('abstract_inverted_index')
    abstract = reconstruct_abstract(inv_index)
    text = f"{title}. {abstract}" if title or abstract else "Paper content unavailable"
    return text

if __name__ == "__main__":
    input_file = 'arxiv_mag_metadata.json'
    output_text_file = 'reconstructed_texts.json'

    with open(input_file, 'r') as f:
        data = json.load(f)

    print(f"Reconstructing {len(data)} abstracts using all CPU cores...")
    # Use Pool to bypass GIL for string processing
    with Pool() as p:
        texts = list(tqdm(p.imap(process_item, data, chunksize=100), total=len(data)))

    with open(output_text_file, 'w') as f:
        json.dump(texts, f)
    print("Stage 1 Complete: Text saved.")

In [None]:
import json
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import os

def reconstruct_abstract(inverted_index):
    if not inverted_index:
        return ""
    
    # 1. Find the length of the abstract by finding the max index
    max_index = 0
    for indices in inverted_index.values():
        if indices:
            max_index = max(max_index, max(indices))
            
    # 2. Create an empty list of the right size
    # We add +1 because indices are 0-based
    reconstructed_list = [""] * (max_index + 1)
    
    # 3. Fill the list with words
    for word, indices in inverted_index.items():
        for idx in indices:
            reconstructed_list[idx] = word
            
    # 4. Join them back into a string
    return " ".join(reconstructed_list)

def generate_local_embeddings(input_file, output_file):
    # 1. Load Data
    print(f"Loading {input_file}...")
    if not os.path.exists(input_file):
        print(f"Error: {input_file} not found.")
        return

    with open(input_file, 'r') as f:
        data = json.load(f)

    # 2. Prepare Text
    # Combine Title + Abstract for best context
    texts = []
    for item in data:
        title = item.get('title', "")
        
        # Handle the inverted index
        inv_index = item.get('abstract_inverted_index')
        if inv_index:
            abstract = reconstruct_abstract(inv_index)
        else:
            abstract = ""

        if not title and not abstract:
            text = "Paper content unavailable"
        else:
            text = f"{title}. {abstract}"
        
        texts.append(text)

    print(f"Loaded {len(texts)} papers. Loading model...")

    model = model = SentenceTransformer(
                        "Qwen/Qwen3-Embedding-4B",
                        model_kwargs={
                            "attn_implementation": "flash_attention_2", 
                            "torch_dtype": "float16",
                            "device_map": "auto"
                        }
                    )
    
    # Check if GPU is available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cpu" and torch.backends.mps.is_available():
        device = "mps" # For Mac users
    
    model.to(device)
    print(f"Model loaded on {device}.")

    # 4. Generate Embeddings (with MRL Truncation)
    # We can perform the truncation *after* generation or let the model handle it if supported.
    # The safest way for MRL models is to generate full dims and slice, 
    # as the first N dimensions contain the most info.
    
    BATCH_SIZE = 16 # Adjust based on your VRAM (32 for 8GB, 128 for 24GB)
    
    print("Generating embeddings...")
    embeddings = model.encode(
        texts, 
        batch_size=BATCH_SIZE, 
        show_progress_bar=True, 
        convert_to_tensor=True,
        normalize_embeddings=True
    )

    print(f"Original shape: {embeddings.shape}")
    embeddings_256 = embeddings[:, :256]
    embeddings_256 = F.normalize(embeddings_256, p=2, dim=1)
    print(f"Truncated & Re-normalized shape: {embeddings_256.shape}")

    # 6. Save
    print(f"Saving to {output_file}...")
    torch.save(embeddings_256.cpu(), output_file)
    print("Done!")

if __name__ == "__main__":
    generate_local_embeddings('arxiv_mag_metadata.json', 'qwen_embeddings_256.pt')

In [None]:
import torch
import pandas as pd
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm

# Files
MAPPING_FILE = Path("data/ogbn_arxiv/mapping/nodeidx2paperid.csv.gz")
METADATA_FILE = Path("arxiv_mag_metadata.json")
EMBEDDING_FILE = Path("qwen_embeddings_256.pt")
OUTPUT_FILE = Path("qwen_embeddings_256_aligned.pt")

def align_embeddings():
    print("1. Loading OGB Mapping (The Ground Truth)...")
    # This CSV maps Graph Node Index -> MAG Paper ID
    df_mapping = pd.read_csv(MAPPING_FILE)
    
    # Handle column naming variations
    if 'paper id' in df_mapping.columns:
        df_mapping = df_mapping.rename(columns={'paper id': 'mag_id'})
    
    # Create a list where index 0 = mag_id for node 0, etc.
    # We explicitly sort by node idx just to be safe, though usually it's sorted
    if 'node idx' in df_mapping.columns:
        df_mapping = df_mapping.sort_values('node idx')
        
    ground_truth_ids = df_mapping['mag_id'].astype(str).tolist()
    total_nodes = len(ground_truth_ids)
    print(f"   Graph expects {total_nodes} nodes.")

    print("2. Loading Fetched Metadata & Embeddings...")
    with open(METADATA_FILE, 'r') as f:
        metadata = json.load(f)
    
    # Helper to clean IDs (OpenAlex returns "https://openalex.org/W213...", we just want "213...")
    def clean_id(url_or_id):
        return str(url_or_id).replace("https://openalex.org/W", "").replace("W", "")

    # Map the fetched MAG ID to its index in the *fetched* tensor
    # metadata[i] corresponds to embeddings[i]
    fetched_id_to_index = {}
    for idx, item in enumerate(metadata):
        # OpenAlex IDs usually look like 'https://openalex.org/W12345' or just integer IDs
        # We try to extract the MAG ID part
        raw_id = item.get('id', '')
        mag_id = clean_id(raw_id)
        
        # Also check 'ids' field if available (sometimes MAG id is nested)
        if 'ids' in item and 'mag' in item['ids']:
            mag_id = str(item['ids']['mag'])
            
        fetched_id_to_index[mag_id] = idx

    print(f"   Mapped {len(fetched_id_to_index)} fetched papers to their tensor indices.")

    print("3. Loading Tensor...")
    # Load on CPU to save memory
    raw_embeddings = torch.load(EMBEDDING_FILE, map_location='cpu')
    embedding_dim = raw_embeddings.shape[1]
    
    # Create the final aligned tensor (Filled with Zeros)
    aligned_tensor = torch.zeros((total_nodes, embedding_dim), dtype=torch.float32)
    
    print("4. Aligning...")
    hits = 0
    misses = 0
    
    # Iterate through the GROUND TRUTH order
    for node_idx, target_mag_id in enumerate(tqdm(ground_truth_ids)):
        target_mag_id = str(target_mag_id)
        
        if target_mag_id in fetched_id_to_index:
            # Found it! Grab the vector from the raw pile
            raw_idx = fetched_id_to_index[target_mag_id]
            aligned_tensor[node_idx] = raw_embeddings[raw_idx]
            hits += 1
        else:
            # Missing? It stays as zeros (or you could use random noise)
            misses += 1

    print("-" * 30)
    print(f"Alignment Complete.")
    print(f"✅ Matched: {hits} ({hits/total_nodes:.1%})")
    print(f"❌ Missing: {misses} ({misses/total_nodes:.1%}) -> Filled with zeros")
    print(f"Final Tensor Shape: {aligned_tensor.shape}")
    
    torch.save(aligned_tensor, OUTPUT_FILE)
    print(f"Saved aligned tensor to {OUTPUT_FILE}")

if __name__ == "__main__":
    align_embeddings()

In [None]:
"""
Download and load the ogbn-arxiv dataset into data/, fuse with Qwen embeddings,
and create a contrastive LinkNeighborLoader.
"""
import torch
import os
from pathlib import Path
# --- CHANGE 1: Import the PyG-specific dataset wrapper ---
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.loader import LinkNeighborLoader
import torch_geometric.transforms as T
import torch_sparse
import pyg_lib

# Define paths relative to this script
# Wrap string in Path() to ensure .mkdir() works
DATA_DIR = Path("./data") 
EMBEDDING_PATH = Path("qwen_embeddings_256_aligned.pt") 

def download_ogbn_arxiv(root: str | Path | None = None):
    """Download ogbn-arxiv to data/ and return the PyG dataset object."""
    if root is None:
        root = DATA_DIR
    else:
        root = Path(root)
        
    root.mkdir(parents=True, exist_ok=True)
    
    print(f"Downloading/Loading ogbn-arxiv at {root}...")
    # --- CHANGE 2: Use PygNodePropPredDataset instead of NodePropPredDataset ---
    dataset = PygNodePropPredDataset(name="ogbn-arxiv", root=str(root))
    return dataset

def load_data_with_features(root: str | Path | None = None):
    """
    Loads the graph and replaces/concatenates features with Qwen embeddings.
    Returns: (dataset, data)
    """
    dataset = download_ogbn_arxiv(root)
    data = dataset[0] # Now this returns a PyG Data object, not a tuple
    
    # --- Feature Fusion ---
    if EMBEDDING_PATH.exists():
        print(f"Found custom embeddings at {EMBEDDING_PATH}. Fusing...")
        
        # Load Qwen embeddings (ensure CPU to prevent OOM during concat)
        qwen_emb = torch.load(EMBEDDING_PATH, map_location='cpu', weights_only=True)
        
        # Validation: Check alignment
        if qwen_emb.shape[0] != data.num_nodes:
            raise ValueError(f"Shape Mismatch! Graph has {data.num_nodes} nodes, "
                             f"but embeddings have {qwen_emb.shape[0]} rows.")

        # Concatenate: [Original(128) | Qwen(256)] -> [384]
        # data.x is already a tensor in PygNodePropPredDataset
        data.x = torch.cat([data.x, qwen_emb], dim=1)
        print(f"Fused Features Shape: {data.x.shape}")
        
    else:
        print(f"Warning: {EMBEDDING_PATH} not found. Using original features only.")
        
    return dataset, data

def create_contrastive_loader(dataset, data, batch_size=2048):
    """
    Creates a LinkNeighborLoader for Contrastive Learning.
    """
    
    # 1. Split Management
    split_idx = dataset.get_idx_split()
    train_idx = split_idx['train']
    
    # Filter edge_index to only include edges where source node is in train_idx
    print("Filtering training edges...")
    src, _ = data.edge_index
    
    # Create mask for edges where the source node is in the training set
    train_mask = torch.isin(src, train_idx)
    train_edge_index = data.edge_index[:, train_mask]
    
    print(f"Training on {train_edge_index.shape[1]} edges (out of {data.num_edges} total).")

    # 2. Define the Loader
    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=[10, 5],   # Sample 10 neighbors at hop 1, 5 at hop 2
        edge_label_index=train_edge_index, # The "Positive" edges to learn from
        neg_sampling_ratio=1.0,  # For every 1 real edge, generate 1 fake negative edge
        batch_size=batch_size,
        shuffle=True,
        num_workers=0, 
        persistent_workers=False
    )
    
    return loader

if __name__ == "__main__":
    # 1. Load Data
    dataset, data = load_data_with_features()
    
    # 2. Create Loader
    print("\nInitializing Contrastive Loader...")
    train_loader = create_contrastive_loader(dataset, data)
    
    # 3. Test a single batch
    print("\n--- Batch Inspection ---")
    batch = next(iter(train_loader))
    
    print(f"Batch Type: {type(batch)}")
    print(f"Batch Nodes: {batch.num_nodes}") 
    print(f"Batch Features: {batch.x.shape}") 
    
    print(f"Contrastive Pairs: {batch.edge_label_index.shape}")
    print(f"Labels (1=Pos, 0=Neg): {batch.edge_label[:10]}")
    
    print("\n✅ Setup Complete. Ready for training.")

In [None]:
# %% [markdown]
# # Contrastive GNN Training with Custom Architecture
# This script trains a GraphSAGE encoder using a contrastive Link Prediction task.
# It fuses Word2Vec features (128d) with Qwen Embeddings (256d) for a total input of 384d.

# %%
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, BatchNorm
from torch_geometric.loader import LinkNeighborLoader
from ogb.nodeproppred import PygNodePropPredDataset
from pathlib import Path
from tqdm import tqdm
import os

# %% [markdown]
# ## 1. Configuration & Constants

# %%
# Hyperparameters
INPUT_DIM = 384   # 128 (Original) + 256 (Qwen)
HIDDEN_DIM = 256
OUTPUT_DIM = 256
BATCH_SIZE = 2048
LR = 0.001
EPOCHS = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paths
DATA_DIR = Path("./data")
EMBEDDING_PATH = Path("qwen_embeddings_256_aligned.pt") # Ensure this is the ALIGNED version
MODEL_SAVE_PATH = "gnn_contrastive_v2.pth"

print(f"Running on: {DEVICE}")

# %% [markdown]
# ## 2. Data Loading & Feature Fusion
# This function loads the OGB-Arxiv graph and concatenates the Qwen embeddings.

# %%
def load_data():
    print(f"Loading data from {DATA_DIR}...")
    dataset = PygNodePropPredDataset(name="ogbn-arxiv", root=str(DATA_DIR))
    data = dataset[0]
    
    # Feature Fusion
    if EMBEDDING_PATH.exists():
        print(f"Fusing embeddings from {EMBEDDING_PATH}...")
        qwen_emb = torch.load(EMBEDDING_PATH, map_location='cpu', weights_only=True)
        
        # Validation
        if qwen_emb.shape[0] != data.num_nodes:
            raise ValueError(f"Mismatch: Graph has {data.num_nodes} nodes, embeddings have {qwen_emb.shape[0]}")
            
        # Concatenate: [N, 128] + [N, 256] -> [N, 384]
        data.x = torch.cat([data.x, qwen_emb], dim=1)
        print(f"New Feature Shape: {data.x.shape}")
    else:
        print("Warning: Custom embeddings not found. Using original features only.")
        
    return dataset, data

# %% [markdown]
# ## 3. Model Architecture
# Using the specific V2 architecture provided.

# %%
class EmbedderGNNv2(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers=3, dropout=0.5):
        super().__init__()
        self.dropout = dropout
        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()

        # Layer 1: Input -> Hidden
        self.convs.append(SAGEConv(in_dim, hidden_dim, aggr='mean'))
        self.bns.append(BatchNorm(hidden_dim))

        # Learnable Mask Token (Replacing Node 0's features)
        self.mask_embed = torch.nn.Parameter(torch.randn(in_dim), requires_grad=True)
        # self.register_parameter("mask_embed", self.mask_embed) # Not strictly needed if assigned to self

        # Hidden Layers
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim, aggr='mean'))
            self.bns.append(BatchNorm(hidden_dim))

        # Output Layer: Hidden -> Out
        self.convs.append(SAGEConv(hidden_dim, out_dim, aggr='mean'))
        self.bns.append(BatchNorm(out_dim))

    def forward(self, x, edge_index):
        # 1. Masking Strategy
        # We clone x so we don't modify the original dataset in memory
        # 'detach()' stops gradients flowing back to raw features (which are fixed anyway)
        x = x.clone().detach()
        
        # Replace the first node in the batch with the learnable mask token
        # Note: In LinkNeighborLoader, node 0 is usually a target node for the batch
        x[0] = self.mask_embed

        # 2. Message Passing Loop
        for i in range(self.num_layers):
            h = self.convs[i](x, edge_index)
            h = self.bns[i](h)
        
            # Activation & Dropout (except for last layer)
            if i != self.num_layers - 1:
                h = F.relu(h)
                h = F.dropout(h, p=self.dropout, training=self.training)

            # Residual Connection
            # Only apply if shapes match (Input Dim != Hidden Dim on layer 0)
            if x.shape == h.shape:
                h = h + x
            
            x = h

        return x

# %% [markdown]
# ## 4. Contrastive Training Loop
# We use LinkNeighborLoader to generate Positive (Real) and Negative (Fake) edges.

# %%
def train():
    # --- Load Data ---
    dataset, data = load_data()
    
    # --- Setup Loader ---
    split_idx = dataset.get_idx_split()
    train_idx = split_idx['train']
    
    # Filter for training edges only
    src, _ = data.edge_index
    train_mask = torch.isin(src, train_idx)
    train_edge_index = data.edge_index[:, train_mask]
    
    print(f"Training on {train_edge_index.shape[1]} edges.")

    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=[10, 10, 5],  # 2-hop neighborhoods
        edge_label_index=train_edge_index,
        neg_sampling_ratio=1.0, # 1 Neg for 1 Pos
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4, # Increase for faster loading
        persistent_workers=True
    )

    # --- Setup Model ---
    model = EmbedderGNNv2(
        in_dim=INPUT_DIM, 
        hidden_dim=HIDDEN_DIM, 
        out_dim=OUTPUT_DIM,
        num_layers=4
    ).to(DEVICE)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    
    # BCEWithLogitsLoss combines Sigmoid + BCE (Numerically stable)
    criterion = torch.nn.BCEWithLogitsLoss()

    print("\n--- Starting Training ---")
    model.train()

    for epoch in range(1, EPOCHS + 1):
        total_loss = 0
        pbar = tqdm(loader, desc=f"Epoch {epoch}/{EPOCHS}")
        
        for batch in pbar:
            batch = batch.to(DEVICE)
            optimizer.zero_grad()

            # 1. Forward Pass
            # Get embeddings for ALL nodes in the subgraph
            z = model(batch.x, batch.edge_index)

            # 2. Extract Embeddings for Contrastive Pairs
            # edge_label_index contains the pairs we want to score (Pos + Neg)
            # Row 0 = Source, Row 1 = Destination
            src_emb = z[batch.edge_label_index[0]]
            dst_emb = z[batch.edge_label_index[1]]

            # 3. Calculate Similarity (Dot Product)
            # We sum along dim 1 to get a single score per pair
            scores = (src_emb * dst_emb).sum(dim=-1)

            # 4. Loss
            # Compare scores to labels (1 for real, 0 for fake)
            loss = criterion(scores, batch.edge_label)
            
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix({"loss": loss.item()})

        print(f"Epoch {epoch} Avg Loss: {total_loss / len(loader):.4f}")

    # Save
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"Model saved to {MODEL_SAVE_PATH}")

# %%
if __name__ == "__main__":
    train()

In [None]:
import json
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import os

def run_gpu_inference(input_json, output_pt):
    # 1. Load the clean strings
    with open(input_json, 'r') as f:
        texts = json.load(f)

    # 2. Load Model with 4090-specific settings
    print("Loading Qwen-4B with Flash Attention 2...")
    model = SentenceTransformer(
        "Qwen/Qwen3-Embedding-0.6B",
        trust_remote_code=True,
        model_kwargs={
            "attn_implementation": "flash_attention_2", 
            "torch_dtype": torch.float16  # Vital for 4090 speed/VRAM
        }
    )
    model.to("cuda")

    # 3. Use Multi-Process Pool for Tokenization
    # This prevents the CPU from being the bottleneck while the GPU works
    pool = model.start_multi_process_pool()

    print(f"Generating embeddings for {len(texts)} papers...")
    # batch_size=128 is the sweet spot for 4B models on a 24GB card
    embeddings_raw = model.encode_multi_process(
        texts, 
        pool, 
        batch_size=64, 
        chunk_size=1000
    )

    model.stop_multi_process_pool(pool)

    # 4. Convert to Torch, Truncate (MRL), and Re-normalize
    print("Moving to GPU for MRL Truncation & Normalization...")
    embeddings = torch.from_numpy(embeddings_raw).to("cuda")

    # Slice the first 256 dimensions (The MRL Core)
    embeddings_256 = embeddings[:, :256]

    # RE-NORMALIZE is mandatory after slicing
    embeddings_256 = F.normalize(embeddings_256, p=2, dim=1)

    # 5. Save
    print(f"Saving to {output_pt}...")
    torch.save(embeddings_256.cpu(), output_pt)
    print(f"Final Tensor Shape: {embeddings_256.shape}. All done!")

if __name__ == "__main__":
    run_gpu_inference('reconstructed_texts.json', 'qwen_embeddings_256.pt')