In [None]:
import subprocess
import sys
import os
import logging
import torch
import torch.nn as nn
import json
import re
import numpy as np
import threading
import ipywidgets as widgets
from IPython.display import display
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    BitsAndBytesConfig,
    Trainer
)
from peft import LoraConfig, get_peft_model, PeftModel
from datasets import Dataset
from sentence_transformers import SentenceTransformer
import faiss

# 安裝必要套件
required_packages = [
    "torch", "transformers", "peft", "datasets", "sentence_transformers",
    "faiss", "bitsandbytes", "ipywidgets", "numpy"
]

def install_packages(packages):
    for package in packages:
        try:
            __import__(package)
            print(f"{package} 已安裝")
        except ImportError:
            print(f"正在安裝 {package}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--user"])

install_packages(required_packages)

print("\n驗證安裝版本：")
for package in required_packages:
    try:
        module = __import__(package)
        version = getattr(module, "__version__", "版本未知")
        print(f"{package}: {version}")
    except ImportError:
        print(f"{package} 安裝失敗，請手動檢查")

# 設置日誌
logging.basicConfig(
    level=logging.INFO,
    format='[%(levelname)s] %(asctime)s %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# 設置環境
def set_environment():
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    torch.backends.cudnn.benchmark = True
    logger.info(f"CUDA available: {torch.cuda.is_available()}, Version: {torch.version.cuda}")

set_environment()

# 清理文本
def clean_text(text: str) -> str:
    return text.strip() if text else ""

# 格式化資料，移除標籤
def format_with_metadata(example):
    prompt = clean_text(example["prompt"])
    response = clean_text(example["response"])
    return {"prompt": prompt, "response": response}

# 預處理函數，處理多行 prompt
def preprocess_function(example, tokenizer, max_length=512):
    prompt = clean_text(example["prompt"])
    if "\n" in prompt:
        prompt = "[對話歷史] " + " ".join(prompt.split("\n"))
    response = clean_text(example["response"])
    prompt_text = f"{prompt}\nAssistant:"

    prompt_ids = tokenizer(
        prompt_text,
        add_special_tokens=False,
        truncation=True,
        max_length=max_length // 2
    )["input_ids"]

    response_ids = tokenizer(
        response,
        add_special_tokens=False,
        truncation=True,
        max_length=max_length // 2
    )["input_ids"]

    input_ids = prompt_ids + response_ids
    labels = [-100] * len(prompt_ids) + response_ids

    if len(input_ids) > max_length:
        logger.warning(f"樣本長度 {len(input_ids)} 超過 max_length {max_length}，已截斷")
        input_ids = input_ids[:max_length]
        labels = labels[:max_length]

    return {"input_ids": input_ids, "labels": labels}

# 自定義數據收集器
def custom_data_collator(features):
    input_ids_list = [f["input_ids"] for f in features]
    labels_list = [f["labels"] for f in features]

    batch_input = tokenizer.pad(
        {"input_ids": input_ids_list},
        padding="longest",
        return_tensors="pt",
        pad_to_multiple_of=8,
    )

    batch_labels = tokenizer.pad(
        {"input_ids": labels_list},
        padding="longest",
        return_tensors="pt",
        pad_to_multiple_of=8,
    )["input_ids"]

    batch_labels[batch_labels == tokenizer.pad_token_id] = -100

    batch = {
        "input_ids": batch_input["input_ids"],
        "attention_mask": batch_input["attention_mask"],
        "labels": batch_labels,
    }
    return batch

# 自定義 Trainer（修正部分）
class MyTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.label_names = ["labels"]

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(**inputs)
        loss = outputs.loss
        if loss.dim() == 0:
            loss = loss.unsqueeze(0)
        logger.debug(f"Step loss: {loss.item()}")
        return (loss, outputs) if return_outputs else loss

# 訓練模型
def train_model():
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
    )
    base_model_name = "yentinglin/Llama-3-Taiwan-8B-Instruct"
    max_length = 512
    max_memory = {i: "24GB" for i in range(torch.cuda.device_count())}

    try:
        model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            device_map="auto",
            max_memory=max_memory,
            torch_dtype=torch.float16,
            quantization_config=quant_config,
        )
        logger.info("Base model loaded successfully with device_map=auto.")
    except Exception as e:
        logger.exception("模型載入失敗")
        raise
    
    try:
        global tokenizer
        tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        logger.info("Tokenizer loaded successfully.")
    except Exception as e:
        logger.exception("Tokenizer 載入失敗")
        raise
    
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    data_path = "output/out.json"
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"訓練數據文件 '{data_path}' 不存在，請確認文件路徑。")
    
    try:
        with open(data_path, "r", encoding="utf-8") as f:
            raw_data = json.load(f)
        if not raw_data or not isinstance(raw_data, list):
            raise ValueError("訓練數據文件為空或格式不正確，應為非空列表。")
        raw_datasets = Dataset.from_list(raw_data)
        logger.info(f"Raw data loaded: {len(raw_datasets)} samples")
    except Exception as e:
        logger.exception("數據載入失敗")
        raise

    formatted_datasets = raw_datasets.map(format_with_metadata)
    cleaned_datasets = formatted_datasets.filter(lambda ex: ex["prompt"] and ex["response"])
    logger.info(f"過濾掉 {len(formatted_datasets) - len(cleaned_datasets)} 個無效樣本。")

    processed_dataset = cleaned_datasets.map(
        lambda x: preprocess_function(x, tokenizer, max_length),
        batched=False,
        remove_columns=cleaned_datasets.column_names,
    )
    train_dataset = processed_dataset
    logger.info(f"Training dataset ready: {len(train_dataset)} samples")

    training_args = TrainingArguments(
        output_dir="./lora-llama3-taiwan-8b-instruct_dialogue",
        overwrite_output_dir=True,
        num_train_epochs=1,  # 你設為 1，可根據需要調整
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        logging_steps=50,
        save_steps=500,
        eval_strategy="no",
        fp16=True,
        learning_rate=1e-4,
        max_grad_norm=1.0,
        logging_dir="./logs",
        optim="adamw_torch",
        warmup_steps=100,
        dataloader_num_workers=0,
        gradient_checkpointing=False,
        run_name="lora-llama3-taiwan-run-20250330",
        disable_tqdm=False
    )

    trainer = MyTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=custom_data_collator,
    )

    try:
        trainer.train()
        trainer.save_model("./lora-llama3-taiwan-8b-instruct_dialogue")
        tokenizer.save_pretrained("./lora-llama3-taiwan-8b-instruct_dialogue")
        logger.info("Training completed. Model and tokenizer saved.")
    except Exception as e:
        logger.exception("訓練過程中發生錯誤")
        raise

# 移除特殊標記
def remove_special_tokens(text: str) -> str:
    tokens_to_remove = ["</s>", "<|im_end|>", "<|begin_of_text|>", "<|endoftext|>"]
    for token in tokens_to_remove:
        text = text.replace(token, "")
    return re.sub(r"<\|.*?\|>", "", text).strip()

# 過濾亂碼
def filter_gibberish(text: str) -> str:
    tokens = text.split()
    filtered_tokens = [token for token in tokens if not re.fullmatch(r'[A-Za-z0-9+\-#^_]{8,}', token)]
    return " ".join(filtered_tokens)

# 提取生成答案
def extract_generated_answer(full_response: str, prompt: str) -> str:
    candidate = full_response[len(prompt):].strip() if full_response.startswith(prompt) else full_response.strip()
    candidate = remove_special_tokens(candidate)
    candidate = filter_gibberish(candidate)
    parts = re.split(r"Assistant[:：]", candidate)
    result = parts[-1].strip() if len(parts) > 1 else candidate
    return re.split(r"User[:：]", result)[0].strip()

# 後處理生成文本
def postprocess_answer(text: str, max_sentences: int = 2) -> str:
    text = remove_special_tokens(text)
    text = filter_gibberish(text)
    text = re.sub(r"\[.*?:.*?\]", "", text)
    text = re.sub(r"\*\*\*.*?\*\*\*", "", text)
    text = re.sub(r"://\S+", "", text)
    text = re.sub(r'^\d+\.\s*', '', text)
    text = re.sub(r'[\u3000-\u303f\ufe50-\ufe6f]', '', text)
    sentences = [s.strip() for s in re.split(r"[.!?。！？]", text) if s.strip()]
    output = " ".join(sentences[:max_sentences])
    if output and output[-1] not in ".。！？":
        output += "。"
    return output

# 設置推論模型
def setup_model(lora_model_path: str, base_model_name: str):
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
    )
    max_memory = {i: "24GB" for i in range(torch.cuda.device_count())}
    try:
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            device_map="auto",
            max_memory=max_memory,
            torch_dtype=torch.float16,
            quantization_config=quant_config,
        )
        logger.info("Base model (for inference) loaded successfully.")
    except Exception as e:
        logger.exception("推論模型載入失敗")
        raise
    try:
        tokenizer_local = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)
        if tokenizer_local.pad_token is None:
            tokenizer_local.pad_token = tokenizer_local.eos_token
            tokenizer_local.pad_token_id = tokenizer_local.eos_token_id
        logger.info("Tokenizer (for inference) loaded successfully.")
    except Exception as e:
        logger.exception("推論 Tokenizer 載入失敗")
        raise
    inference_model_local = PeftModel.from_pretrained(base_model, lora_model_path)
    inference_model_local.eval()
    logger.info("LoRA weights applied, model set to eval mode.")
    return tokenizer_local, inference_model_local

# 設置 FAISS
def setup_faiss():
    embedding_model = SentenceTransformer('paraphrase-MiniLM-L6-v2', device='cuda:0')
    embedding_dim = embedding_model.get_sentence_embedding_dimension()
    logger.info(f"SentenceTransformer loaded, embedding dimension: {embedding_dim}")
    faiss_index = faiss.IndexFlatL2(embedding_dim)
    logger.info("FAISS index created successfully.")
    return embedding_model, faiss_index

# 對話歷史
conversation_history = []

def append_history(role: str, message: str):
    conversation_history.append((role, message))

# 檢索相關文檔
def retrieve_documents(query: str, embedding_model, faiss_index, top_k: int = 3):
    query_embedding = embedding_model.encode([query])
    query_embedding = np.array(query_embedding).astype('float32')
    distances, indices = faiss_index.search(query_embedding, top_k)
    retrieved_docs = [conversation_history[idx][1] for idx in indices[0] if idx != -1 and idx < len(conversation_history)]
    logger.debug(f"FAISS retrieved docs: {retrieved_docs}")
    return retrieved_docs

# 添加到 FAISS 索引
def add_to_index(text: str, embedding_model, faiss_index):
    try:
        embedding = embedding_model.encode([text])
        embedding = np.array(embedding).astype('float32')
        faiss_index.add(embedding)
        logger.debug(f"Added to FAISS index, total entries: {faiss_index.ntotal}")
    except Exception as e:
        logger.exception("Error in add_to_index")

# 設置互動界面
def setup_widgets():
    text_input = widgets.Text(
        placeholder='請輸入對話內容...',
        description='User:',
        layout=widgets.Layout(width='80%')
    )
    send_button = widgets.Button(
        description='送出',
        button_style='primary'
    )
    output_area = widgets.Output(
        layout={'border': '1px solid black', 'height': '300px', 'overflow_y': 'auto'}
    )
    display(text_input, send_button, output_area)
    return text_input, send_button, output_area

# 生成回應
def generate_response(inputs, prompt, progress, output_area, inference_model, tokenizer, embedding_model, faiss_index, max_new_tokens):
    try:
        logger.info(f"Memory before generation: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
        with torch.no_grad():
            outputs = inference_model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.5,
                top_p=0.85,
                top_k=50,
                repetition_penalty=1.2,
                pad_token_id=tokenizer.eos_token_id,
                use_cache=True,
            )
        progress.value = 80
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
        logger.debug(f"Full response: {full_response}")
        generated_answer = extract_generated_answer(full_response, prompt)
        logger.debug(f"Generated answer before postprocess: {generated_answer}")
        final_answer = postprocess_answer(generated_answer, max_sentences=2)
        progress.value = 100
        progress.close()
        output_area.append_stdout("Assistant: " + final_answer + "\n")
        append_history("Assistant", final_answer)
        add_to_index(final_answer, embedding_model, faiss_index)
        logger.info(f"Memory after generation: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    except Exception as e:
        progress.close()
        logger.exception("Error in generate_response")
        output_area.append_stdout("Error during generation: " + str(e) + "\n")

# 互動模式
def interactive_mode():
    base_model_name = "yentinglin/Llama-3-Taiwan-8B-Instruct"
    lora_model_path = "./lora-llama3-taiwan-8b-instruct_dialogue"
    tokenizer, model = setup_model(lora_model_path, base_model_name)
    embedding_model, faiss_index = setup_faiss()
    text_input, send_button, output_area = setup_widgets()

    def on_send_button_clicked(b):
        user_message = text_input.value.strip()
        if not user_message:
            return
        text_input.value = ""
        output_area.append_stdout(f"User: {user_message}\n")
        append_history("User", user_message)
        add_to_index(user_message, embedding_model, faiss_index)
        
        history_context = "\n".join([f"{role}: {msg}" for role, msg in conversation_history[-5:]])
        retrieved_docs = retrieve_documents(user_message, embedding_model, faiss_index, top_k=3)
        retrieved_context = "相關資訊:\n" + "\n".join(retrieved_docs) + "\n" if retrieved_docs else ""
        system_message = (
            f"你是一個台灣大學生，用 LINE 聊天。\n"
            f"回應要超短、自然，像 '靠北超糗'、'好啊去吃爆' 這樣，加點俚語跟表情符號（😂、🥳）。\n"
            f"根據上下文回，內心理解情緒、行為、話題，但別在回應裡秀出來。\n"
            f"以下是對話歷史：\n{history_context}\n"
            f"相關資訊：\n{retrieved_context}\n"
            f"User: {user_message}\nAssistant: "
        )
        logger.info(f"Generated prompt: {system_message}")
        dynamic_max_new_tokens = 50
        progress = widgets.IntProgress(value=0, min=0, max=100, description='處理中:', bar_style='info')
        display(progress)
        try:
            inputs = tokenizer(system_message, return_tensors="pt").to(model.device)
            logger.debug(f"Input token length: {inputs['input_ids'].shape[1]}")
            progress.value = 20
            send_button.disabled = True
            threading.Thread(
                target=lambda: [
                    generate_response(inputs, system_message, progress, output_area, model, tokenizer, embedding_model, faiss_index, dynamic_max_new_tokens),
                    setattr(send_button, 'disabled', False)
                ]
            ).start()
        except Exception as e:
            progress.close()
            logger.exception("Error in on_send_button_clicked")
            output_area.append_stdout("Error during generation: " + str(e) + "\n")
            send_button.disabled = False

    send_button.on_click(on_send_button_clicked)
    logger.info("Interactive interface setup complete.")
    print("[INFO] 推論模式啟動，開始互動。")

# 主程式
if __name__ == "__main__":
    if torch.cuda.is_available():
        print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("未檢測到 GPU，使用 CPU")
    train_model()  # 訓練模式
    interactive_mode()  # 互動模式