In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from model_wrapper_with_mlp_adapter import FeaturePrefixAdapter
from faiss_vd import runtime_add, retrieve_similar_vectors
import os

# === Configuration ===
MODEL_ID = "meta-llama/Llama-2-7b-hf"
MODEL_DIR = "/content/drive/MyDrive/llm/LLM-based-Agent-for-Driver-Sleepiness-Detection-and-Mitigation-in-Automotive-Systems/llm_and_fatigue_handling/llama_prefix_final_model"
FEATURE_DIM = 9
EMBEDDING_DIM = 4096
PREFIX_TOKEN_COUNT = 5
MAX_LENGTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# === Quant config for 4-bit loading ===
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

# === Load tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
tokenizer.pad_token = tokenizer.eos_token

# === Load base model ===
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    quantization_config=bnb_config,
    token=os.environ.get("HUGGINGFACE_TOKEN")
)

# === Load LoRA adapter ===
llama_model = PeftModel.from_pretrained(
    base_model,
    MODEL_DIR,
    device_map="auto"
)
llama_model.eval()

# === Load MLP Adapter ===
adapter = FeaturePrefixAdapter(
    input_dim=FEATURE_DIM,
    hidden_dim=256,
    output_dim=EMBEDDING_DIM,
    num_tokens=PREFIX_TOKEN_COUNT
)
adapter.load_state_dict(torch.load(os.path.join(MODEL_DIR, "prefix_adapter.pth")))
target_dtype = next(llama_model.parameters()).dtype
adapter = adapter.to(dtype=target_dtype, device=DEVICE)
adapter.eval()

# === Input Features and Fatigue Levels ===
features = [24, 8, 0.38, 0.23, 96.0, 0.4, 0.21, 8.0, 1.3]
fatigue_levels = ["low", "medium", "medium"]

# === Get prefix embedding
feature_tensor = torch.tensor([features], dtype=target_dtype).to(DEVICE)
prefix_embeddings = adapter(feature_tensor)  # [1, 5, 4096]

# === Prepare FAISS vector
token_matrix = prefix_embeddings.squeeze(0).detach().cpu().numpy()

# === Retrieve top-k similar interventions
results = retrieve_similar_vectors(token_matrix, k=3)
retrieved_interventions = [
    meta.get("intervention") for _, meta, _ in results
    if meta.get("intervention") and meta.get("intervention").strip().lower() not in {"", "none", "driver alert"}
]

# === RAG-style context
if retrieved_interventions:
    context = "Previously suggested interventions for similar scenarios: " + "; ".join(retrieved_interventions) + ". "
else:
    context = ""

# === Prompt with context + fatigue levels
prompt = f"""
{context}
You are an intelligent in-cabin assistant.

Fatigue levels:
- Camera: {fatigue_levels[0]}
- Steering: {fatigue_levels[1]}
- Lane: {fatigue_levels[2]}

Based on the above driver state and past examples, suggest an intervention to keep the driver alert.

⚠️ IMPORTANT: You must output in this fixed format — no extra text.

Fan: Level X      ← X is a number like 1, 2, or 3  
Music: On/Off  
Vibration: On/Off  
Reason: <short explanation of the logic>

Example output:
Fan: Level 2  
Music: On  
Vibration: Off  
Reason: High blink rate and PERCLOS indicate moderate drowsiness.

Now, provide your intervention:
""".strip()

# === Tokenize prompt
inputs = tokenizer(
    prompt,
    return_tensors="pt",
    max_length=MAX_LENGTH - PREFIX_TOKEN_COUNT,
    truncation=True,
    padding="max_length"
)
input_ids = inputs["input_ids"].to(DEVICE)
attention_mask = inputs["attention_mask"].to(DEVICE)

# === Embeddings
input_embeddings = llama_model.base_model.get_input_embeddings()(input_ids)
prefix_embeddings = prefix_embeddings.to(dtype=target_dtype)
input_embeddings = input_embeddings.to(dtype=target_dtype)
combined_embeddings = torch.cat([prefix_embeddings, input_embeddings], dim=1)

# === Attention Mask
prefix_attention_mask = torch.ones(1, PREFIX_TOKEN_COUNT, dtype=torch.long).to(DEVICE)
extended_attention_mask = torch.cat([prefix_attention_mask, attention_mask], dim=1)

# === Generate
with torch.no_grad():
    output = llama_model.generate(
        inputs_embeds=combined_embeddings,
        attention_mask=extended_attention_mask,
        max_new_tokens=50,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.9,
        pad_token_id=tokenizer.pad_token_id
    )

# === Decode
response = tokenizer.decode(output[0, PREFIX_TOKEN_COUNT:], skip_special_tokens=True)
print("\n=== Generated Intervention ===")
print(response)

# === Save final vector + output
runtime_add(token_matrix, intervention=response)
