In [1]:
import gradio as gr
import pandas as pd
import re
import time
import os
import torch
from transformers import AutoTokenizer, pipeline
from unsloth import FastLanguageModel
from peft import PeftModel
import gc

# Configuration settings
base_model_name = "unsloth/Qwen3-14B"
adapter_path = ""
output_dir = r""
os.makedirs(output_dir, exist_ok=True)
output_csv_path = os.path.join(output_dir, "supply_chain_accidents_data.csv")

# Load model and tokenizer (fixed device issue)
def load_model():
    print("⏳ Loading model...")
    try:
        base_model, _ = FastLanguageModel.from_pretrained(
            model_name=base_model_name,
            max_seq_length=2048,
            load_in_4bit=True,
            device_map="auto"
        )
        model = PeftModel.from_pretrained(
            base_model,
            adapter_path,
            adapter_name="accident_cause_adapter"
        )
        model.eval()
        tokenizer = AutoTokenizer.from_pretrained(
            adapter_path,
            padding_side="left",
            truncation_side="left"
        )
        tokenizer.pad_token = tokenizer.eos_token
        print("✅ Model loaded successfully")
        return model, tokenizer
    except Exception as e:
        print(f"❌ Model loading failed: {str(e)}")
        return None, None

# Create inference pipeline (fixed device error)
def create_pipeline(model, tokenizer):
    if model is None or tokenizer is None:
        return None
        
    print("⚙️ Creating inference pipeline...")
    try:
        # Remove device parameter, use accelerate's automatic device management
        return pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer
        )
    except Exception as e:
        print(f"❌ Pipeline creation failed: {str(e)}")
        return None

# Global variables to store model instances
model, tokenizer = load_model()
pipe = create_pipeline(model, tokenizer)

# Save results to CSV
def save_to_csv(data):
    df = pd.DataFrame([data])
    if not os.path.exists(output_csv_path):
        df.to_csv(output_csv_path, index=False, encoding='utf-8-sig')
        print(f"📄 Created new CSV file: {output_csv_path}")
    else:
        df.to_csv(output_csv_path, mode='a', header=False, index=False, encoding='utf_8_sig')
        print(f"📝 Appended data to CSV file")

# Parse model response
def parse_response(response):
    try:
        # Extract thought chain
        think_pattern = r"<think>(.*?)</think>"
        think_match = re.search(think_pattern, response, re.DOTALL)
        think_chain = think_match.group(1).strip() if think_match else "未找到思维链"
        
        # Extract direct cause
        direct_cause_pattern = r"直接原因:\s*(.*?)(?:\n|$)"
        direct_match = re.search(direct_cause_pattern, response)
        direct_cause = direct_match.group(1).strip() if direct_match else "未找到直接原因"
        
        # Extract indirect cause
        indirect_cause_pattern = r"间接原因:\s*(.*?)(?:\n|$)"
        indirect_match = re.search(indirect_cause_pattern, response)
        indirect_cause = indirect_match.group(1).strip() if indirect_match else "未找到间接原因"
        
        return think_chain, direct_cause, indirect_cause
    except Exception as e:
        print(f"Parsing error: {str(e)}")
        return "解析错误", "解析错误", "解析错误"

# Generate model response
def generate_response(input_text):
    if pipe is None:
        return "模型未加载", "模型未加载", "模型未加载", 0
    
    start_time = time.time()
    
    prompt_template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
你是一个事故因果推理专家，需要根据input中的事故过程从给出的直接原因和间接原因分类表中推理出事故对应的直接原因和间接原因，输出只能为：思维链。直接原因:直接原因列表，间接原因:间接原因列表。直接原因分类表有:车辆操作安全意识淡薄,超速与载重合规性,高风险路段驾驶行为不合规,重型车辆操作规范,车辆行驶稳定性管理,车间装卸货物规范安全操作,工地卸载货物操作不规范.间接原因分类表有:施工现场人员安全管理体系不健全,运输过程合规性监管失效,重型车辆驾驶员运输培训不到位,企业管理车辆驾驶规范责任失效。

### Input:
{}

### Response:
"""
    formatted_prompt = prompt_template.format(input_text)
    
    generation_config = {
        "max_new_tokens": 2048,
        "temperature": 0.6,
        "top_p": 0.9,
        "top_k": 30,
        "do_sample": True,
        "pad_token_id": tokenizer.eos_token_id,
        "repetition_penalty": 1.2
    }
    
    try:
        outputs = pipe(
            formatted_prompt,
            **generation_config,
            return_full_text=False
        )
        response = outputs[0]['generated_text'].strip()
        
        # Clean response content
        if "### Response:" in response:
            response = response.split("### Response:")[-1].strip()
        response = re.sub(r'(直接原因|间接原因).*?[:：]', '', response, count=1)
        
        # Parse response
        think_chain, direct_cause, indirect_cause = parse_response(response)
        elapsed_time = time.time() - start_time
        
        # Save results
        save_to_csv({
            "事故描述": input_text,
            "思维链": think_chain,
            "直接原因": direct_cause,
            "间接原因": indirect_cause,
            "响应时间(秒)": round(elapsed_time, 2),
            "时间戳": time.strftime("%Y-%m-%d %H:%M:%S")
        })
        
        return think_chain, direct_cause, indirect_cause, round(elapsed_time, 2)
    
    except Exception as e:
        print(f"Generation error: {str(e)}")
        elapsed_time = time.time() - start_time
        return f"错误: {str(e)}", "", "", round(elapsed_time, 2)

# Gradio interface
with gr.Blocks(title="Construction Supply Chain Accident Analysis System", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🏗️ Construction Supply Chain Accident Analysis System")
    gr.Markdown("### Enter accident description to analyze direct and indirect causes")
    
    with gr.Row():
        with gr.Column():
            accident_input = gr.Textbox(
                label="Accident Description",
                placeholder="Please enter the description of the accident during the construction supply chain transportation phase...",
                lines=5,
                max_lines=10
            )
            submit_btn = gr.Button("Analyze Accident", variant="primary")
            clear_btn = gr.Button("Clear Input")
            
            with gr.Accordion("Example Inputs", open=False):
                gr.Examples(
                    examples=[
                        ["A heavy-duty truck transporting construction piles lost control and overturned on a curve. The driver had been working continuously for 12 hours and was speeding. There were signs that the fixing devices for the five piles loaded on the vehicle were loose. Post-incident inspection revealed a 0.8-second delay in the braking system response, aging sensors causing deviations in braking force distribution, and evidence of illegal modifications to the onboard speed-limiting module."],
                        ["During unloading at a construction site, a crane operator failed to observe the surrounding area, resulting in a steel pipe falling and striking a worker operating nearby. The investigation found that the operator lacked formal training, the safety supervisor was absent from duty at the time, and the crane's safety limit device had malfunctioned."]
                    ],
                    inputs=accident_input
                )
        
        with gr.Column():
            think_output = gr.Textbox(label="Thought Chain Analysis", interactive=False, lines=7)
            with gr.Row():
                direct_output = gr.Textbox(label="Direct Cause", interactive=False)
                indirect_output = gr.Textbox(label="Indirect Cause", interactive=False)
            time_output = gr.Number(label="Response Time (seconds)", interactive=False)
            status_output = gr.Textbox(label="System Status", value="Ready", interactive=False)
    
    # Submit processing
    submit_btn.click(
        fn=generate_response,
        inputs=accident_input,
        outputs=[think_output, direct_output, indirect_output, time_output]
    ).then(
        fn=lambda: "Analysis Completed",
        outputs=status_output
    )
    
    # Clear input
    clear_btn.click(
        fn=lambda: ["", "", "", "", 0, "Cleared"],
        outputs=[accident_input, think_output, direct_output, indirect_output, time_output, status_output]
    )
    
    gr.Markdown("""
    **Instructions:**
    1. Enter detailed accident description in the left input box
    2. Click "Analyze Accident" button to get analysis results
    3. Analysis results will be automatically saved to the system database
    4. View standard input format through examples
    
    **Note:** Initial run requires model loading and may take 1-2 minutes
    """)

# Launch application
if __name__ == "__main__":
    try:
        print("🚀 Launching Gradio interface...")
        demo.launch(
            server_name="127.0.0.1",
            server_port=7860,
            share=False,
            show_error=True
        )
    except Exception as e:
        print(f"❌ Launch failed: {str(e)}")
    finally:
        # Clean up resources
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        print("Resources cleaned up")


Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
⏳ Loading model...


  GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f"cuda:{i}") for i in range(n_gpus)])


==((====))==  Unsloth 2025.5.7: Fast Qwen3 patching. Transformers: 4.51.3.
   \\   /|    NVIDIA GeForce RTX 5090. Num GPUs = 1. Max memory: 31.842 GB. Platform: Windows.
O^O/ \_/ \    Torch: 2.7.0+cu128. CUDA: 12.0. CUDA Toolkit: 12.8. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Device set to use cuda:0
The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DeepseekV3ForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForConditionalGeneration', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'Glm4ForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoFo

✅ Model loaded successfully
⚙️ Creating inference pipeline...
🚀 Launching Gradio interface...
* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


Resources cleaned up
