In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login
from peft import PeftModel
from model_wrapper_with_mlp_adapter import FeaturePrefixAdapter
from faiss_vd import runtime_add, retrieve_similar_vectors
from transformers import BitsAndBytesConfig
from accelerate import infer_auto_device_map, init_empty_weights
import torch

# === Authenticate Hugging Face ===
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
if HUGGINGFACE_TOKEN is None:
    raise ValueError("HUGGINGFACE_TOKEN environment variable not set.")
login(token=HUGGINGFACE_TOKEN)

# === Config ===
MODEL_NAME = "meta-llama/Llama-2-7b-hf"
MODEL_DIR = "/content/drive/MyDrive/LLM-based-Agent-for-Driver-Sleepiness-Detection-and-Mitigation-in-Automotive-Systems/llm_and_fatigue_handling/llama2_7B_with_prefix_adapter_vector/llama_prefix_final_model"
FEATURE_DIM = 9
EMBEDDING_DIM = 4096
PREFIX_TOKEN_COUNT = 5
MAX_LENGTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# === Tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, token=HUGGINGFACE_TOKEN)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
tokenizer.padding_side = "right"

# === Quantization Config ===
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4"
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    trust_remote_code=True,
    token=HUGGINGFACE_TOKEN,
    quantization_config=bnb_config,  # ✅ Keep this
    low_cpu_mem_usage=True           # ✅ Optional for Colab
    # ❌ DO NOT add: load_in_4bit=True
)

base_model.resize_token_embeddings(len(tokenizer))

# === Load LoRA adapter from local folder ===
llama_model = PeftModel.from_pretrained(
    base_model,
    MODEL_DIR,
    device_map="auto",
    use_auth_token=False  # ✅ explicitly tells PEFT not to go to Hugging Face Hub
)
llama_model.eval()

# === Load MLP Prefix Adapter ===
adapter = FeaturePrefixAdapter(
    input_dim=FEATURE_DIM,
    hidden_dim=256,
    output_dim=EMBEDDING_DIM,
    num_tokens=PREFIX_TOKEN_COUNT
)
adapter.load_state_dict(torch.load(os.path.join(MODEL_DIR, "prefix_adapter.pth")))
adapter = adapter.to(dtype=next(llama_model.parameters()).dtype, device=DEVICE)
adapter.eval()

# === Input Features and Fatigue Levels ===
features = [24, 8, 0.38, 0.23, 96.0, 0.4, 0.21, 8.0, 1.3]
fatigue_levels = ["low", "medium", "medium"]

# === Get prefix embedding
feature_tensor = torch.tensor([features], dtype=torch.float32).to(DEVICE)
feature_tensor = feature_tensor.to(dtype=next(llama_model.parameters()).dtype)
prefix_embeddings = adapter(feature_tensor)
token_matrix = prefix_embeddings.squeeze(0).detach().cpu().numpy()

# === Retrieve top-k similar interventions
results = retrieve_similar_vectors(token_matrix, k=3)
retrieved_interventions = [
    meta.get("intervention") for _, meta, _ in results
    if meta.get("intervention") and meta.get("intervention").strip().lower() not in {"", "none", "driver alert"}
]

# === Build Prompt ===
context = (
    "Previously suggested interventions for similar scenarios: "
    + "; ".join(retrieved_interventions)
    + ". "
) if retrieved_interventions else ""

prompt = f"""
{context}
You are an intelligent in-cabin assistant.

Fatigue levels:
- Camera: {fatigue_levels[0]}
- Steering: {fatigue_levels[1]}
- Lane: {fatigue_levels[2]}

Based on the above driver state and past examples, suggest an intervention to keep the driver alert.

⚠️ IMPORTANT: You must output in this fixed format — no extra text.

Fan: Level X      ← X is a number like 1, 2, or 3
Music: On/Off
Vibration: On/Off
Reason: <short explanation of the logic>

Example output:
Fan: Level 2
Music: On
Vibration: Off
Reason: High blink rate and PERCLOS indicate moderate drowsiness.

Now, provide your intervention:
""".strip()

# === Tokenize prompt
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=MAX_LENGTH - PREFIX_TOKEN_COUNT)
input_ids = inputs["input_ids"].to(DEVICE)
attention_mask = inputs["attention_mask"].to(DEVICE)

# === Embeddings
input_embeddings = llama_model.base_model.get_input_embeddings()(input_ids).to(dtype=prefix_embeddings.dtype)
combined_embeddings = torch.cat([prefix_embeddings, input_embeddings], dim=1)

# === Attention Mask
prefix_attention_mask = torch.ones(1, PREFIX_TOKEN_COUNT, dtype=torch.long).to(DEVICE)
extended_attention_mask = torch.cat([prefix_attention_mask, attention_mask], dim=1)

# === Generate
with torch.no_grad():
    output = llama_model.generate(
        inputs_embeds=combined_embeddings,
        attention_mask=extended_attention_mask,
        max_new_tokens=50,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.9,
        pad_token_id=tokenizer.pad_token_id
    )

# === Decode
response = tokenizer.decode(output[0, PREFIX_TOKEN_COUNT:], skip_special_tokens=True)
print("\n=== Generated Intervention ===")
print(response)

# === Save final vector + output
runtime_add(token_matrix, intervention=response)