In [1]:
import os
import sys
if ".." not in sys.path:
    sys.path.append("..")
    os.chdir("..")

import gc
import datasets
import json
import string
import torch
import random
import numpy as np
from glob import glob

from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from safetensors.torch import load_file, save_file, safe_open
from src import Data, Metrics, ModelArgs, get_model_and_tokenizer

args = ModelArgs()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_dir = "/mnt/bn/search-douyin-rank-yg/all_data_from_lf/text_embedding_data/query_doc_info_sample_order_1028_v4"

data_files = sum([glob(f"{base_dir}/part-{i:05d}-*.parquet") for i in range(2)], [])
dataset = datasets.load_dataset("parquet", data_files=data_files, split="train")

In [6]:
from multiprocessing import cpu_count
from datasets import Dataset
from concurrent.futures import ProcessPoolExecutor, as_completed
from glob import glob
import json
import os
import random
from tqdm import tqdm


def collate_query_pos_neg_from_impression(impression, doc_as_query_portion: float = 0):
    # query
    query = impression[0]["query"]
    assert all([i["query"] == query for i in impression])
    search_id = impression[0]["search_id"]

    # click
    strong_pos = []
    pos = []
    neg = []
    strong_neg = []
    strong_pos_position = []

    neg_candidates = []
    for i, x in enumerate(impression):
        text = x["doc_info"]
        doc_id = x["doc_id"]

        if x["search_result_click_cnt"] > 0:
            strong_pos.append((text, doc_id))
            strong_pos_position.append(x["position"])
        elif x["play_time_max"] > 10000:
            pos.append((text, doc_id))
        else:
            neg_candidates.append(x)

    if len(strong_pos) < 1:
        return None

    min_click_position = min(strong_pos_position)
    for i, x in enumerate(neg_candidates):
        text = x["doc_info"]
        doc_id = x["doc_id"]
        position = x["position"]

        if x["play_time_max"] < 3000:
            if position < min_click_position:
                strong_neg.append((text, doc_id))
            else:
                neg.append((text, doc_id))

    if doc_as_query_portion > 0 and random.uniform(0, 1) <= doc_as_query_portion:
        query, _ = strong_pos.pop(0)
        result_pos = strong_pos + pos
        result_neg = strong_neg + neg
    else:
        result_pos = strong_pos + pos
        result_neg = strong_neg + neg

    if len(result_pos) < 1 or len(result_neg) < 1:
        return None

    result_pos = result_pos[:3]
    result_neg = result_neg[:3]

    return {
        "search_id": search_id,
        "query": query,
        "pos": [p[0] for p in result_pos],
        "neg": [n[0] for n in result_neg],
        "pos_ids": [p[1] for p in result_pos],
        "neg_ids": [n[1] for n in result_neg]
    }


def process_impression_data(batch, doc_as_query_portion):
    results = []
    for impression in batch:
        result = collate_query_pos_neg_from_impression(impression, doc_as_query_portion)
        if result is not None:
            results.append(result)
    return results


def generate_impressions(dataset, doc_as_query_portion, num_workers, max_batches_in_memory=10):
    all_batches = []
    current_batch = []
    prev_search_id = None
    futures = []

    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        for x in tqdm(dataset):
            search_id = x["search_id"]
            if search_id != prev_search_id and prev_search_id is not None:
                # Impression完成，加入当前batch
                all_batches.append(current_batch)
                current_batch = []

                # 如果达到阈值，提交任务
                if len(all_batches) >= max_batches_in_memory:
                    futures.append(executor.submit(process_impression_data, all_batches, doc_as_query_portion))
                    all_batches = []  # 清空缓存

            current_batch.append(x)
            prev_search_id = search_id

        # 处理剩余数据
        if current_batch:
            all_batches.append(current_batch)
        if all_batches:
            futures.append(executor.submit(process_impression_data, all_batches, doc_as_query_portion))

        # 收集任务结果
        for future in tqdm(as_completed(futures)):
            for result in future.result():
                yield result

random.seed(0)

num_workers = cpu_count()
# num_workers = 1
# print(num_workers)
new_dataset = Dataset.from_generator(lambda: generate_impressions(dataset, 0, num_workers, max_batches_in_memory=num_workers * 2))

100%|██████████| 490891/490891 [01:00<00:00, 8151.52it/s]
81it [00:05, 14.49it/s] 19000 examples [01:05, 3017.52 examples/s]
Generating train split: 19441 examples [01:06, 290.27 examples/s] 


In [4]:
new_dataset

Dataset({
    features: ['search_id', 'query', 'pos', 'neg', 'pos_ids', 'neg_ids'],
    num_rows: 19441
})