In [None]:
import os
from tqdm import tqdm
import pandas as pd
import sys
sys.path.append("..")
from utils.split_text import *
def process_stock_news(from_dir, to_dir, method):
    if method == "cos_sim_spliter":
        from FlagEmbedding import FlagModel
        model = FlagModel('resources/open_models/bge-large-zh-v1.5', query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章：", use_fp16=True)
        spliter = BaseSpliter.use_subclass("cos_sim_spliter")(model)

    elif method == "doc_seg_model_spliter":
        from transformers import AutoModelForTokenClassification, AutoTokenizer
        model_name = 'resources/open_models/nlp_bert_document-segmentation_chinese-base'
        model = AutoModelForTokenClassification.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        spliter = BaseSpliter.use_subclass("doc_seg_model_spliter")(model, tokenizer)
    os.makedirs(to_dir,exist_ok=True)
    done_files = os.listdir(to_dir)
    for filename in os.listdir(from_dir):
        stock_id = filename.split('.')[0]
        if not filename.endswith(".csv") or stock_id+".json" in done_files: continue
        print(f"当前id：{stock_id}")
        filepath = os.path.join(from_dir, filename)
        # 读取CSV文件
        df = pd.read_csv(filepath)
        # 初始化一个空的DataFrame来存储结果
        result = []
        # 处理每行数据
        for index, row in tqdm(df.iterrows()):
                # 假设文件名格式为"000001.csv"
            date = row['Date']
            title = row['Title']
            content = row['Content']
            sentence_df = spliter.split_text_to_sentences(content)
            sentence_df = spliter.add_buffered_sentences(sentence_df)
            chunk_df = spliter.cluster(sentence_df)
            # 将结果合并到原始DataFrame中
            for _, sentence_row in chunk_df.iterrows():
                result.append({
                    'stock_id': stock_id,
                    'date': date,
                    'title': title,
                    'content': content,
                    'chunk': sentence_row['chunk'],
                    'start_idx': sentence_row['start_idx'],
                    'end_idx': sentence_row['end_idx']
                })
        print(f"stock_id {stock_id} chunked.")  # 打印或保存结果DataFrame
        # break
        pd.DataFrame(result).to_json(os.path.join(to_dir,f'{stock_id}.json'), force_ascii=False, orient='records', indent=2)


In [None]:
# 调用函数处理文件夹中的所有文件
process_stock_news('../data/raw/CSI300news','../data/cleaned/CSI300news_chunked', "doc_seg_model_spliter")