1. Le Concept Théorique

- Température ($T$) : Modifie la distribution token par token au moment de la génération. Elle aplatit ou aiguise les choix locaux.

- Power Sampling ($\alpha$) : Modifie la distribution de la séquence entière. Elle favorise les chemins qui sont globalement plus cohérents, même si certains tokens individuels ne sont pas les plus probables localement.

In [1]:

import torch 
import torch.nn.functional as F 

def compare_power_vs_temp(model, tokenizer, sequences, alpha=2.0, temp=0.5):

    results = []
    
    for seq in sequences:
        # 1. Obtenir les log-probabilités de chaque token
        inputs = tokenizer(seq, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            logits = outputs.logits # Forme: [1, seq_len, vocab_size]
        
        # On décale les logits pour faire correspondre logit[t] -> token[t+1]
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = inputs["input_ids"][..., 1:].contiguous()
        
        # Log-probabilités brutes (p_i)
        log_probs = F.log_softmax(shift_logits, dim=-1)
        # On extrait la log-prob de chaque token choisi dans la séquence
        token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
        
        # --- CALCULS ---
        
        # A. Power Sampling: log(P(x)^alpha) = alpha * somme(log_p_i)
        power_score = alpha * token_log_probs.sum().item()
        
        # B. Low Temperature: somme(log( softmax(logits_i / T) ))
        # Note: log_softmax(logits/T) n'est pas juste log_p/T à cause de la normalisation
        temp_logits = shift_logits / temp
        temp_log_probs = F.log_softmax(temp_logits, dim=-1)
        temp_token_log_probs = temp_log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
        temp_score = temp_token_log_probs.sum().item()
        
        results.append({
            "sequence": seq,
            "raw_log_p": token_log_probs.sum().item(),
            "power_score": power_score,
            "temp_score": temp_score
        })
        
    return results


from transformers import AutoModelForCausalLM, AutoTokenizer
# On choisit un modèle très léger (environ 500 Mo)
model_name = "facebook/opt-125m" 

print("Téléchargement du tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)

print("Téléchargement du modèle...")
model = AutoModelForCausalLM.from_pretrained(model_name)

# On déplace le modèle sur le GPU si disponible, sinon CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print(f"Modèle chargé sur {device} !")


  from .autonotebook import tqdm as notebook_tqdm


Téléchargement du tokenizer...
Téléchargement du modèle...
Modèle chargé sur cpu !


In [None]:
# --- TEST ---
sequences_test = [
    "7 is a prime number because it only has two divisors.",
    "7 is prime number as it is only divisible by 1 and 7."
]

data = compare_power_vs_temp(model, tokenizer, sequences_test)





for res in data:
    print(f"Seq: {res['sequence'][:30]}... | Raw: {res['raw_log_p']:.2f} | Power: {res['power_score']:.2f} | Temp: {res['temp_score']:.2f}")






test_pairs = [
    ["2 + 2 is 4", "2 + 2 is 5"],
    ["The capital of France is Paris.", "The capital of France is Lyon."],
    ["A prime number has two divisors.", "A prime number has three divisors."]
]

print(f"{'Méthode':<15} | {'Paire':<10} | {'Log-Likelihood':<15} | {'Gap (A-B)'}")
print("-" * 60)

for i, pair in enumerate(test_pairs):
    res = compare_power_vs_temp(model, tokenizer, pair, alpha=2.0, temp=0.5)
    
    # Calcul des écarts (Gap)
    gap_raw = res[0]['raw_log_p'] - res[1]['raw_log_p']
    gap_power = res[0]['power_score'] - res[1]['power_score']
    gap_temp = res[0]['temp_score'] - res[1]['temp_score']
    
    print(f"{'Raw':<15} | {f'Paire {i+1}':<10} | {res[0]['raw_log_p']:<15.2f} | {gap_raw:.2f}")
    print(f"{'Power (a=2)':<15} | {'':<10} | {res[0]['power_score']:<15.2f} | {gap_power:.2f}")
    print(f"{'Temp (T=0.5)':<15} | {'':<10} | {res[0]['temp_score']:<15.2f} | {gap_temp:.2f}")
    print("-" * 60)

Seq: 7 is a prime number because it... | Raw: -47.89 | Power: -95.78 | Temp: -64.71
Seq: 7 is prime number as it is onl... | Raw: -61.96 | Power: -123.93 | Temp: -91.24
Méthode         | Paire      | Log-Likelihood  | Gap (A-B)
------------------------------------------------------------
Raw             | Paire 1    | -31.93          | 0.35
Power (a=2)     |            | -63.87          | 0.70
Temp (T=0.5)    |            | -50.22          | 0.70
------------------------------------------------------------
