In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from model_wrapper_with_mlp_adapter import FeaturePrefixAdapter, PrefixLLaMAModel
import os

# === Config ===
MODEL_PATH = "/content/drive/MyDrive/llama2_driver_intervention_model_2"
PREFIX_TOKEN_COUNT = 5
FEATURE_DIM = 9
EMBEDDING_DIM = 4096
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# === Load tokenizer and model ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
llama = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto")
adapter = FeaturePrefixAdapter(
    input_dim=FEATURE_DIM,
    hidden_dim=256,
    output_dim=EMBEDDING_DIM,
    num_tokens=PREFIX_TOKEN_COUNT
)
adapter.load_state_dict(torch.load(os.path.join(MODEL_PATH, "prefix_adapter.pth"), map_location=DEVICE))

# === Wrap the model
model = PrefixLLaMAModel(llama, adapter).to(DEVICE)
model.eval()

# === Minimal prompt
MIN_PROMPT = "Based on the above signals, what should be the appropriate intervention?"

# === Inference Function
@torch.no_grad()
def generate_intervention(features, max_new_tokens=100):
    # Convert features to tensor and get prefix embedding
    features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(DEVICE)
    prefix_embed = model.adapter(features_tensor)  # shape: (1, PREFIX_TOKEN_COUNT, EMBEDDING_DIM)

    # Tokenize the minimal prompt (no long context!)
    inputs = tokenizer(MIN_PROMPT, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Get embedding for short prompt
    text_embed = model.embedding_layer(input_ids)

    # Concatenate: prefix + prompt embedding
    full_embed = torch.cat([prefix_embed.to(text_embed.dtype), text_embed], dim=1)

    # Update attention mask
    prefix_mask = torch.ones((1, PREFIX_TOKEN_COUNT), dtype=attention_mask.dtype).to(DEVICE)
    full_mask = torch.cat([prefix_mask, attention_mask], dim=1)

    # Generate output
    output = model.llama.generate(
        inputs_embeds=full_embed,
        attention_mask=full_mask,
        max_new_tokens=max_new_tokens,
        do_sample=False
    )

    return tokenizer.decode(output[0], skip_special_tokens=True)

# === Example
if __name__ == "__main__":
    example_features = [14.8, 3.2, 65.4, 0.28, 0.81, 1.9, 1.1, 3.4, 3.9]  # Replace with real-time input
    result = generate_intervention(example_features)
    print("Generated Intervention:\n", result)
