In [None]:
!pip install torch_geometric
!pip install kaggle


Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [None]:
# Upload kaggle api (kaggle.json)
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"ngdanhhng","key":"930841e1d9fe3e1b76ddb6b734b5372d"}'}

In [None]:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json



In [None]:
!kaggle datasets download -d cornell-university/arxiv
!unzip arxiv.zip -d /content/arxiv


Dataset URL: https://www.kaggle.com/datasets/cornell-university/arxiv
License(s): CC0-1.0
Downloading arxiv.zip to /content
 98% 1.49G/1.53G [00:11<00:00, 272MB/s]
100% 1.53G/1.53G [00:11<00:00, 140MB/s]
Archive:  arxiv.zip
  inflating: /content/arxiv/arxiv-metadata-oai-snapshot.json  


In [None]:
# your huggingface token
from huggingface_hub import login
login("") 


In [None]:
!pip install -U bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2


In [None]:
# full_fixed_pipeline_colab_friendly_with_full_save.py
# SciBERT + LoRA(r=8) + 12k samples + Colab-Free friendly
# Full BERT-GNN hybrid for arXiv paper classification & recommendation
# → Saves *everything* needed for a zero-training recommendation run

import os
import gc
import ast
import random
import logging
import pickle
from collections import defaultdict

import numpy as np
import pandas as pd
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.metrics.pairwise import cosine_similarity

from torch_geometric.nn import SAGEConv
from torch_geometric.utils import from_networkx, to_undirected
from peft import LoraConfig, get_peft_model
from tqdm import tqdm

# =============================
# Config
# =============================
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

MODEL_NAME      = "allenai/scibert_scivocab_uncased"
MAX_SAMPLES     = 12000
BERT_BATCH_SIZE = 64
GRAD_ACCUM_STEPS = 3
EPOCHS_BERT     = 45
EPOCHS_GNN      = 200
BERT_LR         = 8e-5
GNN_LR          = 5e-4
MAX_LENGTH      = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_SAVE_PATH = "/content/models"
NUM_WORKERS = 0
GNN_DIM = 128
KNN_K   = 6
CONT_WEIGHT = 0.0
ACCUM_STEPS = GRAD_ACCUM_STEPS

os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

SAVE_TO_DRIVE = os.environ.get('SAVE_TO_DRIVE', '0') == '1'
DRIVE_PATH = os.environ.get('DRIVE_PATH', '/content/drive/MyDrive/models')
if SAVE_TO_DRIVE:
    os.makedirs(DRIVE_PATH, exist_ok=True)

logger.info(f"Using device: {DEVICE}")

# =============================
# Utilities & Data
# =============================
def generate_synthetic_data(max_samples):
    logger.warning("Generating synthetic data (synthetic mode).")
    cats = [f"cs.CV" if i % 5 == 0 else f"cat.{i%20}" for i in range(50)]
    authors = [f"Author_{j}" for j in range(200)]
    data = {
        'id': [str(i) for i in range(max_samples)],
        'title': [f"Research on topic {i} in deep learning and vision" for i in range(max_samples)],
        'abstract': [f"This work studies model {i}, presenting experiments and results that illustrate patterns. More descriptive text to give SciBERT something to learn." for i in range(max_samples)],
        'authors_parsed': [[(random.choice(authors), "", "")] for _ in range(max_samples)],
        'categories': [random.choice(cats) for _ in range(max_samples)],
    }
    return pd.DataFrame(data)

def safe_literal_eval(val):
    if isinstance(val, str):
        try:
            return ast.literal_eval(val)
        except Exception:
            return []
    return val if isinstance(val, list) else []

def load_data(max_samples=MAX_SAMPLES):
    try:
        path = "/content/arxiv/arxiv-metadata-oai-snapshot.json"
        df = pd.read_json(path, lines=True, nrows=max_samples)
        logger.info(f"Loaded {len(df)} rows from arXiv.")
    except Exception:
        df = generate_synthetic_data(max_samples)

    df.fillna({'abstract': "No abstract", 'title': "No title"}, inplace=True)
    df['id'] = df['id'].astype(str)
    df['text'] = df['title'] + "\n\n" + df['abstract']
    df['authors_parsed'] = df.get('authors_parsed', pd.Series([[]]*len(df))).apply(safe_literal_eval)
    df['authors_set'] = df['authors_parsed'].apply(lambda x: {a[0] for a in x if a})
    df['categories'] = df['categories'].apply(lambda x: x.split() if isinstance(x, str) else (x if isinstance(x, list) else []))
    df['label'] = df['categories'].apply(lambda x: x[0] if x else 'unknown')

    keep = df['label'].value_counts()[lambda x: x >= 5].index
    df = df[df['label'].isin(keep)].reset_index(drop=True)

    le = LabelEncoder()
    df['label_enc'] = le.fit_transform(df['label']).astype(np.int32)
    return df, le

df, le = load_data(MAX_SAMPLES)
num_classes = len(le.classes_)
logger.info(f"Num classes after filtering: {num_classes}")
gc.collect()

# =============================
# Graph: metadata edges
# =============================
def build_base_graph(df):
    G = nx.Graph()
    cat_papers = defaultdict(list)
    author_papers = defaultdict(list)
    for _, row in df.iterrows():
        pid = row['id']
        G.add_node(pid, text=row['text'])
        for c in row['categories']:
            cat_papers[c].append(pid)
        for a in row['authors_set']:
            author_papers[a].append(pid)
    for papers in cat_papers.values():
        for i in range(len(papers)):
            for j in range(i+1, len(papers)):
                G.add_edge(papers[i], papers[j])
    for papers in author_papers.values():
        for i in range(len(papers)):
            for j in range(i+1, len(papers)):
                G.add_edge(papers[i], papers[j])
    logger.info(f"Base graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    return G

G_base = build_base_graph(df)
node_list = list(G_base.nodes())
_df_indexed = df.set_index('id')
_df_sel = _df_indexed.loc[[n for n in node_list if n in _df_indexed.index]]
df_reindexed = _df_sel.reset_index()

graph_data = from_networkx(G_base)
graph_data.y = torch.tensor(df_reindexed['label_enc'].values, dtype=torch.long)
texts = df_reindexed['text'].tolist()

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# =============================
# BERT Projector (SciBERT + LoRA r=8)
# =============================
class BertProjector(nn.Module):
    def __init__(self, model_name, num_classes, proj_dim=GNN_DIM):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        hidden = getattr(self.bert.config, 'hidden_size', 768)
        self.classifier = nn.Linear(hidden, num_classes)
        self.projection = nn.Linear(hidden, proj_dim)

    def forward(self, input_ids, attention_mask, project=True):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last = out.last_hidden_state
        mask = attention_mask.unsqueeze(-1).expand_as(last).float()
        summed = (last * mask).sum(1)
        lengths = mask.sum(1).clamp(min=1e-9)
        pooled = summed / lengths
        logits = self.classifier(pooled)
        projected = self.projection(pooled) if project else pooled
        return logits, projected

lora_cfg = LoraConfig(
    r=8, lora_alpha=16, target_modules=["query", "key", "value"], lora_dropout=0.05, bias="none"
)

bert_model = BertProjector(MODEL_NAME, num_classes, GNN_DIM).to(DEVICE)

try:
    bert_model.bert = get_peft_model(bert_model.bert, lora_cfg)
except Exception as e:
    logger.warning(f"PEFT failed, trying fallback targets: {e}")
    lora_cfg = LoraConfig(r=8, lora_alpha=16, target_modules=["q_lin", "k_lin", "v_lin"], lora_dropout=0.05, bias="none")
    bert_model.bert = get_peft_model(bert_model.bert, lora_cfg)

try:
    bert_model.bert.print_trainable_parameters()
except Exception:
    logger.info("(PEFT) print_trainable_parameters not available")

# =============================
# Losses
# =============================
def contrastive_loss(emb, labels, margin=0.5):
    sim = F.cosine_similarity(emb.unsqueeze(1), emb.unsqueeze(0), dim=-1)
    label_mat = (labels.unsqueeze(1) == labels.unsqueeze(0)).float()
    pos = label_mat * (1.0 - sim)
    neg = (1.0 - label_mat) * torch.relu(sim - margin)
    denom = (label_mat.sum() + (1.0 - label_mat).sum()).clamp(min=1.0)
    return (pos.sum() + neg.sum()) / denom

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    def forward(self, logits, target):
        confidence = 1.0 - self.smoothing
        log_probs = F.log_softmax(logits, dim=-1)
        n_classes = logits.size(-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / max(1, (n_classes - 1)))
            true_dist.scatter_(1, target.unsqueeze(1), confidence)
        return torch.mean(torch.sum(-true_dist * log_probs, dim=-1))

# =============================
# BERT fine-tuning
# =============================
class TextDataset(Dataset):
    def __init__(self, texts): self.texts = texts
    def __len__(self): return len(self.texts)
    def __getitem__(self, i): return self.texts[i]

def get_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    return LambdaLR(optimizer, lr_lambda)

def fine_tune_bert(model, texts, labels, batch_size=BERT_BATCH_SIZE, cont_weight=CONT_WEIGHT, accum_steps=ACCUM_STEPS):
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=BERT_LR, weight_decay=1e-2, eps=1e-8)

    indices = np.arange(len(texts))
    y_np = labels.cpu().numpy()
    train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=SEED, stratify=y_np) if len(np.unique(y_np))>1 else (indices, [])

    total_steps = max(1, (len(train_idx) // batch_size // accum_steps) * EPOCHS_BERT)
    warmup_steps = max(1, int(0.10 * total_steps))
    scheduler = get_scheduler_with_warmup(optimizer, warmup_steps, total_steps)

    criterion = nn.CrossEntropyLoss()
    scaler = amp.GradScaler()

    losses, accs = [], []
    for epoch in range(EPOCHS_BERT):
        model.train()
        np.random.shuffle(train_idx)
        epoch_loss = 0.0
        batch_count = 0
        prog = tqdm(range(0, len(train_idx), batch_size), desc=f"BERT Epoch {epoch+1}/{EPOCHS_BERT}")
        for i in prog:
            batch_idx = train_idx[i:i+batch_size]
            batch_txt = [texts[j] for j in batch_idx]
            batch_y = torch.tensor(y_np[batch_idx], dtype=torch.long, device=DEVICE)

            enc = tokenizer(batch_txt, truncation=True, padding='max_length', max_length=MAX_LENGTH, return_tensors='pt')
            enc = {k: v.to(DEVICE) for k, v in enc.items()}

            with amp.autocast():
                logits, proj = model(enc['input_ids'], enc['attention_mask'], project=True)
                cls_loss = criterion(logits, batch_y)
                cont_loss = contrastive_loss(proj, batch_y) if cont_weight > 0 else torch.tensor(0.0, device=DEVICE)
                loss = cls_loss + cont_weight * cont_loss
                loss = loss / accum_steps

            scaler.scale(loss).backward()

            if ((batch_count + 1) % accum_steps) == 0 or (i + batch_size) >= len(train_idx):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()

            epoch_loss += loss.item() * accum_steps
            batch_count += 1
            prog.set_postfix(loss=loss.item() * accum_steps)
            del enc, logits, proj, batch_y
            gc.collect(); torch.cuda.empty_cache()

        avg_loss = epoch_loss / max(1, batch_count)
        losses.append(avg_loss)

        # validation
        model.eval()
        all_pred, all_true = [], []
        with torch.no_grad(), amp.autocast():
            for i in range(0, len(val_idx), batch_size):
                idx = val_idx[i:i+batch_size]
                txt = [texts[j] for j in idx]
                y = y_np[idx]
                enc = tokenizer(txt, truncation=True, padding='max_length', max_length=MAX_LENGTH, return_tensors='pt')
                enc = {k: v.to(DEVICE) for k, v in enc.items()}
                logits, _ = model(enc['input_ids'], enc['attention_mask'], project=False)
                pred = logits.argmax(dim=1).cpu().numpy()
                all_pred.extend(pred); all_true.extend(y)
                del enc, logits
                gc.collect(); torch.cuda.empty_cache()

        acc = accuracy_score(all_true, all_pred) if all_true else 0.0
        f1 = f1_score(all_true, all_pred, average='macro') if all_true else 0.0
        accs.append(acc)
        logger.info(f"[BERT] Epoch {epoch+1} | Loss: {avg_loss:.4f} | Val Acc: {acc:.4f} | F1: {f1:.4f}")

        # checkpoint
        ckpt_path = os.path.join(MODEL_SAVE_PATH, f"bert_epoch{epoch+1}.pt")
        torch.save(model.state_dict(), ckpt_path)
        if SAVE_TO_DRIVE:
            torch.save(model.state_dict(), os.path.join(DRIVE_PATH, f"bert_epoch{epoch+1}.pt"))

    return losses, accs

labels_tensor = torch.tensor(df_reindexed['label_enc'].values, dtype=torch.long, device=DEVICE)
bert_losses, bert_accs = fine_tune_bert(bert_model, texts, labels_tensor, cont_weight=CONT_WEIGHT)

# =============================
# Projected embeddings
# =============================
def get_projected_embeddings(model, texts, batch_size=BERT_BATCH_SIZE, workers=NUM_WORKERS):
    model.eval()
    ds = TextDataset(texts)
    dl = DataLoader(ds, batch_size=batch_size, num_workers=workers, pin_memory=True)
    embs = []
    with torch.no_grad(), amp.autocast():
        for batch in tqdm(dl, desc="Projecting SciBERT"):
            enc = tokenizer(batch, truncation=True, padding='max_length', max_length=MAX_LENGTH, return_tensors='pt')
            enc = {k: v.to(DEVICE) for k, v in enc.items()}
            _, proj = model(enc['input_ids'], enc['attention_mask'], project=True)
            embs.append(proj.cpu().to(torch.float32))
            del enc, proj
            gc.collect(); torch.cuda.empty_cache()
    return torch.cat(embs, dim=0) if embs else torch.zeros((0, GNN_DIM))

proj_emb_cpu = get_projected_embeddings(bert_model, texts)
proj_emb_cpu = proj_emb_cpu / (proj_emb_cpu.norm(dim=1, keepdim=True) + 1e-12)
proj_emb_np = proj_emb_cpu.numpy()

# =============================
# Augment graph with KNN
# =============================
def add_knn_edges_to_graph(G, node_list, embeddings_cpu, k=KNN_K):
    emb_np = embeddings_cpu.numpy()
    sims = cosine_similarity(emb_np)
    N = sims.shape[0]
    new_edges = []
    for i in range(N):
        sims[i, i] = -1.0
        kk = min(k, N-1)
        topk = np.argpartition(-sims[i], range(kk))[:kk]
        for j in topk:
            new_edges.append((node_list[i], node_list[j]))
    G_aug = G.copy()
    before = G_aug.number_of_edges()
    G_aug.add_edges_from(new_edges)
    after = G_aug.number_of_edges()
    logger.info(f"Added {after-before} KNN edges (k={k}). Total: {after}")
    return G_aug

G_aug = add_knn_edges_to_graph(G_base, node_list, proj_emb_cpu, k=KNN_K)
pyg_graph = from_networkx(G_aug)
pyg_graph.y = torch.tensor(df_reindexed['label_enc'].values, dtype=torch.long, device=DEVICE)
edge_index = to_undirected(pyg_graph.edge_index).to(DEVICE)
node_emb = proj_emb_cpu.to(DEVICE)

# =============================
# GNN model
# =============================
class GNN(nn.Module):
    def __init__(self, in_dim, hid, n_cls, dropout=0.3):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hid)
        self.bn1 = nn.BatchNorm1d(hid)
        self.conv2 = SAGEConv(hid, hid)
        self.bn2 = nn.BatchNorm1d(hid)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hid, n_cls)
        self.act = nn.ReLU()

    def forward(self, x, edge):
        x = self.act(self.conv1(x, edge))
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.act(self.conv2(x, edge))
        x = self.bn2(x)
        x = self.dropout(x)
        return self.fc(x)

    def embed(self, x, edge):
        x = self.act(self.conv1(x, edge))
        x = self.bn1(x)
        x = self.act(self.conv2(x, edge))
        x = self.bn2(x)
        return x

gnn = GNN(GNN_DIM, 128, num_classes, dropout=0.3).to(DEVICE)
opt_gnn = AdamW(gnn.parameters(), lr=GNN_LR, weight_decay=1e-2)
criterion_gnn = LabelSmoothingCrossEntropy(smoothing=0.1)

idx = np.arange(len(df_reindexed))
y_np = pyg_graph.y.cpu().numpy()
train_idx, test_idx = train_test_split(idx, test_size=0.2, random_state=SEED, stratify=y_np)
train_mask = torch.zeros(len(df_reindexed), dtype=torch.bool, device=DEVICE)
test_mask  = torch.zeros(len(df_reindexed), dtype=torch.bool, device=DEVICE)
train_mask[torch.tensor(train_idx, dtype=torch.long, device=DEVICE)] = True
test_mask[torch.tensor(test_idx, dtype=torch.long, device=DEVICE)] = True

# =============================
# GNN training
# =============================
scaler_gnn = amp.GradScaler()
gnn_losses, gnn_accs = [], []

for epoch in range(EPOCHS_GNN):
    gnn.train()
    opt_gnn.zero_grad()
    with amp.autocast():
        logits = gnn(node_emb, edge_index)
        loss = criterion_gnn(logits[train_mask], pyg_graph.y[train_mask])
    scaler_gnn.scale(loss).backward()
    torch.nn.utils.clip_grad_norm_(gnn.parameters(), max_norm=1.0)
    scaler_gnn.step(opt_gnn)
    scaler_gnn.update()
    gnn_losses.append(loss.item())

    gnn.eval()
    with torch.no_grad():
        logits = gnn(node_emb, edge_index)
        pred = logits[test_mask].argmax(dim=1).cpu().numpy()
        true = pyg_graph.y[test_mask].cpu().numpy()
        acc = accuracy_score(true, pred)
        gnn_accs.append(acc)

    if epoch % 5 == 0 or epoch == EPOCHS_GNN-1:
        logger.info(f"[GNN] Epoch {epoch+1}/{EPOCHS_GNN} | Loss: {loss.item():.4f} | Test Acc: {acc:.4f}")

    if (epoch + 1) % 5 == 0:
        ckpt = os.path.join(MODEL_SAVE_PATH, f"gnn_epoch{epoch+1}.pt")
        torch.save(gnn.state_dict(), ckpt)
        if SAVE_TO_DRIVE:
            torch.save(gnn.state_dict(), os.path.join(DRIVE_PATH, f"gnn_epoch{epoch+1}.pt"))

torch.save(gnn.state_dict(), os.path.join(MODEL_SAVE_PATH, "gnn_final.pt"))
if SAVE_TO_DRIVE:
    torch.save(gnn.state_dict(), os.path.join(DRIVE_PATH, "gnn_final.pt"))

# =============================
# Refine embeddings & recommend
# =============================
gnn.eval()
with torch.no_grad():
    refined_emb = gnn.embed(node_emb, edge_index)
refined_emb_np = refined_emb.cpu().numpy().astype(np.float32)
refined_emb_np = refined_emb_np / (np.linalg.norm(refined_emb_np, axis=1, keepdims=True) + 1e-12)

def recommend(query, top_k=5):
    enc = tokenizer([query], truncation=True, padding='max_length', max_length=MAX_LENGTH, return_tensors='pt')
    enc = {k: v.to(DEVICE) for k, v in enc.items()}
    with torch.no_grad():
        _, q_proj = bert_model(enc['input_ids'], enc['attention_mask'], project=True)
        q = q_proj.cpu().numpy().astype(np.float32)
        q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-12)
    sims = cosine_similarity(q, proj_emb_np)[0]
    top_idx = np.argsort(sims)[-top_k:][::-1]
    return [node_list[i] for i in top_idx], sims[top_idx]

def pretty_recommendation(query, ids, scores, df):
    df_lookup = df.set_index('id')
    lines = []
    for pid, sc in zip(ids, scores):
        row = df_lookup.loc[pid]
        lines.append(f"  • {row['title']}\n    Abstract: {row['abstract'][:180].rstrip()}...\n    Score: {sc:.4f}\n")
    return f"\nRecommendations for \"{query}\":\n" + "\n".join(lines)

# =============================
# SAVE EVERYTHING FOR PURE RECOMMENDATION
# =============================
save_dir = MODEL_SAVE_PATH


df_reindexed.to_pickle(os.path.join(save_dir, "df_reindexed.pkl"))


with open(os.path.join(save_dir, "node_list.pkl"), "wb") as f:
    pickle.dump(node_list, f)


torch.save(proj_emb_cpu, os.path.join(save_dir, "proj_emb_cpu.pt"))


torch.save({
    "edge_index": edge_index.cpu(),
    "num_nodes": G_aug.number_of_nodes()
}, os.path.join(save_dir, "graph_data.pt"))


torch.save(torch.from_numpy(refined_emb_np), os.path.join(save_dir, "refined_emb.pt"))


with open(os.path.join(save_dir, "label_encoder.pkl"), "wb") as f:
    pickle.dump(le, f)

logger.info(f"All inference artefacts saved to {save_dir}")

# =============================
# Final Summary & Demo
# =============================
query = "advancements in graph neural networks for computer vision"
rec_ids, rec_scores = recommend(query, top_k=5)

print("\n" + "="*70)
print(" " * 20 + "TRAINING SUMMARY")
print("="*70)
print("\nSciBERT Fine-tuning (LoRA r=8)")
for ep, (l, a) in enumerate(zip(bert_losses, bert_accs), 1):
    print(f"   Epoch {ep:2d} | Loss: {l:.4f} | Val Acc: {a:.4f}")

print("\nGNN (GraphSAGE) on SciBERT + KNN graph")
for ep in range(0, len(gnn_losses), 5):
    epoch = ep + 1
    print(f"   Epoch {epoch:2d} | Loss: {gnn_losses[ep]:.4f} | Test Acc: {gnn_accs[ep]:.4f}")
print(f"   Final     | Loss: {gnn_losses[-1]:.4f} | Test Acc: {gnn_accs[-1]:.4f}")
print("="*70)
print(pretty_recommendation(query, rec_ids, rec_scores, df_reindexed))
print("="*70)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  scaler = amp.GradScaler()


trainable params: 442,368 || all params: 110,360,832 || trainable%: 0.4008


  with amp.autocast():
BERT Epoch 1/43: 100%|██████████| 150/150 [04:20<00:00,  1.73s/it, loss=4.77]
  with torch.no_grad(), amp.autocast():
  with amp.autocast():
BERT Epoch 2/43: 100%|██████████| 150/150 [04:19<00:00,  1.73s/it, loss=4.45]
  with torch.no_grad(), amp.autocast():
  with amp.autocast():
BERT Epoch 3/43: 100%|██████████| 150/150 [04:49<00:00,  1.93s/it, loss=3.64]
  with torch.no_grad(), amp.autocast():
  with amp.autocast():
BERT Epoch 4/43: 100%|██████████| 150/150 [04:23<00:00,  1.75s/it, loss=2.98]
  with torch.no_grad(), amp.autocast():
  with amp.autocast():
BERT Epoch 5/43: 100%|██████████| 150/150 [04:22<00:00,  1.75s/it, loss=2.93]
  with torch.no_grad(), amp.autocast():
  with amp.autocast():
BERT Epoch 6/43: 100%|██████████| 150/150 [04:21<00:00,  1.75s/it, loss=2.31]
  with torch.no_grad(), amp.autocast():
  with amp.autocast():
BERT Epoch 7/43: 100%|██████████| 150/150 [04:20<00:00,  1.74s/it, loss=1.95]
  with torch.no_grad(), amp.autocast():
  with amp.au


                    TRAINING SUMMARY

SciBERT Fine-tuning (LoRA r=8)
   Epoch  1 | Loss: 4.7560 | Val Acc: 0.0017
   Epoch  2 | Loss: 4.6029 | Val Acc: 0.0816
   Epoch  3 | Loss: 4.0732 | Val Acc: 0.1849
   Epoch  4 | Loss: 3.1790 | Val Acc: 0.3931
   Epoch  5 | Loss: 2.4890 | Val Acc: 0.4705
   Epoch  6 | Loss: 2.0945 | Val Acc: 0.5065
   Epoch  7 | Loss: 1.8881 | Val Acc: 0.5257
   Epoch  8 | Loss: 1.7581 | Val Acc: 0.5479
   Epoch  9 | Loss: 1.6575 | Val Acc: 0.5692
   Epoch 10 | Loss: 1.5816 | Val Acc: 0.5780
   Epoch 11 | Loss: 1.5163 | Val Acc: 0.5926
   Epoch 12 | Loss: 1.4692 | Val Acc: 0.6035
   Epoch 13 | Loss: 1.4205 | Val Acc: 0.6056
   Epoch 14 | Loss: 1.3859 | Val Acc: 0.6136
   Epoch 15 | Loss: 1.3549 | Val Acc: 0.6161
   Epoch 16 | Loss: 1.3204 | Val Acc: 0.6211
   Epoch 17 | Loss: 1.2932 | Val Acc: 0.6290
   Epoch 18 | Loss: 1.2726 | Val Acc: 0.6278
   Epoch 19 | Loss: 1.2522 | Val Acc: 0.6332
   Epoch 20 | Loss: 1.2283 | Val Acc: 0.6399
   Epoch 21 | Loss: 1.2131 | V