설명 & 커스터마이즈 포인트

Schema view: P->A, P->C로부터 평균 집계 후, 타입 어텐션으로 결합 → z_sc. (원 논문의 “자기 자신 mask” 개념 반영: paper 자기 피처는 쓰지 않고 이웃만 사용)

Meta-path view: P–A–P과 P–C–P 2-hop 논문 이웃을 각각 평균 집계 후 시맨틱 어텐션으로 결합 → z_mp.

Projection head: 두 view 공용(가중치 공유) MLP.

Loss: 배치 내 멀티-포지티브 InfoNCE. 각 쿼리 i의 포지티브는 메타패스 이웃(PAP, PCP의 union). POS_NUM_CAP로 포지티브 상한을 둘 수 있음.

Downstream: 보통 semantic 또는 fused 임베딩((z_sc+z_mp)/2)을 추천에 사용.

In [1]:
# =========================
# [Cell 1] 설정 / 경로 / 하이퍼파라미터
# =========================
from pathlib import Path

# 디렉토리
BASE_DIR = Path("/root/heco")
ART_DIR  = BASE_DIR / "artifacts"
OUT_DIR  = BASE_DIR / "results"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# 학습 하이퍼파라미터
SEED         = 42
EPOCHS       = 80
BATCH_SIZE   = 512
LR           = 0.005
WD           = 0.0001
EMBED_DIM    = 128      # encoder 출력 차원
PROJ_DIM     = 128      # projection head 출력 차원
TAU          = 0.6      # temperature
LAMBDA       = 0.5      # L = lambda * Lsc + (1-lambda)*Lmp
DROPOUT      = 0.3 
POS_NUM_CAP  = 20       # (선택) 각 노드의 positive 상한 (배치 내 교집합 기준)

# 메타패스 사용 (02에서 만든 것)
USE_PAP = True
USE_PCP = True

print("ART_DIR:", ART_DIR)
print("OUT_DIR:", OUT_DIR)


ART_DIR: /root/heco/artifacts
OUT_DIR: /root/heco/results


In [2]:
# =========================
# [Cell 2] 라이브러리 & 시드
# =========================
import json, math, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [3]:
# =========================
# [Cell 3] 아티팩트 로드 (피처/엣지/메타패스)
# =========================
def load_npz_E(path):
    with np.load(path) as data:
        if "E" in data: return data["E"]
        raise ValueError(f"{path} has no 'E' key.")

def load_npz_csr(path):
    with np.load(path) as data:
        return data["indptr"], data["indices"]

# meta.json
with open(ART_DIR/"meta.json","r") as f:
    meta = json.load(f)
num_papers   = meta["num_papers"]
num_authors  = meta["num_authors"]
num_concepts = meta["num_concepts"]

# Features
Xp = np.load(ART_DIR/"features_papers.npz")["X"].astype(np.float32)     # (N_papers, Dp)
Xa = np.load(ART_DIR/"features_authors.npz")["X"].astype(np.float32)    # (N_authors, Da)
Xc = np.load(ART_DIR/"features_concepts.npz")["X"].astype(np.float32)   # (N_concepts, Dc)

# Edges (schema용 단일 hop)
E_PA = load_npz_E(ART_DIR/"edges_PA.npz")   # paper -> author
E_PC = load_npz_E(ART_DIR/"edges_PC.npz")   # paper -> concept

# Build CSR adjacency: P->A, P->C
def build_csr(num_src, edges):
    src = edges[:,0]
    dst = edges[:,1]
    order = np.argsort(src, kind="mergesort")
    src = src[order]; dst = dst[order]
    counts = np.bincount(src, minlength=num_src)
    indptr = np.zeros(num_src+1, dtype=np.int64)
    indptr[1:] = np.cumsum(counts)
    indices = dst.astype(np.int64, copy=False)
    return indptr, indices

PA_indptr, PA_indices = build_csr(num_papers, E_PA)  # P->A
PC_indptr, PC_indices = build_csr(num_papers, E_PC)  # P->C

# Meta-path CSR
PAP_indptr, PAP_indices = (np.array([0]*(num_papers+1)), np.array([], dtype=np.int64))
PCP_indptr, PCP_indices = (np.array([0]*(num_papers+1)), np.array([], dtype=np.int64))
if USE_PAP:
    PAP_indptr, PAP_indices = load_npz_csr(ART_DIR/"metapath_PAP.npz")
if USE_PCP:
    PCP_indptr, PCP_indices = load_npz_csr(ART_DIR/"metapath_PCP.npz")

print("Xp:", Xp.shape, "Xa:", Xa.shape, "Xc:", Xc.shape)
print("P->A CSR:", PA_indptr.shape, PA_indices.shape)
print("P->C CSR:", PC_indptr.shape, PC_indices.shape)
print("PAP CSR  :", PAP_indptr.shape, PAP_indices.shape)
print("PCP CSR  :", PCP_indptr.shape, PCP_indices.shape)


Xp: (5000, 832) Xa: (32161, 32161) Xc: (6901, 768)
P->A CSR: (5001,) (78212,)
P->C CSR: (5001,) (164493,)
PAP CSR  : (5001,) (95466,)
PCP CSR  : (5001,) (24995000,)


In [4]:
# =========================
# [Cell 4] PyTorch 텐서 준비 & 배치 유틸
# =========================
# features
Xp_t = torch.from_numpy(Xp).to(device)  # papers
Xa_t = torch.from_numpy(Xa).to(device)  # authors
Xc_t = torch.from_numpy(Xc).to(device)  # concepts

# CSR to device (torch.int64)
PA_indptr_t = torch.from_numpy(PA_indptr).to(device)
PA_indices_t= torch.from_numpy(PA_indices).to(device)
PC_indptr_t = torch.from_numpy(PC_indptr).to(device)
PC_indices_t= torch.from_numpy(PC_indices).to(device)

PAP_indptr_t= torch.from_numpy(PAP_indptr).to(device)
PAP_indices_t=torch.from_numpy(PAP_indices).to(device)
PCP_indptr_t= torch.from_numpy(PCP_indptr).to(device)
PCP_indices_t=torch.from_numpy(PCP_indices).to(device)

paper_indices_all = np.arange(num_papers, dtype=np.int64)

def get_batch_indices(batch_size=BATCH_SIZE):
    # 무작위 셔플 미니배치
    perm = np.random.permutation(num_papers)
    for i in range(0, num_papers, batch_size):
        yield torch.from_numpy(perm[i:i+batch_size]).to(device)


In [5]:
# =========================
# [Cell 5] Schema / Meta-path 인코더 정의
#   - Schema: P->A, P->C 이웃 집계 + 타입 어텐션
#   - Meta-path: (PAP, PCP) 각각 집계 + 시맨틱 어텐션
# =========================
class TypeAttention(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.w = nn.Linear(in_dim, 1, bias=True)
    def forward(self, xs):  # list of [B, D]
        # 스택 후 [B, K, D] -> [B, K, 1] -> softmax
        X = torch.stack(xs, dim=1)              # [B, K, D]
        score = self.w(torch.tanh(X))           # [B, K, 1]
        alpha = torch.softmax(score, dim=1)     # [B, K, 1]
        out = (alpha * X).sum(dim=1)            # [B, D]
        return out, alpha.squeeze(-1)           # alpha: [B, K]

class SchemaEncoder(nn.Module):
    def __init__(self, d_paper_in, d_author_in, d_concept_in, d_out, dropout=0.0):
        super().__init__()
        self.pa_lin = nn.Linear(d_author_in, d_out)
        self.pc_lin = nn.Linear(d_concept_in, d_out)
        self.attn  = TypeAttention(d_out)
        self.dropout = nn.Dropout(dropout)

    @torch.no_grad()
    def gather_neighbors(self, indptr_t, indices_t, idx_batch):
        # indptr: [N+1], indices: [E], idx_batch: [B]
        starts = indptr_t[idx_batch]          # [B]
        ends   = indptr_t[idx_batch + 1]      # [B]
        lists = []
        for s,e in zip(starts.tolist(), ends.tolist()):
            if e <= s: lists.append([])
            else:      lists.append(indices_t[s:e].tolist())
        return lists  # list of python lists (neighbor indices)

    def agg_mean(self, X_table, lists, default_idx=None):
        # lists: length B; each is list of idx
        outs = []
        for nb in lists:
            if len(nb)==0:
                if default_idx is None:
                    outs.append(torch.zeros(X_table.size(1), device=X_table.device))
                else:
                    outs.append(X_table[default_idx])
            else:
                outs.append(X_table[nb].mean(dim=0))
        return torch.stack(outs, dim=0)  # [B, D_in]

    def forward(self, paper_idx_batch, PA_indptr, PA_indices, PC_indptr, PC_indices, Xa, Xc):
        # 이웃 리스트
        a_lists = self.gather_neighbors(PA_indptr, PA_indices, paper_idx_batch)  # P->A
        c_lists = self.gather_neighbors(PC_indptr, PC_indices, paper_idx_batch)  # P->C

        # 평균 집계 후 선형 -> dropout
        a_agg = self.pa_lin(self.agg_mean(Xa, a_lists))  # [B, d_out]
        c_agg = self.pc_lin(self.agg_mean(Xc, c_lists))  # [B, d_out]
        a_agg = self.dropout(F.elu(a_agg))
        c_agg = self.dropout(F.elu(c_agg))

        # 타입 어텐션으로 결합
        z_sc, alpha = self.attn([a_agg, c_agg])          # [B, d_out]
        z_sc = F.normalize(z_sc, p=2, dim=-1)
        return z_sc, alpha  # alpha: [B, 2]

class SemanticEncoder(nn.Module):
    def __init__(self, d_paper_in, d_out, dropout=0.0):
        super().__init__()
        self.pap_lin = nn.Linear(d_paper_in, d_out)
        self.pcp_lin = nn.Linear(d_paper_in, d_out)
        self.attn    = TypeAttention(d_out)
        self.dropout = nn.Dropout(dropout)

    @torch.no_grad()
    def gather_neighbors(self, indptr_t, indices_t, idx_batch):
        starts = indptr_t[idx_batch]
        ends   = indptr_t[idx_batch + 1]
        lists = []
        for s,e in zip(starts.tolist(), ends.tolist()):
            if e <= s: lists.append([])
            else:      lists.append(indices_t[s:e].tolist())
        return lists

    def agg_mean(self, X_paper, lists, default_idx=None):
        outs = []
        for nb in lists:
            if len(nb)==0:
                if default_idx is None:
                    outs.append(torch.zeros(X_paper.size(1), device=X_paper.device))
                else:
                    outs.append(X_paper[default_idx])
            else:
                outs.append(X_paper[nb].mean(dim=0))
        return torch.stack(outs, dim=0)

    def forward(self, paper_idx_batch, Xp, PAP_indptr, PAP_indices, PCP_indptr, PCP_indices):
        views = []
        if USE_PAP:
            pap_lists = self.gather_neighbors(PAP_indptr, PAP_indices, paper_idx_batch)
            pap_agg = self.pap_lin(self.agg_mean(Xp, pap_lists))
            pap_agg = self.dropout(F.elu(pap_agg))
            views.append(pap_agg)
        if USE_PCP:
            pcp_lists = self.gather_neighbors(PCP_indptr, PCP_indices, paper_idx_batch)
            pcp_agg = self.pcp_lin(self.agg_mean(Xp, pcp_lists))
            pcp_agg = self.dropout(F.elu(pcp_agg))
            views.append(pcp_agg)

        if len(views)==1:
            z_mp = F.normalize(views[0], p=2, dim=-1)
            alpha = torch.ones(z_mp.size(0), 1, device=z_mp.device)
            return z_mp, alpha
        else:
            z_mp, alpha = self.attn(views)           # [B, d_out]
            z_mp = F.normalize(z_mp, p=2, dim=-1)
            return z_mp, alpha  # alpha: [B, K(meta-path 개수)]


In [6]:
# =========================
# [Cell 6] Projection Head & 대조 손실(멀티 양성)
# =========================
class ProjectionHead(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.fc1 = nn.Linear(d_in, d_out)
        self.act = nn.ELU()
        self.fc2 = nn.Linear(d_out, d_out)
    def forward(self, z):
        return self.fc2(self.act(self.fc1(z)))

def cosine_sim(a, b):
    # a: [B, D], b: [B, D] or [N, D]
    a = F.normalize(a, p=2, dim=-1)
    b = F.normalize(b, p=2, dim=-1)
    return a @ b.t()

@torch.no_grad()
def build_pos_mask(batch_idx, PAP_indptr, PAP_indices, PCP_indptr, PCP_indices, pos_cap=POS_NUM_CAP):
    """
    배치 내 (i,j)가 positive인지 bool mask 생성.
    positive 기준: i의 메타패스(PAP,PCP) 이웃 안에 j가 포함되면 True.
    """
    B = batch_idx.size(0)
    mask = torch.zeros(B, B, dtype=torch.bool, device=batch_idx.device)
    idx_list = batch_idx.tolist()
    idx_pos_sets = []
    for u in idx_list:
        # 두 메타패스 이웃 union
        pos_set = set()
        # PAP
        s = PAP_indptr[u].item(); e = PAP_indptr[u+1].item()
        if e > s: pos_set.update(PAP_indices[s:e].tolist())
        # PCP
        s = PCP_indptr[u].item(); e = PCP_indptr[u+1].item()
        if e > s: pos_set.update(PCP_indices[s:e].tolist())
        # 상한 (옵션)
        if (pos_cap is not None) and (len(pos_set) > pos_cap):
            # 랜덤 서브샘플
            chosen = np.random.choice(list(pos_set), size=pos_cap, replace=False).tolist()
            pos_set = set(chosen)
        idx_pos_sets.append(pos_set)

    # 배치 내 교집합 매핑
    for i, u in enumerate(idx_list):
        pos_set = idx_pos_sets[i]
        for j, v in enumerate(idx_list):
            if i==j: 
                continue  # self는 제외(원하면 포함 가능)
            if v in pos_set:
                mask[i, j] = True
    return mask  # [B,B] bool


In [7]:
# =========================
# [Cell 7] 모델 초기화
# =========================
d_paper_in   = Xp_t.size(1)
d_author_in  = Xa_t.size(1)
d_concept_in = Xc_t.size(1)

schema_enc = SchemaEncoder(d_paper_in, d_author_in, d_concept_in, d_out=EMBED_DIM, dropout=DROPOUT).to(device)
sem_enc    = SemanticEncoder(d_paper_in, d_out=EMBED_DIM, dropout=DROPOUT).to(device)
proj_head  = ProjectionHead(EMBED_DIM, PROJ_DIM).to(device)

params = list(schema_enc.parameters()) + list(sem_enc.parameters()) + list(proj_head.parameters())
opt = torch.optim.AdamW(params, lr=LR, weight_decay=WD)

print(f"paper_in={d_paper_in}, author_in={d_author_in}, concept_in={d_concept_in}, embed={EMBED_DIM}, proj={PROJ_DIM}")


paper_in=832, author_in=32161, concept_in=768, embed=128, proj=128


In [8]:
# =========================
# [Cell 8] 학습 루프
#  - cross-view InfoNCE (멀티-포지티브)
# =========================
def infoNCE_multi_pos(z_q, z_k, pos_mask, tau=TAU, eps=1e-8):
    """
    z_q: [B, D], z_k: [B, D]  (batch 기준)
    pos_mask: [B, B]  (i row의 positives: j where mask[i,j]=True)
    """
    B = z_q.size(0)
    logits = cosine_sim(z_q, z_k) / tau             # [B,B]
    # 마스크: 자기 자신 제외
    diag = torch.eye(B, dtype=torch.bool, device=z_q.device)
    logits = logits.masked_fill(diag, float('-inf'))

    # 안정성: row-wise log-sum-exp
    logsumexp = torch.logsumexp(logits, dim=1)      # [B]

    # 양성 쪽: 각 i에 대해 pos j들의 log-sum-exp (여러 양성 합)
    pos_logits = logits.masked_fill(~pos_mask, float('-inf'))
    # pos가 하나도 없으면 -inf -> exp 0 -> log 0 = -inf -> loss NaN 방지용 epsilon
    pos_logsumexp = torch.logsumexp(pos_logits, dim=1)  # [B]
    # 없으면 (모두 -inf) -> -inf 로 남음. eps 보정
    pos_logsumexp = torch.where(torch.isneginf(pos_logsumexp),
                                torch.full_like(pos_logsumexp, math.log(eps)),
                                pos_logsumexp)

    loss = -(pos_logsumexp - logsumexp).mean()
    return loss

best_loss = float('inf')
for epoch in range(1, EPOCHS+1):
    schema_enc.train(); sem_enc.train(); proj_head.train()
    epoch_loss = 0.0
    steps = 0
    for batch_idx in get_batch_indices(BATCH_SIZE):
        opt.zero_grad()
        # encodings
        z_sc, _ = schema_enc(batch_idx, PA_indptr_t, PA_indices_t, PC_indptr_t, PC_indices_t, Xa_t, Xc_t)
        z_mp, _ = sem_enc(batch_idx, Xp_t, PAP_indptr_t, PAP_indices_t, PCP_indptr_t, PCP_indices_t)
        # projection (공유)
        z_sc_p = proj_head(z_sc)
        z_mp_p = proj_head(z_mp)
        # positives mask (배치 내)
        pos_mask = build_pos_mask(batch_idx, PAP_indptr_t, PAP_indices_t, PCP_indptr_t, PCP_indices_t, pos_cap=POS_NUM_CAP)
        # loss (양방향)
        L_sc = infoNCE_multi_pos(z_sc_p, z_mp_p, pos_mask, tau=TAU)
        L_mp = infoNCE_multi_pos(z_mp_p, z_sc_p, pos_mask, tau=TAU)
        loss = LAMBDA * L_sc + (1.0 - LAMBDA) * L_mp

        loss.backward()
        nn.utils.clip_grad_norm_(params, max_norm=5.0)
        opt.step()

        epoch_loss += loss.item()
        steps += 1

    avg = epoch_loss / max(1, steps)
    if avg < best_loss:
        best_loss = avg
    print(f"[Epoch {epoch:03d}] loss={avg:.4f} (best {best_loss:.4f})")


[Epoch 001] loss=11.7585 (best 11.7585)
[Epoch 002] loss=12.0050 (best 11.7585)
[Epoch 003] loss=12.0323 (best 11.7585)
[Epoch 004] loss=11.7916 (best 11.7585)
[Epoch 005] loss=11.7887 (best 11.7585)
[Epoch 006] loss=11.7827 (best 11.7585)
[Epoch 007] loss=12.0184 (best 11.7585)
[Epoch 008] loss=11.7674 (best 11.7585)
[Epoch 009] loss=11.7970 (best 11.7585)
[Epoch 010] loss=11.8288 (best 11.7585)
[Epoch 011] loss=11.7372 (best 11.7372)
[Epoch 012] loss=11.7484 (best 11.7372)
[Epoch 013] loss=11.7402 (best 11.7372)
[Epoch 014] loss=11.7603 (best 11.7372)
[Epoch 015] loss=12.2420 (best 11.7372)
[Epoch 016] loss=12.0197 (best 11.7372)
[Epoch 017] loss=11.9637 (best 11.7372)
[Epoch 018] loss=11.8169 (best 11.7372)
[Epoch 019] loss=11.9853 (best 11.7372)
[Epoch 020] loss=11.7792 (best 11.7372)
[Epoch 021] loss=11.9113 (best 11.7372)
[Epoch 022] loss=11.8868 (best 11.7372)
[Epoch 023] loss=11.8964 (best 11.7372)
[Epoch 024] loss=11.8030 (best 11.7372)
[Epoch 025] loss=11.7706 (best 11.7372)


In [9]:
# =========================
# [Cell 9] 최종 임베딩 추출 & 저장
# =========================
schema_enc.eval(); sem_enc.eval(); proj_head.eval()

# 전체를 배치로 나눠 인코딩
def encode_all():
    Z_sc_list, Z_mp_list = [], []
    with torch.no_grad():
        for i in range(0, num_papers, BATCH_SIZE):
            batch_idx = torch.arange(i, min(i+BATCH_SIZE, num_papers), device=device)
            z_sc, _ = schema_enc(batch_idx, PA_indptr_t, PA_indices_t, PC_indptr_t, PC_indices_t, Xa_t, Xc_t)
            z_mp, _ = sem_enc(batch_idx, Xp_t, PAP_indptr_t, PAP_indices_t, PCP_indptr_t, PCP_indices_t)
            Z_sc_list.append(z_sc)
            Z_mp_list.append(z_mp)
    Z_sc = torch.cat(Z_sc_list, dim=0)    # [N, D]
    Z_mp = torch.cat(Z_mp_list, dim=0)    # [N, D]
    # 보통 downstream은 meta-path view(또는 concat) 사용
    Z_fuse = F.normalize((Z_sc + Z_mp)/2, p=2, dim=-1)
    return Z_sc, Z_mp, Z_fuse

Z_sc, Z_mp, Z = encode_all()

np.save(OUT_DIR/"paper_embeddings_schema.npy", Z_sc.detach().cpu().numpy())
np.save(OUT_DIR/"paper_embeddings_semantic.npy", Z_mp.detach().cpu().numpy())
np.save(OUT_DIR/"paper_embeddings_fused.npy",    Z.detach().cpu().numpy())

print("Saved:",
      OUT_DIR/"paper_embeddings_schema.npy",
      OUT_DIR/"paper_embeddings_semantic.npy",
      OUT_DIR/"paper_embeddings_fused.npy")


Saved: /root/heco/results/paper_embeddings_schema.npy /root/heco/results/paper_embeddings_semantic.npy /root/heco/results/paper_embeddings_fused.npy


In [10]:
# =========================
# [Cell 10] (선택) 근접 탐색 샘플
# =========================
import pandas as pd

paper_ids = pd.read_csv(ART_DIR/"map_paper_id.csv")["paper_id"].tolist()
Z_np = np.load(OUT_DIR/"paper_embeddings_fused.npy")

def topk_neighbors(i, k=5):
    v = Z_np[i]
    sims = Z_np @ v
    sims[i] = -1e9
    idx = np.argpartition(-sims, k)[:k]
    idx = idx[np.argsort(-sims[idx])]
    return idx, sims[idx]

test_idx = 0  # 예시
idxs, scs = topk_neighbors(test_idx, k=5)
print("query paper_idx:", test_idx, "paper_id:", paper_ids[test_idx] if test_idx < len(paper_ids) else "NA")
print("top-5 idx:", idxs)
print("top-5 id :", [paper_ids[i] for i in idxs])
print("sims     :", [float(s) for s in scs])


query paper_idx: 0 paper_id: https://openalex.org/W3010906965
top-5 idx: [4517 4382 3820  112  525]
top-5 id : ['https://openalex.org/W2936033307', 'https://openalex.org/W3177009667', 'https://openalex.org/W3048565185', 'https://openalex.org/W3020097213', 'https://openalex.org/W3128551684']
sims     : [0.9879460334777832, 0.9831452369689941, 0.9826983213424683, 0.9824612140655518, 0.9810717105865479]
