In [None]:
from sentence_transformers import SentenceTransformer
import faiss
import orjson
import json
import warnings

warnings.filterwarnings("ignore", message="IProgress not found.*")

In [None]:
model = SentenceTransformer("all-mpnet-base-v2", device="cuda")
model.half()

index = faiss.read_index("news_mpnet.index")
print("Number of vectors in the index:", index.ntotal)

In [None]:
sentences = []
with open("tdt.jsonl", 'rb') as fin:
    for line in fin:
        record = orjson.loads(line)
        sentences.append({"text": record["text"], "date": record["date"]})

In [None]:
da_labels = ["xx", ...]

In [None]:
da_texts = []

with open("train.jsonl", 'r', encoding='utf-8') as file:
    for line in file:
        data = orjson.loads(line.strip())
        text = data['text']
        labels = data['label']
        title = data['title']
        date = "/".join(data['date'].split("-"))

        labels_needed = []
        for label in labels:
            if label in da_labels:
                labels_needed.append(label)

        if labels_needed:
            da_texts.append({"labels": labels_needed, "title": title, "date": date, "text": text})

In [None]:
unique_da_texts = []
seen = set()
for item in da_texts:
    if item["text"] not in seen:
        seen.add(item["text"])
        unique_da_texts.append(item)

In [None]:
with open("query_data.jsonl", "w") as f:
    for item in unique_da_texts:
        json.dump(item, f, ensure_ascii=False)
        f.write("\n")

In [None]:
# 1. 扁平化数据并记录标签
all_texts = []
labels = []
dates = []
for item in unique_da_texts:
    all_texts.append(item["text"])
    labels.append(item["labels"])
    dates.append(item["date"])

In [None]:
batch_size = 1024  # Adjust according to GPU memory
query_embeddings = model.encode(
    all_texts,
    batch_size=batch_size,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True,
)
print("Query embedding dimensions:", query_embeddings.shape)

In [None]:
k = 10
distances, indices = index.search(query_embeddings, k)
print("Search completed, distances shape:", distances.shape)

In [None]:
results = []

for i in range(len(all_texts)):
    query = all_texts[i]
    date = dates[i]
    label = labels[i]
    for j in range(k):
        if distances[i][j] >= 0.8:
            results.append({
                "scores": float(distances[i][j]),
                "labels": label,
                "query_date": date,
                "follow_date": sentences[indices[i][j]]["date"],
                "query": query,
                "follow": sentences[indices[i][j]]["text"],
            })


print("Found", len(results), "results")

In [None]:
deduped_data = []
seen = set()

for data in results:
    follow = data.get("follow")
    if follow and follow not in seen:
        seen.add(follow)
        deduped_data.append(data)

In [None]:
from datetime import datetime

def is_within_two_years(query_date_str, follow_date_str):
    query_date = datetime.strptime(query_date_str, "%Y/%m/%d")
    follow_date = datetime.strptime(follow_date_str, "%Y/%m/%d")
    delta = (follow_date - query_date).days
    return delta <= 3 * 365 and delta > 0

filtered_results = [record for record in deduped_data if is_within_two_years(record["query_date"], record["follow_date"])]

In [None]:
import json

with open("results.jsonl", 'w', encoding='utf-8') as file:
  for result in filtered_results:
    json.dump(result, file, ensure_ascii=False)
    file.write('\n')