# Chat Template Control Experiment: Llama-3.1-8B-Instruct

**Purpose:** Test whether chat templates affect the embedding-output relationship.

**Background:** Kevin-pw pointed out that chat-tuned models require proper templating.

**Experimental Design:**
- Condition A: Instruct model WITH chat template
- Condition B: Instruct model WITHOUT chat template

**Expected Runtime:** ~3-4h on A100

---

**Author:** Davide D'Elia  
**Date:** 2026-01-03

## 1. Setup

In [None]:
!pip install -q transformers accelerate torch numpy scipy matplotlib scikit-learn huggingface_hub

In [None]:
import json
import warnings
from datetime import datetime
from typing import Dict, List, Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import stats
from transformers import AutoModelForCausalLM, AutoTokenizer

warnings.filterwarnings('ignore')

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

N_BOOTSTRAP = 10000
CI_LEVEL = 0.95

print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')

In [None]:
# ========================================
# HUGGINGFACE LOGIN
# ========================================
# OPTION 1: Colab Secrets (empfohlen)
try:
    from google.colab import userdata
    HF_TOKEN = userdata.get('HF_TOKEN')
    print("✅ Token aus Colab Secrets geladen")
except:
    HF_TOKEN = None

# OPTION 2: Manuell eingeben
if not HF_TOKEN:
    HF_TOKEN = ''  # <-- Hier Token einfügen falls nötig

# OPTION 3: Interaktiver Login
if not HF_TOKEN:
    from huggingface_hub import notebook_login
    notebook_login()
else:
    from huggingface_hub import login
    login(token=HF_TOKEN)
    print("✅ Eingeloggt")

## 2. Load Instruct Model

In [None]:
MODEL_NAME = 'meta-llama/Llama-3.1-8B-Instruct'
MODEL_DISPLAY = 'Llama-3.1-8B-Instruct'

print(f'Loading {MODEL_DISPLAY}...')

token_arg = HF_TOKEN if HF_TOKEN else True

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=token_arg)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    token=token_arg,
    torch_dtype=torch.float16,
    device_map='auto',
    output_hidden_states=True
)

print(f'✅ Layers: {model.config.num_hidden_layers}')
print(f'Chat template: {tokenizer.chat_template is not None}')

In [None]:
# Demonstrate chat template
example = 'The Earth is flat.'
messages = [{'role': 'user', 'content': example}]
templated = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
print('Raw:', example)
print('Templated:', templated[:80], '...')

## 3. Load Dataset

In [None]:
!wget -q https://raw.githubusercontent.com/buk81/uniformity-asymmetry/main/dataset.json

with open('dataset.json', 'r') as f:
    DATASET = json.load(f)

ALL_PAIRS = []
for cat_name, cat_data in DATASET.items():
    for pair in cat_data['pairs']:
        ALL_PAIRS.append({'stmt_a': pair[0], 'stmt_b': pair[1], 'category': cat_name})

print(f'Total pairs: {len(ALL_PAIRS)}')

## 4. Core Functions

In [None]:
def apply_template(text, tokenizer):
    messages = [{'role': 'user', 'content': text}]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

def get_layer_embedding(text, model, tokenizer, layer_idx, use_template=False):
    if use_template:
        text = apply_template(text, tokenizer)
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512).to(model.device)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    layer_hidden = outputs.hidden_states[layer_idx]
    return layer_hidden[0, 1:, :].mean(dim=0).cpu().numpy().astype(np.float32)

def get_output_preference(text_a, text_b, model, tokenizer, use_template=False):
    def get_nll(text):
        if use_template:
            text = apply_template(text, tokenizer)
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512).to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs['input_ids'])
        return outputs.loss.item()
    return get_nll(text_b) - get_nll(text_a)

def cosine_similarity(a, b):
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10))

def bootstrap_correlation(x, y, n_bootstrap=10000, ci_level=0.95):
    n = len(x)
    r_obs, p = stats.pearsonr(x, y)
    rs = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(n, size=n, replace=True)
        if np.std(x[idx]) > 0 and np.std(y[idx]) > 0:
            r, _ = stats.pearsonr(x[idx], y[idx])
            rs.append(r)
    rs = np.array(rs)
    alpha = 1 - ci_level
    return float(r_obs), float(np.percentile(rs, alpha/2*100)), float(np.percentile(rs, (1-alpha/2)*100)), float(p)

print('Functions defined.')

## 5. Collect Embeddings (Both Conditions)

In [None]:
N_LAYERS = model.config.num_hidden_layers
LAYERS_TO_TEST = list(range(0, N_LAYERS + 1, 4))
if N_LAYERS not in LAYERS_TO_TEST:
    LAYERS_TO_TEST.append(N_LAYERS)

print(f'Layers: {LAYERS_TO_TEST}')
print(f'Running TWO conditions: WITH and WITHOUT template')

pair_data_with = []
pair_data_without = []
start = datetime.now()

for i, pair in enumerate(ALL_PAIRS):
    if (i + 1) % 25 == 0:
        elapsed = (datetime.now() - start).total_seconds() / 60
        print(f'  [{i+1}/{len(ALL_PAIRS)}] - {elapsed:.1f} min')
    
    # WITH template
    pref_w = get_output_preference(pair['stmt_a'], pair['stmt_b'], model, tokenizer, use_template=True)
    embs_w = {l: {'a': get_layer_embedding(pair['stmt_a'], model, tokenizer, l, True),
                  'b': get_layer_embedding(pair['stmt_b'], model, tokenizer, l, True)} for l in LAYERS_TO_TEST}
    pair_data_with.append({'pref': pref_w, 'cat': pair['category'], 'embs': embs_w})
    
    # WITHOUT template
    pref_wo = get_output_preference(pair['stmt_a'], pair['stmt_b'], model, tokenizer, use_template=False)
    embs_wo = {l: {'a': get_layer_embedding(pair['stmt_a'], model, tokenizer, l, False),
                   'b': get_layer_embedding(pair['stmt_b'], model, tokenizer, l, False)} for l in LAYERS_TO_TEST}
    pair_data_without.append({'pref': pref_wo, 'cat': pair['category'], 'embs': embs_wo})

print(f'Done in {(datetime.now() - start).total_seconds() / 60:.1f} min')

## 6. Analyze Both Conditions

In [None]:
def compute_centroid_asymmetry(pair_data, layer_idx):
    embs_a = np.array([p['embs'][layer_idx]['a'] for p in pair_data])
    embs_b = np.array([p['embs'][layer_idx]['b'] for p in pair_data])
    cent_a, cent_b = embs_a.mean(0), embs_b.mean(0)
    dist_a = np.array([cosine_similarity(e, cent_a) for e in embs_a])
    dist_b = np.array([cosine_similarity(e, cent_b) for e in embs_b])
    return dist_a - dist_b

def analyze(pair_data, name):
    print(f'\n=== {name} ===')
    prefs = np.array([p['pref'] for p in pair_data])
    results = {}
    for l in LAYERS_TO_TEST:
        metric = compute_centroid_asymmetry(pair_data, l)
        r, ci_lo, ci_hi, p = bootstrap_correlation(metric, prefs, N_BOOTSTRAP, CI_LEVEL)
        results[l] = {'r': r, 'ci_lo': ci_lo, 'ci_hi': ci_hi, 'sig': not (ci_lo <= 0 <= ci_hi)}
        sig = '***' if results[l]['sig'] else ''
        print(f'  Layer {l:2d}: r={r:+.3f} CI=[{ci_lo:+.3f},{ci_hi:+.3f}] {sig}')
    return results

results_with = analyze(pair_data_with, 'WITH Template')
results_without = analyze(pair_data_without, 'WITHOUT Template')

## 7. Compare Conditions

In [None]:
print('\n' + '='*60)
print(' TEMPLATE EFFECT COMPARISON')
print('='*60)
print(f'Layer   WITH      WITHOUT   Same Sign?')
print('-'*45)

same = 0
for l in LAYERS_TO_TEST:
    rw = results_with[l]['r']
    rwo = results_without[l]['r']
    ss = (rw * rwo) > 0
    if ss: same += 1
    print(f'{l:2d}      {rw:+.3f}    {rwo:+.3f}    {"YES" if ss else "NO"}')

print(f'\nSame sign: {same}/{len(LAYERS_TO_TEST)}')

# Phase structure
def phase_means(results):
    e = np.mean([results[l]['r'] for l in [0,4,8]])
    m = np.mean([results[l]['r'] for l in [12,16,20]])
    t = np.mean([results[l]['r'] for l in [24,28,32]])
    return e, m, t

ew, mw, tw = phase_means(results_with)
ewo, mwo, two = phase_means(results_without)

print(f'\nPhase Structure:')
print(f'           WITH     WITHOUT')
print(f'Early:    {ew:+.3f}   {ewo:+.3f}')
print(f'Mid:      {mw:+.3f}   {mwo:+.3f}')
print(f'Late:     {tw:+.3f}   {two:+.3f}')

## 8. Verdict

In [None]:
def pattern_str(e, m, t):
    if e > 0.1 and m > 0.1 and t < -0.1:
        return 'PYTHIA_PATTERN'
    elif abs(e) < 0.1 and abs(m) < 0.1 and abs(t) < 0.1:
        return 'DECOUPLED'
    else:
        return 'OTHER'

pat_w = pattern_str(ew, mw, tw)
pat_wo = pattern_str(ewo, mwo, two)

print('\n' + '='*60)
print(' VERDICT')
print('='*60)
print(f'WITH template: {pat_w}')
print(f'WITHOUT template: {pat_wo}')

if pat_w == pat_wo:
    verdict = 'TEMPLATE_INVARIANT'
elif pat_wo == 'DECOUPLED' and pat_w != 'DECOUPLED':
    verdict = 'TEMPLATE_ENABLES_COUPLING'
else:
    verdict = 'TEMPLATE_CHANGES_PATTERN'

print(f'\n>>> {verdict} <<<')

## 9. Save Results

In [None]:
save_data = {
    'timestamp': datetime.now().isoformat(),
    'model': MODEL_NAME,
    'n_pairs': len(ALL_PAIRS),
    'results_with_template': {str(k): v for k, v in results_with.items()},
    'results_without_template': {str(k): v for k, v in results_without.items()},
    'phase_with': {'early': ew, 'mid': mw, 'late': tw, 'pattern': pat_w},
    'phase_without': {'early': ewo, 'mid': mwo, 'late': two, 'pattern': pat_wo},
    'verdict': verdict
}

fname = f'llama3_instruct_template_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
with open(fname, 'w') as f:
    json.dump(save_data, f, indent=2)
print(f'Saved: {fname}')

from google.colab import files
files.download(fname)

## 10. Visualization

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

rs_w = [results_with[l]['r'] for l in LAYERS_TO_TEST]
rs_wo = [results_without[l]['r'] for l in LAYERS_TO_TEST]

ax.plot(LAYERS_TO_TEST, rs_w, 'o-', label='WITH template', color='blue', markersize=8)
ax.plot(LAYERS_TO_TEST, rs_wo, 's--', label='WITHOUT template', color='red', markersize=8)
ax.axhline(0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Layer')
ax.set_ylabel('r(centroid_asymmetry, output)')
ax.set_title(f'{MODEL_DISPLAY}: Template Effect\nVerdict: {verdict}')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('llama3_instruct_template.png', dpi=150)
plt.show()
print('Plot saved')

from google.colab import files
files.download('llama3_instruct_template.png')