In [None]:
import os
import sys
import warnings
import subprocess
import multiprocessing
from pathlib import Path
import psutil

os.environ['OMP_NUM_THREADS'] = str(multiprocessing.cpu_count())
os.environ['MKL_NUM_THREADS'] = str(multiprocessing.cpu_count()) 
os.environ['NUMEXPR_NUM_THREADS'] = str(multiprocessing.cpu_count())
os.environ['TOKENIZERS_PARALLELISM'] = 'true'  

warnings.filterwarnings('ignore')

try:
    multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
    pass

def check_system_info():
    print("系统信息检测")
    print("=" * 50)
    
    cpu_count = multiprocessing.cpu_count()
    cpu_freq = psutil.cpu_freq()
    memory = psutil.virtual_memory()
    
    
    try:
        import torch
        torch.cuda.is_available():
        gpu_count = torch.cuda.device_count()
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        
        print("未检测到CUDA GPU，将使用CPU模式")
    except:
        print("PyTorch未安装，稍后自动安装")
    
    print("=" * 50)

check_system_info()

def install_requirements():
    """安装系统依赖"""
    required_packages = [
        'torch>=2.0.0',
        'transformers>=4.35.0', 
        'peft>=0.6.0',
        'sentence-transformers>=2.2.0',
        'faiss-cpu>=1.7.4',
        'pandas>=1.5.0',
        'numpy>=1.21.0',
        'scikit-learn>=1.3.0',
        'matplotlib>=3.5.0',
        'seaborn>=0.11.0',
        'tqdm>=4.64.0',
        'accelerate>=0.24.0',
        'openpyxl>=3.0.0', 
        'xlrd>=2.0.0'     
    ]
    
    print("\n检查并安装依赖包...")
    for package in required_packages:
        try:
            __import__(package.split('>=')[0].split('==')[0])
        except ImportError:
            print(f"安装 {package}")
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])

install_requirements()

print("环境检测完成！")
print(f"已配置 {multiprocessing.cpu_count()} 线程并行处理")


In [None]:
import torch
import pandas as pd
import numpy as np
import json
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import multiprocessing
import psutil

CPU_COUNT = multiprocessing.cpu_count()
MEMORY_GB = psutil.virtual_memory().total / (1024**3)
GPU_AVAILABLE = torch.cuda.is_available()
GPU_MEMORY_GB = torch.cuda.get_device_properties(0).total_memory / (1024**3) if GPU_AVAILABLE else 0


CONFIG_TYPE = "balanced"

LINUX_CONFIG = {
    # 基础配置
    "config_type": CONFIG_TYPE,
    "use_gpu": GPU_AVAILABLE,
    "device": "cuda" if GPU_AVAILABLE else "cpu",
    
    # 多线程配置
    "num_workers": min(CPU_COUNT, 8),  # 限制最大线程数避免过载
    "prefetch_factor": 2,
    "pin_memory": GPU_AVAILABLE,
    "thread_pool_size": CPU_COUNT // 2,
    "process_pool_size": min(CPU_COUNT // 2, 4),
    
    # 模型配置 
    "model_configs": {
        "balanced": {
            "model_name": "ShengbinYue/LawLLM-7B", 
            "max_length": 1024,
            "batch_size": 2,
            "use_4bit": True,
            "cpu_offload": False,
        },
    },
    
    # RAG配置
    "rag_config": {
        "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
        "top_k": 5,
        "similarity_threshold": 0.7,
        "chunk_size": 256,
        "chunk_overlap": 50,
        "index_type": "faiss",
    },
    
    # LoRA配置
        "lora_config": {
        "r": 8,
        "lora_alpha": 32,
        "lora_dropout": 0.1,
        "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
        "bias": "none",
        "task_type": "CAUSAL_LM",
    },
     
     # 训练配置
    "training_config": {
        "learning_rate": 5e-5,
        "num_epochs": 2,
        "warmup_steps": 50,
        "gradient_accumulation_steps": 2,
        "dataloader_num_workers": 0,
        "fp16": GPU_AVAILABLE,
        "gradient_checkpointing": True,
    }
}

# 显示当前配置
current_config = LINUX_CONFIG["model_configs"][CONFIG_TYPE]
print(f"\n已选择配置: {CONFIG_TYPE}")
print(f"模型: {current_config['model_name']}")
print(f"批处理大小: {current_config['batch_size']}")
print(f"最大长度: {current_config['max_length']}")
print(f"多线程数: {LINUX_CONFIG['num_workers']}")
print(f"4bit量化: {current_config['use_4bit']}")
print(f"CPU卸载: {current_config['cpu_offload']}")

THREAD_POOL = ThreadPoolExecutor(max_workers=LINUX_CONFIG['thread_pool_size'])
PROCESS_POOL = None


In [None]:
import pandas as pd
import re
from typing import List, Dict, Any
from concurrent.futures import as_completed
import time
from tqdm import tqdm

class LegalDataProcessor:
    
    def __init__(self, config=LINUX_CONFIG):
        self.config = config
        self.data = None
        self.processed_data = None
        
    def _create_sample_data(self):
        sample_data = {
            '纠纷条款原文': ["公司有权根据经营需要调整员工的工作岗位。", "员工应保证24小时随叫随到，否则视为旷工。"],
            '条款缺陷分析': ["该条款过于宽泛，可能被认定为无效。未明确调整岗位的合理性、协商程序等，容易引发争议。", "该要求严重违反劳动法关于工作时间的规定，侵害员工休息权，属于无效条款。"],
            '违约/纠纷案例': ["员工因拒绝公司无理由调岗被辞退，后起诉公司，法院判决公司违法解除劳动合同。", "某程序员因拒绝半夜到公司处理非紧急事务被辞退，仲裁裁定公司支付赔偿金。"],
            '判决结果': ["法院认定公司调岗不具有合理性，属于违法解除，应支付赔偿金。", "仲裁委认定公司要求24小时待命违法，构成违法解除。"],
            '法律依据': ["《劳动合同法》第三十五条、第四十条", "《劳动法》第三十八条、第四十一条"]
        }
        return pd.DataFrame(sample_data)

    def load_data(self, file_path="./data.xlsx", sheet_name=None):
        print("加载劳动合规数据集...")
        
        try:
            if file_path.endswith('.xlsx') or file_path.endswith('.xls'):
                print(f"检测到Excel文件: {file_path}")
                
                if sheet_name is None:
                    excel_file = pd.ExcelFile(file_path)
                    sheet_names = excel_file.sheet_names
                    print(f"发现工作表: {sheet_names}")
                    
                    # 选择第一个工作表，或者包含"data"关键词的工作表
                    for name in sheet_names:
                        if 'data' in name.lower() or '数据' in name or '案例' in name:
                            sheet_name = name
                            break
                    
                    if sheet_name is None:
                        sheet_name = sheet_names[0]  # 默认使用第一个工作表
                    
                    print(f"使用工作表: {sheet_name}")
                
                # 读取Excel文件
                self.data = pd.read_excel(
                    file_path, 
                    sheet_name=sheet_name,
                    engine='openpyxl'  # 使用openpyxl引擎
                )
                
            elif file_path.endswith('.csv'):
                # CSV文件处理（兼容性保留）
                print(f"检测到CSV文件: {file_path}")
                self.data = pd.read_csv(file_path)
            else:
                raise ValueError(f"不支持的文件格式: {file_path}")
            
            print(f"成功加载 {len(self.data)} 条劳动合规案例")
            
            # 显示数据概览
            print("\n数据集概览:")
            print(f"列名: {list(self.data.columns)}")
            print(f"数据形状: {self.data.shape}")
            
            # 检查关键列
            required_columns = ['纠纷条款原文', '条款缺陷分析', '违约/纠纷案例', '判决结果', '法律依据']
            missing_columns = [col for col in required_columns if col not in self.data.columns]
            
            if missing_columns:
                print(f"缺失列: {missing_columns}")
            else:
                print("所有必需列都存在")
                
            return self.data
            
        except Exception as e:
            print(f"数据加载失败: {e}")
            print("请检查文件路径和格式:")
            print("   - Excel文件: ./data.xlsx")  
            print("   - CSV文件: ./data.csv")
            print("   - 确保文件包含必要的列: 纠纷条款原文、条款缺陷分析、违约/纠纷案例、判决结果、法律依据")
            # 创建示例数据用于演示
            self.data = self._create_sample_data()
            return self.data
    
    def analyze_excel_structure(self, file_path="./data.xlsx"):
        """分析Excel文件结构"""
        try:
            print(f"分析Excel文件结构: {file_path}")
            
            # 读取Excel文件信息
            excel_file = pd.ExcelFile(file_path)
            sheet_names = excel_file.sheet_names
            
            print(f"工作表数量: {len(sheet_names)}")
            
            # 分析每个工作表
            for i, sheet_name in enumerate(sheet_names):
                print(f"\n工作表 {i+1}: {sheet_name}")
                
                # 读取前几行数据
                df_preview = pd.read_excel(
                    file_path, 
                    sheet_name=sheet_name,
                    nrows=3,  # 只读取前3行用于预览
                    engine='openpyxl'
                )
                
                print(f"  形状: {df_preview.shape}")
                print(f"  列名: {list(df_preview.columns)}")
                
                # 检查是否包含所需列
                required_columns = ['纠纷条款原文', '条款缺陷分析', '违约/纠纷案例', '判决结果', '法律依据']
                matching_columns = [col for col in required_columns if col in df_preview.columns]
                
                if matching_columns:
                    print(f"   匹配列: {matching_columns}")
                    if len(matching_columns) >= 3:
                        print(f"   推荐使用此工作表")
                else:
                    print(f"   未找到标准列名")
                    
                # 显示数据预览
                if not df_preview.empty:
                    print(f"   数据预览:")
                    for col in df_preview.columns[:3]:  # 显示前3列
                        sample_value = str(df_preview[col].iloc[0])[:30] if not pd.isna(df_preview[col].iloc[0]) else "空值"
                        print(f"      {col}: {sample_value}...")
            
            return sheet_names
            
        except Exception as e:
            print(f"Excel文件分析失败: {e}")
            return []
    
    def clean_text(self, text):
        """文本清洗"""
        if pd.isna(text):
            return ""
        
        # 移除多余空白字符
        text = re.sub(r'\s+', ' ', str(text))
        # 移除特殊字符但保留中文标点
        text = re.sub(r'[^\u4e00-\u9fff\u3000-\u303f\uff00-\uffef\w\s.,;:!?()]', '', text)
        return text.strip()
    
    def process_data_parallel(self):
        """多线程并行处理数据"""
        print("\n开始多线程数据预处理...")
        start_time = time.time()
        
        # 定义处理函数
        def process_row(row_data):
            idx, row = row_data
            processed_row = {}
            
            for col in self.data.columns:
                processed_row[col] = self.clean_text(row[col])
            
            # 创建组合字段用于RAG
            processed_row['combined_text'] = f"""
            条款原文: {processed_row.get('纠纷条款原文', '')}
            缺陷分析: {processed_row.get('条款缺陷分析', '')}
            相关案例: {processed_row.get('违约/纠纷案例', '')}
            判决结果: {processed_row.get('判决结果', '')}
            法律依据: {processed_row.get('法律依据', '')}
            """.strip()
            
            return idx, processed_row
        
        # 多线程处理
        futures = []
        with THREAD_POOL as executor:
            for idx, row in self.data.iterrows():
                future = executor.submit(process_row, (idx, row))
                futures.append(future)
        
        # 收集结果并显示进度条
        processed_rows = {}
        for future in tqdm(as_completed(futures), total=len(futures), desc="数据预处理"):
            idx, processed_row = future.result()
            processed_rows[idx] = processed_row
        
        # 转换为DataFrame
        self.processed_data = pd.DataFrame.from_dict(processed_rows, orient='index')
        
        processing_time = time.time() - start_time
        print(f"数据预处理完成，用时 {processing_time:.2f} 秒")
        print(f"处理了 {len(self.processed_data)} 条记录")
        
        return self.processed_data
    
    def get_training_data(self):
        """获取LoRA训练数据"""
        if self.processed_data is None:
            self.process_data_parallel()
        
        training_data = []
        
        for _, row in self.processed_data.iterrows():
            # 构造训练样本
            input_text = f"请分析以下劳动合同条款的合规风险：\n{row['纠纷条款原文']}"
            
            output_text = f"""
条款缺陷分析：{row['条款缺陷分析']}

相关案例：{row['违约/纠纷案例']}

判决结果：{row['判决结果']}

法律依据：{row['法律依据']}
            """.strip()
            
            training_data.append({
                'input': input_text,
                'output': output_text,
                'combined': row['combined_text']
            })
        
        return training_data

# 初始化数据处理器
data_processor = LegalDataProcessor()

# Excel文件结构分析（可选）
print("首先分析Excel文件结构...")
try:
    # 尝试分析Excel文件
    excel_sheets = data_processor.analyze_excel_structure("./data.xlsx")
    print("\n" + "="*50)
except:
    print("Excel文件分析跳过，将直接尝试加载数据")

# 加载和处理数据
print("\n开始加载劳动合规数据...")
legal_data = data_processor.load_data("./data.xlsx")  # 改为Excel文件

# 数据清洗和预处理
processed_data = data_processor.process_data_parallel()

# 准备训练数据
training_data = data_processor.get_training_data()

print(f"\n训练数据准备完成: {len(training_data)} 个样本")
print("数据预处理阶段完成！")

# 显示数据统计信息
if legal_data is not None and not legal_data.empty:
    print(f"\n数据集统计:")
    print(f"   总样本数: {len(legal_data)}")
    print(f"   列数: {len(legal_data.columns)}")
    print(f"   数据列: {list(legal_data.columns)}")
    
    # 检查数据完整性
    missing_data = legal_data.isnull().sum()
    if missing_data.sum() > 0:
        print(f"   缺失数据统计:")
        for col, missing_count in missing_data.items():
            if missing_count > 0:
                print(f"      {col}: {missing_count} 个缺失值")
    else:
        print(f"   数据完整，无缺失值")


In [None]:
from sentence_transformers import SentenceTransformer
import numpy as np
from typing import List, Tuple, Dict
import pickle
from tqdm import tqdm
from transformers import AutoConfig
import gc

try:
    import faiss
    FAISS_AVAILABLE = True
except ImportError:
    FAISS_AVAILABLE = False
    from sklearn.metrics.pairwise import cosine_similarity

class LegalRAGSystem:
    
    def __init__(self, config=LINUX_CONFIG):
        self.config = config
        self.embedding_model = None
        self.knowledge_base = []
        self.embeddings = None
        self.index = None
        
    def load_embedding_model(self):
        """加载嵌入模型"""
        print("加载嵌入模型...")
        
        model_name = self.config['rag_config']['embedding_model']
        
        try:
            # 设置多线程
            self.embedding_model = SentenceTransformer(
                model_name,
                device=self.config['device']
            )
            
            # 设置编码池大小（如果支持）
            if hasattr(self.embedding_model, 'pool'):
                try:
                    # 尝试设置线程池大小
                    import torch
                    torch.set_num_threads(self.config['num_workers'])
                except:
                    pass
            
            print(f"嵌入模型加载完成: {model_name}")
            
        except Exception as e:
            print(f"模型加载失败: {e}")
            print("尝试加载备用模型...")
            self.embedding_model = SentenceTransformer(
                'all-MiniLM-L6-v2',
                device='cpu'
            )
            print("备用模型加载完成")
    
    def build_knowledge_base(self, processed_data):
        """多线程构建知识库"""
        print("\n构建法律知识库...")
        
        if self.embedding_model is None:
            self.load_embedding_model()
        
        # 构建知识条目
        self.knowledge_base = []
        
        for idx, row in processed_data.iterrows():
            # 分类存储不同类型的知识
            knowledge_items = [
                {
                    'type': '条款分析',
                    'content': f"条款: {row['纠纷条款原文']}\n分析: {row['条款缺陷分析']}",
                    'metadata': {
                        'source_id': idx,
                        'category': '条款分析',
                        'original_clause': row['纠纷条款原文']
                    }
                },
                {
                    'type': '案例',
                    'content': f"案例: {row['违约/纠纷案例']}\n判决: {row['判决结果']}",
                    'metadata': {
                        'source_id': idx,
                        'category': '案例',
                        'legal_basis': row['法律依据']
                    }
                },
                {
                    'type': '法律依据',
                    'content': f"法条: {row['法律依据']}\n适用场景: {row['纠纷条款原文']}",
                    'metadata': {
                        'source_id': idx,
                        'category': '法律依据',
                        'law_article': row['法律依据']
                    }
                }
            ]
            
            self.knowledge_base.extend(knowledge_items)
        
        print(f"知识库构建完成: {len(self.knowledge_base)} 个知识条目")
        
        # 生成嵌入向量
        self._generate_embeddings()
        
        # 构建FAISS索引
        self._build_faiss_index()
        
        return self.knowledge_base
    
    def _generate_embeddings(self):
        """生成嵌入向量 - 批量多线程处理"""
        print("生成嵌入向量...")
        
        # 提取文本内容
        texts = [item['content'] for item in self.knowledge_base]
        
        # 批量编码 - 利用多线程
        print("使用多线程批量编码...")
        
        batch_size = 32  # 批量大小
        all_embeddings = []
        
        for i in tqdm(range(0, len(texts), batch_size), desc="生成嵌入"):
            batch_texts = texts[i:i + batch_size]
            
            # 使用多线程编码
            batch_embeddings = self.embedding_model.encode(
                batch_texts,
                batch_size=batch_size,
                show_progress_bar=False,
                convert_to_numpy=True,
                normalize_embeddings=True,  # 归一化以便相似度计算
                device=self.config['device']  # 确保使用正确的设备
            )
            
            all_embeddings.append(batch_embeddings)
        
        # 合并所有嵌入
        self.embeddings = np.vstack(all_embeddings)
        
        print(f"嵌入向量生成完成: {self.embeddings.shape}")
        
        # 释放内存
        gc.collect()
    
    def _build_faiss_index(self):
        """构建向量索引"""
        if FAISS_AVAILABLE:
            print("构建FAISS向量索引...")
            
            # 创建FAISS索引
            dimension = self.embeddings.shape[1]
            
            # 小数据集使用精确搜索
            self.index = faiss.IndexFlatIP(dimension)  # 内积索引（余弦相似度）
                
            # 训练索引
            self.index.train(self.embeddings.astype('float32'))
            
            # 添加向量到索引
            self.index.add(self.embeddings.astype('float32'))
            
            print(f"FAISS索引构建完成: {self.index.ntotal} 个向量")
        else:
            print("使用sklearn进行向量检索...")
            # 不使用FAISS时，直接使用numpy数组
            self.index = None
            print(f"sklearn检索准备完成: {len(self.embeddings)} 个向量")
    
    def search(self, query: str, top_k: int = None) -> List[Dict]:
        """搜索相关知识"""
        if top_k is None:
            top_k = self.config['rag_config']['top_k']
        
        # 编码查询
        query_embedding = self.embedding_model.encode(
            [query], 
            normalize_embeddings=True,
            convert_to_numpy=True,
            device=self.config['device']
        )
        
        if FAISS_AVAILABLE and self.index is not None:
            # 使用FAISS搜索
            scores, indices = self.index.search(
                query_embedding.astype('float32'), 
                top_k
            )
            score_idx_pairs = list(zip(scores[0], indices[0]))
        else:
            # 使用sklearn进行相似度搜索
            similarities = cosine_similarity(
                query_embedding.reshape(1, -1), 
                self.embeddings
            )[0]
            
            # 获取top_k个最相似的索引
            top_indices = np.argsort(similarities)[::-1][:top_k]
            top_scores = similarities[top_indices]
            
            score_idx_pairs = list(zip(top_scores, top_indices))
        
        # 整理结果
        results = []
        for score, idx in score_idx_pairs:
            if score > self.config['rag_config']['similarity_threshold']:
                result = {
                    'content': self.knowledge_base[idx]['content'],
                    'metadata': self.knowledge_base[idx]['metadata'],
                    'score': float(score),
                    'type': self.knowledge_base[idx]['type']
                }
                results.append(result)
        
        return results
    
    def save_index(self, path="legal_rag_index"):
        """保存索引和知识库"""
        print(f"保存RAG索引到 {path}...")
        
        # 保存FAISS索引（如果有）
        if FAISS_AVAILABLE and self.index is not None:
            faiss.write_index(self.index, f"{path}_faiss.index")
        
        # 保存知识库和元数据
        with open(f"{path}_knowledge.pkl", 'wb') as f:
            pickle.dump({
                'knowledge_base': self.knowledge_base,
                'config': self.config,
                'faiss_available': FAISS_AVAILABLE
            }, f)
        
        # 保存嵌入向量
        np.save(f"{path}_embeddings.npy", self.embeddings)
        
        print("RAG索引保存完成")
    
    def load_index(self, path="legal_rag_index"):
        """加载索引和知识库"""
        print(f"从 {path} 加载RAG索引...")
        
        try:
            # 加载知识库
            with open(f"{path}_knowledge.pkl", 'rb') as f:
                data = pickle.load(f)
                self.knowledge_base = data['knowledge_base']
                saved_faiss_available = data.get('faiss_available', True)
            
            # 加载嵌入向量
            self.embeddings = np.load(f"{path}_embeddings.npy")
            
            # 加载FAISS索引（如果可用且存在）
            if FAISS_AVAILABLE and saved_faiss_available:
                try:
                    self.index = faiss.read_index(f"{path}_faiss.index")
                    print("FAISS索引加载完成")
                except:
                    print("FAISS索引文件不存在，将使用sklearn")
                    self.index = None
            else:
                print("使用sklearn进行向量搜索")
                self.index = None
            
            print("RAG索引加载完成")
            return True
            
        except Exception as e:
            print(f"索引加载失败: {e}")
            return False

# 初始化RAG系统
rag_system = LegalRAGSystem()

# 构建知识库
knowledge_base = rag_system.build_knowledge_base(processed_data)

# 保存索引以供后续使用
rag_system.save_index("legal_rag_index")

print("\nRAG知识库构建完成！")
print(f"知识条目: {len(knowledge_base)}")
print(f"支持实时检索相关法律条文、案例和依据")

In [None]:
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, AutoConfig,
    TrainingArguments, Trainer, DataCollatorForLanguageModeling
)
from peft import (
    LoraConfig, get_peft_model, TaskType,
    prepare_model_for_kbit_training
)
from datasets import Dataset, load_dataset
from sklearn.model_selection import train_test_split
import torch

class LegalLoRATrainer:
    """劳动合规LoRA微调器 - 修复版本"""
    
    def __init__(self, config=LINUX_CONFIG):
        self.config = config
        self.model_config = config['model_configs'][config['config_type']]
        self.lora_config = config['lora_config']
        self.training_config = config['training_config']
        
        self.tokenizer = None
        self.model = None
        self.peft_model = None
        
    def load_model_and_tokenizer(self):
        """加载模型和分词器"""
        print("加载基础模型和分词器...")
        
        model_name = self.model_config['model_name']
        device = self.config['device']
        
        try:
            # 加载分词器
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True,
                padding_side='left'
            )
            
            # 设置特殊token
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # 1) 构建配置，关闭SWA并指定eager
            config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
            for key in ['sliding_window', 'sliding_window_size', 'use_sliding_window']:
                if hasattr(config, key):
                    setattr(config, key, False if key == 'use_sliding_window' else None)
            setattr(config, 'attn_implementation', 'eager')
            
            # 2) 组装加载参数（量化/CPU卸载需在加载前放入）
            model_kwargs = {
                'trust_remote_code': True,
                'device_map': None,  # 训练阶段禁用自动分布，以免与Trainer设备管理冲突
                'torch_dtype': torch.float16 if device == 'cuda' else torch.float32,
                'low_cpu_mem_usage': True,
            }
            # 训练阶段不使用4bit量化，避免设备管理冲突；显存不足可尝试8bit或减小batch
            
            # 3) 仅加载一次模型
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                config=config,
                attn_implementation='eager',
                **model_kwargs
            )
            # 显式将模型移动到目标设备
            if device in ('cuda', 'cpu'):
                self.model.to(device)
            
            # 4) 加载后再次确保禁用SWA（防止模型内部覆盖）
            for key in ['sliding_window', 'sliding_window_size', 'use_sliding_window']:
                if hasattr(self.model.config, key):
                    setattr(self.model.config, key, False if key == 'use_sliding_window' else None)
            if hasattr(self.model, 'generation_config'):
                for key in ['sliding_window', 'sliding_window_size', 'use_sliding_window']:
                    if hasattr(self.model.generation_config, key):
                        setattr(self.model.generation_config, key, False if key == 'use_sliding_window' else None)
            
            print(f"模型加载完成: {model_name}")
            
            # 显示模型信息
            if device == 'cuda':
                allocated = torch.cuda.memory_allocated() / 1e9
                cached = torch.cuda.memory_reserved() / 1e9
                print(f"GPU内存使用: {allocated:.1f}GB (缓存: {cached:.1f}GB)")
            
        except Exception as e:
            print(f"模型加载失败: {e}")
            # 可按需添加备用小模型加载逻辑
    
    def setup_lora(self):
        """设置LoRA配置"""
        print("配置LoRA微调...")
        
        # LoRA配置
        lora_config = LoraConfig(
            r=self.lora_config['r'],
            lora_alpha=self.lora_config['lora_alpha'],
            lora_dropout=self.lora_config['lora_dropout'],
            target_modules=self.lora_config['target_modules'],
            bias=self.lora_config['bias'],
            task_type=TaskType.CAUSAL_LM,
        )
        
        # 准备模型进行训练
        if self.model_config['use_4bit']:
            self.model = prepare_model_for_kbit_training(self.model)
        
        # 应用LoRA
        self.peft_model = get_peft_model(self.model, lora_config)
        
        # 显示可训练参数
        trainable_params = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in self.peft_model.parameters())
        
        print(f"LoRA配置完成:")
        print(f"  可训练参数: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
        print(f"  总参数: {total_params:,}")
        print(f"  LoRA秩: {self.lora_config['r']}")
        
        return self.peft_model
    
    def prepare_training_data(self, training_data):
        """准备训练和验证数据集（聊天模板+仅监督回答段落）"""
        print("准备训练和验证数据...")

        def build_chat_sample(instruction: str, response: str) -> str:
            messages = [
                {"role": "system", "content": "你是经验丰富的劳动法律师，回答要专业、结构化。"},
                {"role": "user", "content": instruction},
                {"role": "assistant", "content": response},
            ]
            if hasattr(self.tokenizer, 'apply_chat_template'):
                return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            # 退化为简易模板
            return f"[系统] 你是经验丰富的劳动法律师。\n[用户] {instruction}\n[助手] {response}"

        # 构造训练文本
        texts = []
        for item in training_data:
            instr = f"请分析以下劳动合同条款的合规风险：\n{item['input']}"
            resp = item['output']
            texts.append(build_chat_sample(instr, resp))

        # 划分训练集和验证集
        train_texts, val_texts = train_test_split(texts, test_size=0.1, random_state=42)
        print(f"训练集: {len(train_texts)}样本, 验证集: {len(val_texts)}样本")

        def tokenize_with_mask(text: str):
            enc = self.tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=self.model_config['max_length'],
            )
            labels = enc["input_ids"][:]
            # 简化的回答掩码策略：仅保留最后40%的token为可学习标签
            cutoff = int(len(labels) * 0.6)
            for i in range(cutoff):
                labels[i] = -100
            enc["labels"] = labels
            return enc

        # 创建数据集
        train_dataset = Dataset.from_dict({'text': train_texts})
        val_dataset = Dataset.from_dict({'text': val_texts})

        # 分词（不使用 num_proc=0，避免报错）
        train_tokenized = train_dataset.map(lambda x: tokenize_with_mask(x['text']), batched=False, remove_columns=['text'])
        val_tokenized   = val_dataset.map(lambda x: tokenize_with_mask(x['text']),   batched=False, remove_columns=['text'])

        print(f"数据准备完成: 训练集 {len(train_tokenized)}, 验证集 {len(val_tokenized)}")
        return train_tokenized, val_tokenized

    def train(self, training_data, output_dir="./legal_lora_model"):
        """开始LoRA训练"""
        print("开始LoRA微调训练...")
        
        # 准备数据
        train_dataset, eval_dataset = self.prepare_training_data(training_data)
        
        # 训练参数
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=self.training_config['num_epochs'],
            per_device_train_batch_size=self.model_config['batch_size'],
            gradient_accumulation_steps=self.training_config['gradient_accumulation_steps'],
            warmup_steps=self.training_config['warmup_steps'],
            learning_rate=self.training_config['learning_rate'],
            weight_decay=0.05,
            logging_dir=f"{output_dir}/logs",
            logging_steps=10,
            eval_strategy="steps",
            eval_steps=50,
            save_strategy="steps",
            save_steps=50,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            fp16=self.training_config['fp16'] and torch.cuda.is_available(),
            gradient_checkpointing=self.training_config['gradient_checkpointing'],
            dataloader_num_workers=self.training_config['dataloader_num_workers'],
            remove_unused_columns=False,
            report_to="tensorboard",
        )
        
        # 数据整理器
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False
        )
        
        # 创建训练器
        trainer = Trainer(
            model=self.peft_model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
            label_names=["labels"],
        )
        
        # 开始训练
        print("开始训练，利用多线程数据加载和验证...")
        
        try:
            trainer.train()
            
            # 保存最好的模型
            best_model_path = f"{output_dir}/best_model"
            trainer.save_model(best_model_path)
            self.tokenizer.save_pretrained(best_model_path)
            
            print(f"训练完成，最佳模型已保存到: {best_model_path}")
            
        except Exception as e:
            print(f"训练失败: {e}")
        
        return trainer

# 初始化LoRA训练器
print("初始化LoRA训练器...")
lora_trainer = LegalLoRATrainer()

# 加载模型
print("加载模型和分词器...")
lora_trainer.load_model_and_tokenizer()

# 设置LoRA
print("设置LoRA配置...")
peft_model = lora_trainer.setup_lora()

print("\nLoRA微调模块准备完成！")
print("可以继续运行下一个Cell进行RAG+LoRA集成")
print("如需训练模型，取消注释下面的代码：")
trainer = lora_trainer.train(training_data, output_dir="./legal_lora_model_v2")

In [None]:
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import gradio as gr
import gc

class LegalAnalysisSystem:
    
    def __init__(self, lora_model_path, rag_system, config=LINUX_CONFIG):
        self.config = config
        self.device = config['device']
        self.tokenizer = None
        self.model = None
        self.rag_system = rag_system
        self.lora_model_path = lora_model_path
        
        # 加载RAG的嵌入模型
        self.rag_system.load_embedding_model()

    def load_finetuned_model(self):
        """加载微调后的LoRA模型"""
        print(f"从 {self.lora_model_path} 加载微调后的模型...")

        try:
            # 加载基础模型的分词器
            base_model_name = self.config['model_configs'][self.config['config_type']]['model_name']
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.lora_model_path,
                trust_remote_code=True
            )

            # 加载基础模型
            model_kwargs = {
                'trust_remote_code': True,
                'device_map': 'auto',
                'torch_dtype': torch.float16,
                'low_cpu_mem_usage': True,
            }
            if self.config['model_configs'][self.config['config_type']]['use_4bit']:
                 from transformers import BitsAndBytesConfig
                 bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.float16,
                    bnb_4bit_use_double_quant=True,
                )
                 model_kwargs['quantization_config'] = bnb_config
            
            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                **model_kwargs
            )
            
            # 加载LoRA权重
            self.model = PeftModel.from_pretrained(base_model, self.lora_model_path)
            self.model = self.model.merge_and_unload() # 合并权重以便于推理
            self.model.eval()

            print("微调模型加载并合并完成！")
            
        except Exception as e:
            print(f"模型加载失败: {e}")
            raise

    def analyze_clause(self, clause_text, top_k=3):
        """分析劳动合同条款"""
        if not self.model or not self.tokenizer:
            print("模型未加载，请先调用 load_finetuned_model()")
            return "模型未加载", ""

        print(f"正在分析条款: '{clause_text}'")
        
        # 1. 使用RAG检索相关知识
        print("   - 使用RAG检索相关知识...")
        retrieved_knowledge = self.rag_system.search(clause_text, top_k=top_k)
        
        knowledge_context = ""
        if retrieved_knowledge:
            print(f"   - 检索到 {len(retrieved_knowledge)} 条相关知识")
            knowledge_context += "背景知识：\n"
            for i, item in enumerate(retrieved_knowledge):
                knowledge_context += f"{i+1}. [来源: {item['type']}] {item['content']}\n"
        else:
            print("   - 未检索到直接相关的知识")

        # 2. 构建模型输入
        prompt = f"""
[指令]
你是一位经验丰富的劳动法律师。请根据以下背景知识，分析劳动合同中的条款是否存在法律风险。

背景知识:
{knowledge_context if knowledge_context else "无"}

待分析条款:
"{clause_text}"

请提供详细的风险分析、相关案例（如果有）、法律依据和改进建议。

[回答]
"""
        
        # 3. 模型生成答案
        print("   - 模型生成分析结果...")
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        streamer = TextStreamer(self.tokenizer, skip_prompt=True)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=1024,
                temperature=0.7,
                top_p=0.9,
                repetition_penalty=1.1,
                do_sample=True,
                streamer=streamer,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # 清理内存
        del inputs, outputs
        gc.collect()
        torch.cuda.empty_cache()

        return response.split('[回答]')[1].strip(), knowledge_context

# --- 交互式测试 ---

# 假设之前的单元格已经成功运行
# lora_trainer, rag_system 已经创建并训练/构建完成

# 设置微调模型的路径 (请根据实际情况修改)
FINETUNED_MODEL_PATH = "./legal_lora_model_v2/best_model" 

# 初始化分析系统
analysis_system = LegalAnalysisSystem(
    lora_model_path=FINETUNED_MODEL_PATH,
    rag_system=rag_system  # 使用前面构建的RAG系统
)

# 加载模型
try:
    analysis_system.load_finetuned_model()
    print("模型加载成功，可以开始测试。")
except Exception as e:
    print(f"无法加载模型，请确认路径 '{FINETUNED_MODEL_PATH}' 是否正确，并且模型已训练。")


def gradio_interface(clause):
    """Gradio界面调用的函数"""
    if not analysis_system.model:
        return "模型未成功加载，请检查后台日志。", "无"
    
    analysis_result, retrieved_info = analysis_system.analyze_clause(clause)
    
    return analysis_result, retrieved_info

# 创建Gradio应用
iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Textbox(lines=5, label="请输入要分析的劳动合同条款", placeholder="例如：员工同意公司有权根据经营需要，随时调整其工作岗位和工作地点。"),
    outputs=[
        gr.Markdown(label="模型分析结果"),
        gr.Textbox(lines=8, label="RAG检索到的相关知识", interactive=False)
    ],
    title="智能劳动法合规审查系统",
    description="本系统结合了LoRA微调模型和RAG知识库，可以对劳动合同条款进行智能分析，并提供法律依据和案例参考。",
    examples=[
        ["员工试用期为6个月，期间工资为正式工资的60%。"],
        ["所有员工必须无条件接受加班安排，否则按旷工处理。"],
        ["员工在职期间产生的所有知识产权，无论是否与工作相关，均归公司所有。"]
    ],
    theme=gr.themes.Soft()
)

# 启动Gradio界面
# iface.launch(share=True)
print("\n交互式测试单元格准备就绪！")
print("如果需要启动Web界面进行测试，请取消最后一行 `iface.launch(share=True)` 的注释并运行。")

