# LLM-Based Behavioral Data Generation

## Overview
This script generates a structured dataset of behavioral decisions using a Large Language Model (LLM).  
It systematically samples combinations of *personal traits* and *situational contexts* and queries an LLM to estimate the likelihood of taking a specific action, the action *"buys ice cream"*.  
The resulting dataset is stored as JSON Lines files.

---

## 1. Experimental Objective
Let the LLM be represented as a conditional function  

$$
f_\theta(x) = \text{LLM output for prompt } x,
$$  

where each prompt $x$ describes a unique configuration of traits and contexts.  
The goal is to approximate the mapping

$$
x = (\text{traits}, \text{context}) \mapsto y = \mathbb{E}[\,p(\text{buy} \mid x)\,],
$$  

where $p(\text{buy} \mid x) \in [0, 1]$ is the model’s estimated probability that the individual takes the action “buys ice cream.”

---

## 2. Model Initialization

A pre-trained autoregressive model (e.g., `gemma-3-4b-it`) is loaded from a local directory using Hugging Face Transformers.  
The model operates in inference mode (`eval()`), using `bfloat16` precision and GPU memory optimization.  

Formally, the model computes token-level conditional probabilities  

$$
p_\theta(w_t \mid w_{<t}, x),
$$  

and the script deterministically decodes the most probable sequence using greedy decoding (`do_sample=False`).

---

## 3. Trait and Context Definition

Two finite sets are defined:

$$
\mathcal{T} = \{\text{traits such as "likes sweets", "cheap", "impulsive buyer", ...}\},
$$

$$
\mathcal{C} = \{\text{contexts such as "hot summer day", "after work", ...}\}.
$$

Each prompt uses combinations $(T, C)$ with  
$T \subseteq \mathcal{T}$, $C \subseteq \mathcal{C}$.

To prevent semantically inconsistent combinations, logical exclusion rules are applied:
$$
(T, C) \text{ is valid only if } 
T \cap \tau_i = \emptyset \ \forall \tau_i \in \text{TRAIT\_CONFLICTS}, \quad
$$
$$
C \cap \kappa_i = \emptyset \ \forall \kappa_i \in \text{CONTEXT\_CONFLICTS}.
$$

---

## 4. Prompt Template

Each valid pair $(T, C)$ is inserted into a fixed natural-language template:

```
Your decisions in everyday life are influenced by your personal traits and the context you find yourself in.

Traits: {traits}
Context: {context}

The action of interest is whether you buy ice cream.
Please assess how likely you are to take the action "buys ice cream" and provide your reasoning.
Return JSON with:
{"buy": <float between 0 and 1>, "explanation": "<text>"}
```

Thus, for each input $x_i = (T_i, C_i)$, the model produces text $y_i$ that should contain a JSON object with a numeric field `buy`.

---

## 5. Data Generation Process

For each valid trait–context pair, the model generates a completion $y_i = f_\theta(x_i)$.
The JSON portion of the model’s output is parsed to extract the scalar value.  

The data record is then stored as
```json
{
  "traits": [...],
  "context": [...],
  "probability_LLM": 0.74
}
```

The full dataset is split into:
- 80% training examples  
- 20% test examples  

and written to:
```
../data/train.jsonl
../data/test.jsonl
```

---

## 6. Batch Generation

Prompts are processed in batches of size $B$ to optimize GPU utilization.  
Each batch $\{x_{i_1}, \dots, x_{i_B}\}$ is converted into a chat-style input format and tokenized jointly.  
Generation is performed under deterministic decoding, ensuring consistent probability extraction.

Formally:
$$
Y_B = \text{decode}\big( \text{argmax}_{w_t} p_\theta(w_t \mid w_{<t}, x_{i_b}) \big)_{b=1}^B.
$$

---

## 7. Postprocessing and Storage

Each generated text is parsed, and valid JSON fragments are extracted.  
If the value of `"buy"` is missing or non-numeric, the example is discarded.  
The remaining dataset is shuffled and partitioned for later supervised training.

---

## 8. Extensions

- Replace the single action “buys ice cream” with arbitrary decision actions.  
- Add explanation extraction for interpretability studies.  
- Use the generated dataset for fine-tuning smaller “student” models.


In [None]:
import os
import gc
import json
import random
import torch
from itertools import combinations
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm


# ============================================================
# MODEL SETUP
# ============================================================
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

MODEL_PATH = "../models/saved_models/gemma-3-4b-it"

def free_gpu_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

print("🚀 Lade Modell ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    dtype=torch.bfloat16,
    device_map="cuda:0",
    low_cpu_mem_usage=True
)
model.eval()

# ============================================================
# TRAITS & CONTEXT DEFINITIONS
# ============================================================
TRAITS = [
    "likes sweets",
    "dislikes sweets",
    "health-conscious",
    "lactose intolerant",
    "cheap",
    "spender",
    "impulsive buyer"
]

CONTEXTS = [
    "hungry",
    "on a diet",
    "ice cream truck nearby",
    "hot summer day",
    "cold winter day",
    "ice cream is cheap today (discount)",
    "after a long workout",
    "after lunch",
    "after work"
]

# ============================================================
# CONFLICT RULES
# ============================================================
TRAIT_CONFLICTS = [
    {"likes sweets", "dislikes sweets"},
    {"cheap", "spender"},
    {"cheap", "impulsive buyer"}
]

CONTEXT_CONFLICTS = [
    {"hot summer day", "cold winter day"},
    {"hungry", "after lunch"}
]

def valid_combo(combo, conflicts):
    for conflict in conflicts:
        if conflict.issubset(set(combo)):
            return False
    return True

# ============================================================
# PROMPT TEMPLATE
# ============================================================
PROMPT_TEMPLATE = """Your decisions in everyday life are influenced by your personal traits and the context you find yourself in. 

In this scenario, your traits are: {traits}.  
The current context is: {context}.  

The action of interest is whether you buy ice cream. 
Your decision depends on how your traits interact with the current context. 

Please assess how likely you are to take the action "buys ice cream" in this context considering factors like price, health concerns, and social influences.  

Please share your decision in a JSON format.  
The format should have one key:  
- 'buy' (a value between 0 and 1 with intervals of 0.02, indicating the willingness or propensity to buy ice cream).
"""

# ============================================================
# DATA GENERATION
# ============================================================
def build_prompts():
    prompts = []
    meta = []
    for t_count in range(1, 7):
        for c_count in range(1, 7):
            trait_combos = [
                t for t in combinations(TRAITS, t_count)
                if valid_combo(t, TRAIT_CONFLICTS)
            ]
            context_combos = [
                c for c in combinations(CONTEXTS, c_count)
                if valid_combo(c, CONTEXT_CONFLICTS)
            ]

            sampled_traits = random.sample(trait_combos, k=min(30, len(trait_combos)))
            sampled_contexts = random.sample(context_combos, k=min(30, len(context_combos)))

            for traits in sampled_traits:
                for context in sampled_contexts:
                    prompt = PROMPT_TEMPLATE.format(
                        traits=json.dumps(traits),
                        context=json.dumps(context)
                    )
                    prompts.append(prompt)
                    meta.append((traits, context))
    print(f"Generiere {len(prompts)} Prompts aus {len(TRAITS)} Traits und {len(CONTEXTS)} Contexts.")
    return prompts, meta



def generate_batch(prompts, batch_size=8, max_new_tokens=200):
    """Erzeugt Antworten in Batches mit demselben Chat-Template wie im Chat-Skript."""
    all_outputs = []

    eos_token_id = tokenizer.eos_token_id or tokenizer.pad_token_id
    pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id

    for i in tqdm(range(0, len(prompts), batch_size), desc="🔄 Generiere Batches", unit="batch"):
        batch_prompts = prompts[i:i + batch_size]

        # Verwende dasselbe Chatformat wie im funktionierenden Skript
        chat_texts = [
            tokenizer.apply_chat_template(
                [{"role": "user", "content": p}],
                tokenize=False,
                add_generation_prompt=True
            )
            for p in batch_prompts
        ]

        # Nur beim ersten Batch den Prompt anzeigen
        # if i == 0:
#             print("\n🧾 Beispiel-Prompt an das Modell:")
#             print("=" * 80)
#             print(chat_texts[0])
#             print("=" * 80)

        inputs = tokenizer(chat_texts, return_tensors="pt", padding=True, truncation=True).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                do_sample=False,      # deterministisch für stabile Werte
                #temperature=0.0,
                #top_p=1.0,
                max_new_tokens=max_new_tokens,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                repetition_penalty=1.1
            )

        # Ausgabe wie im Chat-Skript decodieren
        for j in range(len(batch_prompts)):
            gen_tokens = outputs[j][inputs["input_ids"].shape[-1]:]
            answer = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
            # Falls das Modell ein Präfix "model" ausgibt → entfernen
            if answer.lower().startswith("model"):
                answer = answer[len("model"):].lstrip(": -\n")

            all_outputs.append(answer)

    return all_outputs





def extract_prob(text):
    """Parst den JSON-Anteil und extrahiert den 'buy'-Wert."""
    try:
        json_start = text.find("{")
        json_end = text.rfind("}") + 1
        data = json.loads(text[json_start:json_end])
        prob = float(data.get("buy", None))
        if 0 <= prob <= 1:
            return round(prob, 2)
    except Exception:
        pass
    return None


# ============================================================
# MAIN PIPELINE
# ============================================================
if __name__ == "__main__":
    random.seed(42)
    print("📋 Erzeuge Prompts ...")
    prompts, meta = build_prompts()
    print(f"➡️ {len(prompts)} Prompts erzeugt.")

    print("🧠 Generiere Modellantworten in Batches ...")

    # Nur die ersten 8 Prompts nehmen (1 Batch)
    test_prompts = prompts[:8]
    test_meta = meta[:8]
    
    # Einen Batch generieren
    # results = generate_batch(test_prompts, batch_size=8)
#     for i, (traits, context) in enumerate(test_meta):
#         print("\n" + "="*80)
#         print(f"🧩 Beispiel {i+1}")
#         print(f"Traits:  {traits}")
#         print(f"Context: {context}")
#         print(f"🔹 Roh-Output des Modells:")
#         print(results[i])
#         print("="*80)
    results = generate_batch(prompts, batch_size=8)

    DATA = []
    for (traits, context), text in zip(meta, results):
        prob = extract_prob(text)
        if prob is not None:
            DATA.append({
                "traits": list(traits),
                "context": list(context),
                "probability_LLM": prob
            })
            print(f"✅ {traits} + {context} → {prob:.2f}")

    random.shuffle(DATA)
    split = int(0.8 * len(DATA))
    train_data = DATA[:split]
    test_data = DATA[split:]

    # os.makedirs("generated_data", exist_ok=True)
    with open("../data/train.jsonl", "w", encoding="utf-8") as f:
        for d in train_data:
            f.write(json.dumps(d) + "\n")
    with open("../data/test.jsonl", "w", encoding="utf-8") as f:
        for d in test_data:
            f.write(json.dumps(d) + "\n")

    print(f"\n📊 Fertig! {len(train_data)} Trainings- und {len(test_data)} Testbeispiele gespeichert.")
    free_gpu_memory()
    print("🧹 GPU-Speicher bereinigt. Ende.")


🚀 Lade Modell ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

📋 Erzeuge Prompts ...
Generiere 9027 Prompts aus 7 Traits und 9 Contexts.
➡️ 9027 Prompts erzeugt.
🧠 Generiere Modellantworten in Batches ...


🔄 Generiere Batches:   0%|          | 0/1129 [00:00<?, ?batch/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
🔄 Generiere Batches:   8%|▊         | 89/1129 [51:26<10:02:49, 34.78s/batch]