In [1]:
from pathlib import Path
import sys

ROOT = Path().resolve().parents[1] / "code"
sys.path.append(str(ROOT))
import os
import polars as pl
from functions.read_news import read_news
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import numpy as np
import pickle


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def build_news_embeddings(news_path,save_dir,model_path='/home/ming/GraduateProject/model/all-mpnet-base-v2',batch_size=64):
    os.makedirs(save_dir, exist_ok=True)
    if news_path:
        news=pl.read_parquet(news_path).select(
            ['news_id','title','abstract']
        )
    else:
        news=read_news().select(['new_id','title','abstract'])
    texts=(
        news['title'].fill_null("")+" "+news['abstract'].fill_null("")
    ).to_list()

    model=SentenceTransformer(model_path,device="cuda")

    emb=model.encode(texts,batch_size=batch_size,show_progress_bar=True,normalize_embeddings=True)

    article_ids=news['news_id'].to_list()

    np.save(os.path.join(save_dir,'news_emb.npy'),emb)
    with open(os.path.join(save_dir,'news_ids.pkl'),'wb') as f:
        pickle.dump(article_ids,f)
    print("Saved news embeddings:", emb.shape)

In [5]:
build_news_embeddings(
    news_path='/home/ming/GraduateProject/Data/news.parquet',
    save_dir="/home/ming/GraduateProject/Data/content_emb"
)

Batches: 100%|██████████| 1628/1628 [01:30<00:00, 18.07it/s]


Saved news embeddings: (104151, 768)


In [4]:
from cuml.neighbors import NearestNeighbors
import cupy as cp

In [5]:
import numpy as np
import pickle
embeddings = np.load("/home/ming/GraduateProject/Data/content_emb/news_emb.npy")
with open("/home/ming/GraduateProject/Data/content_emb/news_ids.pkl", "rb") as f:
    article_ids = pickle.load(f)
embeddings= cp.asarray(embeddings)

In [6]:
knn=NearestNeighbors(n_neighbors=51,metric='cosine') # 最近邻查找，每个样本返回21个（包括自己）
knn.fit(embeddings)

In [9]:
_,aid_nns=knn.kneighbors(embeddings)
aid_nns=aid_nns[:,1:]
if hasattr(aid_nns, "get"):
    aid_nns = aid_nns.get()
aid_nns

array([[ 52566,  98492,  85713, ...,  88761,  53247,  32211],
       [ 60199,  51812,  80970, ...,  96786,  53322,  96275],
       [  1598,    752,  39145, ...,  72342,  33442,   8059],
       ...,
       [ 87700,  27671,  53866, ...,  86730,  39890,   3249],
       [ 95757,  85122,  19645, ...,  59065,  13007,  77024],
       [ 40184, 100577,  45128, ..., 103189,   4260,  69751]],
      shape=(104151, 50))

In [10]:
news_sim_dict = {
    article_ids[i]: [article_ids[int(j)] for j in aid_nns[i]]
    for i in range(len(article_ids))
}

del aid_nns

In [11]:
with open("recall_results/news_content_sim.pkl", "wb") as f:
    pickle.dump(news_sim_dict, f)