In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from model_wrapper_with_mlp_adapter import FeaturePrefixAdapter
import os
from huggingface_hub import login

# === Configuration ===
MODEL_ID = "meta-llama/Llama-2-7b-hf"  # ✅ Smaller model to fit GPU
MODEL_DIR = "/content/drive/MyDrive/llm/LLM-based-Agent-for-Driver-Sleepiness-Detection-and-Mitigation-in-Automotive-Systems/llm_and_fatigue_handling/llama_prefix_final_model"
ADAPTER_PATH = os.path.join(MODEL_DIR, "prefix_adapter.pth")
FEATURE_DIM = 9
EMBEDDING_DIM = 4096
PREFIX_TOKEN_COUNT = 5
MAX_LENGTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# === Login to Hugging Face ===
login(token=os.environ.get("HUGGINGFACE_TOKEN"))

# === Fix fragmentation (optional, helpful for lower memory GPUs) ===
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# === BitsAndBytes Quantization config ===
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

# === Load tokenizer and base model ===
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    token=os.environ.get("HUGGINGFACE_TOKEN")
)
tokenizer.pad_token = tokenizer.eos_token  # Fix for missing pad_token

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    quantization_config=bnb_config,
    token=os.environ.get("HUGGINGFACE_TOKEN")
)

print(f"Base model device: {next(base_model.parameters()).device}")

# === Load PEFT Adapter ===
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

llama_model = get_peft_model(base_model, lora_config)

adapter_state_dict = torch.load(ADAPTER_PATH, map_location="cpu")
adapter_state_dict = {k.replace('base_model.', ''): v for k, v in adapter_state_dict.items()}
llama_model.load_state_dict(adapter_state_dict, strict=False)

print(f"Peft model device: {next(llama_model.parameters()).device}")

# === Load MLP Feature Adapter ===
adapter = FeaturePrefixAdapter(
    input_dim=FEATURE_DIM,
    hidden_dim=256,
    output_dim=EMBEDDING_DIM,
    num_tokens=PREFIX_TOKEN_COUNT
)
adapter.load_state_dict(torch.load(ADAPTER_PATH, map_location="cpu"))
adapter = adapter.to(dtype=target_dtype, device=DEVICE)
adapter.eval()

print(f"Adapter device: {next(adapter.parameters()).device}")

# === Sensor Features ===
target_dtype = next(llama_model.parameters()).dtype  # Match model dtype
features = [24, 8, 0.38, 0.23, 96.0, 0.4, 0.21, 8.0, 1.3]
feature_tensor = torch.tensor([features], dtype=target_dtype).to(DEVICE) # Cast feature_tensor to target_dtype

# === Tokenize Prompt ===
short_prompt = "Based on the above signals, what should be the appropriate intervention?"
inputs = tokenizer(
    short_prompt,
    return_tensors="pt",
    max_length=MAX_LENGTH - PREFIX_TOKEN_COUNT,
    truncation=True,
    padding="max_length"
)
input_ids = inputs["input_ids"].to(DEVICE)
attention_mask = inputs["attention_mask"].to(DEVICE)

# === Generate Prefix Embeddings ===
prefix_embeddings = adapter(feature_tensor)

# === Get Input Embeddings from model ===
embedding_layer = base_model.get_input_embeddings().to(DEVICE)
input_embeddings = embedding_layer(input_ids)

# Match dtypes (already done for feature_tensor, keeping this for input_embeddings)
prefix_embeddings = prefix_embeddings.to(dtype=target_dtype)
input_embeddings = input_embeddings.to(dtype=target_dtype)

# === Combine Prefix + Input Embeddings ===
combined_embeddings = torch.cat([prefix_embeddings, input_embeddings], dim=1)

# === Extended Attention Mask ===
prefix_attention_mask = torch.ones(1, PREFIX_TOKEN_COUNT, dtype=torch.long).to(DEVICE)
extended_attention_mask = torch.cat([prefix_attention_mask, attention_mask], dim=1)

# === Generate Output ===
with torch.no_grad():
    outputs = llama_model.generate(
        inputs_embeds=combined_embeddings,
        attention_mask=extended_attention_mask,
        max_new_tokens=50,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.9,
        pad_token_id=tokenizer.pad_token_id
    )

# === Decode Output ===
response = tokenizer.decode(outputs[0, PREFIX_TOKEN_COUNT:], skip_special_tokens=True)
print("\n=== Generated Intervention ===")
print(response)