In [None]:
from unsloth import FastLanguageModel  
import torch  

max_seq_length = 2048 
dtype = None  
load_in_4bit = True  


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="./model_cache",  
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)


In [None]:
prompt_style = """The following is a task description along with contextual input.
Please provide an appropriate answer, and before answering, perform logical reasoning to support your judgment.

### Instruction:
You are a professional blast furnace equipment fault diagnosis expert. 
You have mastered the various states and causes of the blast furnace operation process, and can accurately determine the blast furnace operation state according to the given variables and their values. 

The operational status of the blast furnace includes the following five categories:
- Normal state, label: 0.0
- Fault category: Hanging, fault label: 1.0
- Fault category: Hot Stove Malfunction, fault label: 2.0
- Fault category: Channeling, fault label: 3.0

You need to analyze the data according to the causes of various fault states to determine the system is in the normal state or which of the four fault states.

### Question:
{}

### Answer:


Your reasoning:
<think>{}"""


question = '''{
    "Question": ""
}'''

FastLanguageModel.for_inference(model)


import json
question_text = json.loads(question)["Question"]  

inputs = tokenizer(
    [prompt_style.format(question_text, "")],  
    return_tensors="pt"
).to("cuda")

outputs = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    max_new_tokens=2650,
    use_cache=True,
)

response = tokenizer.batch_decode(outputs)
print(response[0])

In [None]:
train_prompt_style = """The following is a task description along with contextual input.
Please provide an appropriate answer, and before answering, perform logical reasoning to support your judgment.

### Instruction:
You are a professional blast furnace equipment fault diagnosis expert. 
You have mastered the  various states and causes of the blast furnace operation process, and can accurately determine the blast furnace operation state according to the given variables and their values. 
﻿
The operational status of the blast furnace includes the following five categories:
- Normal state, label: 0.0
- Fault category: Hanging, fault label: 1.0
- Fault category: Hot Stove Malfunction, fault label: 2.0
- Fault category: Channeling, fault label: 3.0
﻿
You need to analyze the data according to the causes of various fault states to determine the system is in the normal state or which of the four fault states.

### Question:
{}

### Answer:
<think>
{}

</think>
{}"""

EOS_TOKEN = tokenizer.eos_token  

from datasets import load_dataset
dataset = load_dataset("json", data_files="/root/traindataset.jsonl", trust_remote_code=True)

print(dataset.column_names)
def formatting_prompts_func(examples):
    inputs = examples["input"]
    cots = examples["reasoning"]
    outputs = examples["output"]
    texts = []  
    for input, cot, output in zip(inputs, cots, outputs):
        text = train_prompt_style.format(input, cot, output) + EOS_TOKEN
        texts.append(text)  
    return {
        "text": texts,  
    }

dataset = dataset.map(formatting_prompts_func, batched = True)
dataset["text"][0]


In [None]:
FastLanguageModel.for_training(model)


# Attn-only
attn_only = ["q_proj","k_proj","v_proj","o_proj"]

# Attn+FFN
attn_ffn = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]

# FFN-only
ffn_only = ["gate_proj","up_proj","down_proj"]



model = FastLanguageModel.get_peft_model(
    model,  
    r = 16,  
    target_modules = attn_only,
    lora_alpha = 32,  
    lora_dropout = 0.1,  
    bias = "none",   
    use_gradient_checkpointing = "unsloth",  
    random_state = 3407,  
    use_rslora = False,  
    loftq_config = None,  
)

from trl import SFTTrainer  
from transformers import TrainingArguments  
from unsloth import is_bfloat16_supported  

trainer = SFTTrainer(  
    model=model,  
    tokenizer=tokenizer,  
    train_dataset=dataset,  
    dataset_text_field="text",  
    max_seq_length=max_seq_length,  
    dataset_num_proc=2,  
    packing=False,  
    args=TrainingArguments(  
        per_device_train_batch_size=2,  
        gradient_accumulation_steps=4,  
        warmup_steps=10,  
        num_train_epochs=30,  
        learning_rate=2e-4,  
        fp16=not is_bfloat16_supported(),  
        bf16=is_bfloat16_supported(),  
        logging_steps=20,  
        optim="adamw_8bit",  
        weight_decay=0.1,  
        lr_scheduler_type="linear",  
        seed=3407,  
        output_dir="outputs",  
        report_to="none",  
    ),
)

trainer_stats = trainer.train()


In [None]:
import re
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from datasets import load_dataset
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
import numpy as np
from sklearn.decomposition import PCA
eval_dataset = load_dataset("json", data_files="/root/testdataset.jsonl")


prompt_style =  """The following is a task description along with contextual input.
Please provide an appropriate answer, and before answering, perform logical reasoning to support your judgment.

### Instruction:
You are a professional blast furnace equipment fault diagnosis expert. 
You have mastered the  various states and causes of the blast furnace operation process, and can accurately determine the blast furnace operation state according to the given variables and their values. 
﻿
The operational status of the blast furnace includes the following five categories:
- Normal state, label: 0.0
- Fault category: Hanging, fault label: 1.0
- Fault category: Hot Stove Malfunction, fault label: 2.0
- Fault category: Channeling, fault label: 3.0

You need to analyze the data according to the causes of various fault states to determine the system is in the normal state or which of the four fault states.

### Question:
{}

### Answer:
<think>
"""


y_true = []
y_pred = []
pred_labels = []
true_labels = []

FastLanguageModel.for_inference(model)

embeddings = []

for example in tqdm(eval_dataset):
    question = example["input"]
    label = str(example["output"])
    input_text = prompt_style.format(question)
    inputs = tokenizer([input_text], return_tensors="pt").to("cuda")

   
    outputs = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_new_tokens=2650,
    )

        
    with torch.no_grad():
        model_output = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, output_hidden_states=True)
        hidden_states = model_output.hidden_states[-2]  
        mean_hidden = hidden_states.mean(dim=1)  # shape: [batch_size, hidden_dim]
        embeddings.append(mean_hidden.cpu().numpy().flatten())  

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Response: {response}")
    
    post_think_match = re.search(r"</think>\s*(.*)", response, re.DOTALL)
    if not post_think_match:
        print("Unable to find the </think> tag, skipping this sample.")
        continue

    post_think_text = post_think_match.group(1)

    pred_match = re.search(r"label\s*:\s*(\d+(\.\d+)?)", post_think_text, re.IGNORECASE)
    if not pred_match:
        print("Failed to extract the predicted label, skipping this sample.")
        continue

    pred_label = float(pred_match.group(1))
    print(f"The extracted predicted label: {pred_label}")

    
    label_match = re.search(r"label\s*:\s*(\d+(\.\d+)?)", label)
    if not label_match:
        print("Unable to extract the true label, skipping this sample.")
        continue
    true_label = float(label_match.group(1))
    print(f"The extracted true label: {true_label}")

    y_true.append(true_label)
    y_pred.append(pred_label)
    true_labels.append(true_label)
    pred_labels.append(pred_label)

if not y_true or not y_pred:
    print(" No valid label data is available, evaluation metrics cannot be calculated.")
else:
   
    y_true_int = [int(x) for x in y_true]
    y_pred_int = [int(x) for x in y_pred]

    
    accuracy = accuracy_score(y_true_int, y_pred_int)
    precision = precision_score(y_true_int, y_pred_int, average='macro', zero_division=0)
    recall = recall_score(y_true_int, y_pred_int, average='macro', zero_division=0)
    f1 = f1_score(y_true_int, y_pred_int, average='macro', zero_division=0)

    print(f"\n Evaluation Results：")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-score: {f1:.4f}")

    target_names = ["0.0", "1.0", "2.0", "3.0"]
    print("\n Classification Report：")
    print(classification_report(
        y_true,
        y_pred,
        labels=[0.0, 1.0, 2.0, 3.0],
        target_names=target_names,
        zero_division=0
    ))

    cm = confusion_matrix(y_true, y_pred, labels=[0.0, 1.0, 2.0, 3.0])
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='g', cmap='Blues',
                xticklabels=target_names, yticklabels=target_names)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.show()
