In [1]:
import os
import csv
import re
import logging
import torch
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, SimilarityFunction
from typing import List, Dict, Tuple, Set

class SemanticDeduplicator:
    def __init__(self, 
                 model_name: str = "intfloat/multilingual-e5-base", 
                 exact_threshold: float = 0.98, 
                 semantic_threshold: float = 0.95,
                 logging_level: int = logging.INFO):
        """
        初始化语义去重器
        
        Args:
            model_name (str): 用于语义编码的模型名称
            exact_threshold (float): 精确去重阈值
            semantic_threshold (float): 语义去重阈值
            logging_level (int): 日志级别
        """
        # 配置日志
        logging.basicConfig(
            level=logging_level, 
            format='%(asctime)s - %(levelname)s: %(message)s'
        )
        self.logger = logging.getLogger(__name__)
        
        # 初始化参数
        self.exact_threshold = exact_threshold
        self.semantic_threshold = semantic_threshold
        
        # 加载模型
        self.logger.info(f"加载语义模型: {model_name}")
        self.model = SentenceTransformer(
            model_name, 
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )
        self.device = self.model.device

    def normalize_text(self, text: str) -> str:
        """
        规范化文本
        
        Args:
            text (str): 输入文本
        
        Returns:
            str: 规范化后的文本
        """
        if not isinstance(text, str):
            text = str(text)
        
        # 去除标点符号
        text = re.sub(r'[^\w\s]', '', text)
        
        # 转换为小写并去除首尾空白
        text = text.lower().strip()
        
        # 压缩多余空白
        text = re.sub(r'\s+', ' ', text)
        
        return text

    def find_semantic_duplicates(
        self, 
        sentences: List[str], 
        threshold: float = None
    ) -> Tuple[Set[int], List[Tuple[int, int, float]]]:
        """
        查找语义重复句子
        
        Args:
            sentences (List[str]): 句子列表
            threshold (float, optional): 相似度阈值
        
        Returns:
            Tuple[Set[int], List[Tuple[int, int, float]]]: 重复索引和重复对
        """
        threshold = threshold or self.semantic_threshold
        
        self.logger.info("开始编码句子...")
        embeddings = self.model.encode(
            sentences, 
            batch_size=128, 
            show_progress_bar=True,
            convert_to_tensor=True, # 方便GPU计算
            device=self.device
        )# return 2-diimentional tensor
        
        # 归一化嵌入向量
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        #p=2：使用l2范数归一化，方便计算余弦相似度，dim=1：按行进行归一化 
        
        self.logger.info("计算相似度矩阵...")
        sim_matrix = torch.mm(embeddings, embeddings.T) # 0-1之间的相似度矩阵
        
        duplicates = set()
        duplicate_pairs = []
        
        # 使用numpy加速计算（cpu中numpy计算速度更快）
        # 也可以选择numba和multiprocessing加速，后续有需要再修改
        sim_matrix_np = sim_matrix.cpu().numpy()
        
        #计算对角矩阵乘法减少计算量
        for i in tqdm(range(sim_matrix.shape[0]), desc="扫描句子对"):
            for j in range(i+1, sim_matrix.shape[1]):
                if sim_matrix_np[i, j] > threshold:
                    duplicates.add(j)
                    duplicate_pairs.append((i, j, sim_matrix_np[i, j].item()))#索引i，j，相似度
        
        return duplicates, duplicate_pairs

    def deduplicate_csv(
        self, 
        input_path: str, 
        output_path: str, 
        duplicate_report_path: str, 
        exact_dedup_path: str,
        ambiguity_column: str = '歧义句'
    ) -> Dict[str, int]:
        """
        对CSV文件进行去重处理，生成去重报告（包含精确去重和语义去重）。
        
        Args:
            input_path (str): 输入文件路径
            output_path (str): 输出文件路径
            duplicate_report_path (str): 重复报告路径
            exact_dedup_path (str): 精确去重结果路径
            ambiguity_column (str): 包含歧义句的列名
            
        Returns:
            Dict[str, int]: 处理统计信息
        """
        # 读取数据
        with open(input_path, 'r', encoding='utf-8-sig') as f:
            reader = csv.DictReader(f)#按照列名进行访问
            original_fieldnames = [col for col in reader.fieldnames if col != '是否删除']
            rows = list(reader)
        
        # 精确去重
        seen_exact = {}          # 记录已出现的规范化文本 {clean_text: 原始行索引}
        deduped_rows = []        # 去重后保留的行
        sentences = []           # 用于语义去重的规范化文本集合
        original_indices = []    # 保留行的原始索引
        exact_duplicates = []        # 精确去重结果（原格式）
        exact_duplicates_report = [] # 精确去重报告（新格式）
        
        self.logger.info("开始精确去重...")
        for idx, row in enumerate(tqdm(rows)):
            original_text = row[ambiguity_column]#包含歧义句的列
            clean_text = self.normalize_text(original_text)
            
            #先处理精确去重，如果同样的句子已经出现过，就不再进行语义去重
            #格式化的句子没有出现过，就加入到语义去重的句子集合中
            if clean_text not in seen_exact:
                seen_exact[clean_text] = idx
                sentences.append(clean_text)
                deduped_rows.append(row)
                original_indices.append(idx)
                # 记录保留行信息（原格式）
                exact_duplicates.append({
                    '原始行号': idx + 2,
                    '原始文本': original_text,
                    '规范化文本': clean_text
                })
            else:
                # 记录精确重复项（新格式）
                main_id = seen_exact[clean_text]
                exact_duplicates_report.append({
                    '主句子行号': main_id + 2,
                    '主句子内容': rows[main_id][ambiguity_column],
                    '重复句子行号': idx + 2,
                    '重复句子内容': original_text,
                    '去重方式': "精确去重"
                })

        # 写入精确去重结果（原格式）
        with open(exact_dedup_path, 'w', encoding='utf-8-sig', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=['原始行号', '原始文本', '规范化文本'])
            writer.writeheader()
            writer.writerows(exact_duplicates)

        
        duplicate_records = exact_duplicates_report.copy()

        # 语义去重
        semantic_duplicates, dup_pairs = self.find_semantic_duplicates(sentences)
        
        # 添加语义重复记录
        for i, j, similarity in dup_pairs:
            main_id = original_indices[i]
            dup_id = original_indices[j]
            duplicate_records.append({
                '主句子行号': main_id + 2,
                '主句子内容': rows[main_id][ambiguity_column],
                '重复句子行号': dup_id + 2,
                '重复句子内容': rows[dup_id][ambiguity_column],
                '去重方式': f"语义去重 (相似度: {similarity:.4f})"
            })
        
        # 写入重复报告（合并精确+语义）
        with open(duplicate_report_path, 'w', encoding='utf-8-sig', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=['主句子行号','主句子内容','重复句子行号','重复句子内容','去重方式'])
            writer.writeheader()
            writer.writerows(duplicate_records)
        
        # 标记删除项
        exact_deleted = set(range(len(rows))) - set(original_indices)
        semantic_deleted = {original_indices[j] for _, j, _ in dup_pairs}
        deleted_ids = exact_deleted.union(semantic_deleted)
        
        # 生成最终结果
        output_fieldnames = original_fieldnames + ['是否删除']
        with open(output_path, 'w', encoding='utf-8-sig', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=output_fieldnames)
            writer.writeheader()
            for idx, row in enumerate(rows):
                row['是否删除'] = '是' if idx in deleted_ids else '否'
                writer.writerow(row)
        
        # 统计信息
        stats = {
            '原始总行数': len(rows),
            '精确去重保留': len(exact_duplicates),
            '最终保留行数': len(rows) - len(deleted_ids),
            '总重复项': len(deleted_ids)
        }
        
        # 打印统计信息
        self.logger.info("\n===== 处理结果统计 =====")
        for key, value in stats.items():
            self.logger.info(f"{key}: {value}")
        
        self.logger.info(f"精确去重结果: {exact_dedup_path}")
        self.logger.info(f"重复报告: {duplicate_report_path}")
        
        return stats

def main():
    # 设定文件夹路径
    folder_path = r"d:\python\Coding\NLP"

    # 创建去重器实例
    deduplicator = SemanticDeduplicator(
        model_name="intfloat/multilingual-e5-base", 
        exact_threshold=0.98, 
        semantic_threshold=0.95,
        logging_level=logging.INFO
    )

    # 生成完整文件路径
    input_path = os.path.join(folder_path, r'D:\python\Coding\NLP\消歧\歧义句数据集原始.csv')
    output_path = os.path.join(folder_path, "去重后歧义句数据集.csv")
    duplicate_report_path = os.path.join(folder_path, "重复报告3.csv")
    exact_dedup_path = os.path.join(folder_path, "精确去重结果3.csv")

    # 运行去重处理
    deduplicator.deduplicate_csv(
        input_path,
        output_path,
        duplicate_report_path,
        exact_dedup_path
    )

if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm
2025-05-01 15:07:20,686 - INFO: 加载语义模型: intfloat/multilingual-e5-base
2025-05-01 15:07:20,689 - INFO: Load pretrained SentenceTransformer: intfloat/multilingual-e5-base
2025-05-01 15:07:23,640 - INFO: 开始精确去重...
100%|██████████| 1020/1020 [00:00<00:00, 338491.18it/s]
2025-05-01 15:07:23,650 - INFO: 开始编码句子...
Batches: 100%|██████████| 8/8 [00:04<00:00,  1.79it/s]
2025-05-01 15:07:28,118 - INFO: 计算相似度矩阵...
扫描句子对: 100%|██████████| 1001/1001 [00:00<00:00, 17057.57it/s]
2025-05-01 15:07:28,193 - INFO: 
===== 处理结果统计 =====
2025-05-01 15:07:28,193 - INFO: 原始总行数: 1020
2025-05-01 15:07:28,193 - INFO: 精确去重保留: 1001
2025-05-01 15:07:28,194 - INFO: 最终保留行数: 948
2025-05-01 15:07:28,194 - INFO: 总重复项: 72
2025-05-01 15:07:28,195 - INFO: 精确去重结果: d:\python\Coding\NLP\精确去重结果3.csv
2025-05-01 15:07:28,195 - INFO: 重复报告: d:\python\Coding\NLP\重复报告3.csv


In [2]:
import pandas as pd
import os

def extract_columns(input_file, output_file):
    """
    从输入文件中提取指定的列并保存到输出文件
    
    参数:
    input_file (str): 输入文件路径
    output_file (str): 输出文件路径
    """
    print(f"正在读取文件: {input_file}")
    
    # 确定文件类型并读取
    file_ext = os.path.splitext(input_file)[1].lower()
    
    try:
        if file_ext == '.csv':
            df = pd.read_csv(input_file)
        elif file_ext in ['.xlsx', '.xls']:
            df = pd.read_excel(input_file)
        else:
            print(f"不支持的文件格式: {file_ext}")
            return False
        
        print(f"成功读取数据，共 {len(df)} 行")
        
        # 需要保留的列
        keep_columns = [
            '歧义句',
            '歧义句及上下文',
            '歧义文本位置',
            '歧义原因（解读选项）',
            '歧义句消岐1',
            '歧义句消岐2',
            '歧义类型'
        ]
        
        # 检查列是否在数据集中存在
        missing_columns = [col for col in keep_columns if col not in df.columns]
        if missing_columns:
            print(f"警告：以下列在数据集中不存在: {', '.join(missing_columns)}")
            available_columns = [col for col in keep_columns if col in df.columns]
            if not available_columns:
                print("错误：没有找到任何需要保留的列")
                return False
            keep_columns = available_columns
            
        # 提取需要的列
        df_extracted = df[keep_columns]
        
        # 确定输出文件格式并保存
        output_ext = os.path.splitext(output_file)[1].lower()
        if output_ext == '.csv':
            df_extracted.to_csv(output_file, index=False, encoding='utf-8-sig')
        elif output_ext in ['.xlsx', '.xls']:
            df_extracted.to_excel(output_file, index=False)
        else:
            # 默认保存为CSV
            if not output_ext:
                output_file += '.csv'
            df_extracted.to_csv(output_file, index=False, encoding='utf-8-sig')
        
        print(f"成功提取数据并保存到: {output_file}")
        print(f"提取了 {len(df_extracted)} 行和 {len(keep_columns)} 列")
        return True
    except Exception as e:
        print(f"处理过程中出错: {str(e)}")
        return False

def main():
    """
    主函数，处理用户输入并调用相应的功能
    """
    print("数据集列提取工具")
    print("=" * 50)
    
    input_file = r'D:\python\Coding\NLP\去重后歧义句数据集.csv'
    output_file = r'D:\python\Coding\NLP\提取后的数据集.csv'
    
    if not os.path.exists(input_file):
        print(f"错误：文件 '{input_file}' 不存在!")
        return
    
    success = extract_columns(input_file, output_file)
    if success:
        print("处理完成!")
    else:
        print("处理失败，请检查错误信息。")

if __name__ == "__main__":
    main()

数据集列提取工具
正在读取文件: D:\python\Coding\NLP\去重后歧义句数据集.csv
成功读取数据，共 1020 行
成功提取数据并保存到: D:\python\Coding\NLP\提取后的数据集.csv
提取了 1020 行和 7 列
处理完成!
