# Lab 9

In the previous lab, we had a posterior distibution we couldn't sample from directly: $p(z | x_L)$, where $x_L$ was part of an image. We then constructed a sampler for this posterior using MCMC. 

In this lab, we will study transformer autoregressive models, which allows for sampling by construction. Let $x$ be the input to the model and $y_{1:T}$ be the output. An autoregressive model learns the distribution $p(y|x) = \prod_{i=1}^T p(y_i | y_{<i}, x)$. Here, we use the chain rule to predict the $y_i$'s one at a time, thereby modeling the joint distribution. For more on the transformer, read this [blog post](https://medium.com/@zepingyu/123-cb62513f5d50).

## Tokenization

In the case of language models, the $y_i$'s are tokens. To process text using a language model, we need to first convert the text into tokens, or chunks of characters (subwords). One can technically set the $y_i$'s to characters, and character-level language models indeed exist, but they are not as efficient as token-level language models when using a standard attention scheme. The most common tokenization scheme is Byte Pair Encoding (BPE), which runs like this:
1. Initialize the vocabulary with characters
2. Count the frequency of adjacent token pairs in the corpus
3. Merge the most frequent pair to create a new token
4. Add this new token to the vocabulary
5. Repeat steps 2-4 until a target vocabulary size is reached


In [79]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F

In [38]:
# load a model

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")

In [None]:
# forced choice

blimp = load_dataset("nyu-mll/blimp", "causative", split="train")

In [11]:
blimp[0]

{'sentence_good': 'Aaron breaks the glass.',
 'sentence_bad': 'Aaron appeared the glass.',
 'field': 'syntax',
 'linguistics_term': 'argument_structure',
 'UID': 'causative',
 'simple_LM_method': True,
 'one_prefix_method': False,
 'two_prefix_method': False,
 'lexically_identical': False,
 'pair_id': 0}

In [16]:
def evaluate_blimp_grammaticality(tokenizer, model, blimp, phenomenon=None):
    """
    Loads the BLiMP dataset from Hugging Face,
    performs forced choice grammaticality judgment using a language model,
    and reports overall accuracy.
    
    Args:
        tokenizer: The tokenizer to use for encoding sentences
        model: The language model to use for scoring sentences
        phenomena: List of specific phenomena to evaluate. If None, evaluates all phenomena.
        
    Returns:
        dict: A dictionary containing overall accuracy and per-phenomenon results
    """
    
    # Dictionary to store results
    results = {
    }
    
    # Make sure model is in evaluation mode
    model.eval()
    
    phenomenon_correct = 0
    phenomenon_total = len(blimp)
    
    # Track scores for this phenomenon
    good_scores = []
    bad_scores = []
    
    # Iterate through each minimal pair in this phenomenon
    for item in tqdm(blimp, desc=f"Evaluating {phenomenon}", leave=False):
        # Get the grammatical and ungrammatical sentences
        grammatical = item["sentence_good"]
        ungrammatical = item["sentence_bad"]
        
        # Compute log likelihood for grammatical sentence
        grammatical_score = compute_log_likelihood(grammatical, tokenizer, model)
        
        # Compute log likelihood for ungrammatical sentence
        ungrammatical_score = compute_log_likelihood(ungrammatical, tokenizer, model)
        
        # Store scores for analysis
        good_scores.append(grammatical_score)
        bad_scores.append(ungrammatical_score)
        
        # Prediction is correct if grammatical sentence has higher log likelihood
        prediction_correct = grammatical_score > ungrammatical_score
        
        # Update counter
        if prediction_correct:
            phenomenon_correct += 1
    
    # Calculate accuracy for this phenomenon
    phenomenon_accuracy = phenomenon_correct / phenomenon_total
    
    # Store results for this phenomenon
    results[phenomenon] = {
        "correct": phenomenon_correct,
        "total": phenomenon_total,
        "accuracy": phenomenon_accuracy,
        "avg_good_score": np.mean(good_scores),
        "avg_bad_score": np.mean(bad_scores),
        "avg_score_diff": np.mean(np.array(good_scores) - np.array(bad_scores))
    }
    
    return results

def compute_log_likelihood(sentence, tokenizer, model):
    """
    Compute the log likelihood of a sentence using the provided model.
    
    Args:
        sentence: The sentence to compute log likelihood for
        tokenizer: The tokenizer to use for encoding
        model: The language model
        
    Returns:
        float: The log likelihood of the sentence
    """
    # Tokenize the sentence
    inputs = tokenizer(sentence, return_tensors="pt")
    
    # Move inputs to the same device as the model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Get the input IDs for easier reference
    input_ids = inputs["input_ids"]
    
    # We'll compute the average log likelihood per token
    with torch.no_grad():
        # Get the model outputs
        outputs = model(**inputs, labels=input_ids)
        
        # Get the loss (negative log likelihood)
        neg_log_likelihood = outputs.loss.item()
        
        # Convert to positive log likelihood (higher is better)
        log_likelihood = -neg_log_likelihood
    
    return log_likelihood

In [17]:
results = evaluate_blimp_grammaticality(tokenizer, model, blimp)

                                                                    

In [47]:
# constrained decoding

boolq = load_dataset("google/boolq", split="validation")
print(boolq[0])

def merge_passage_and_question(example):
    input_text = example["passage"] + " " + example["question"] + "? "
    return {"input_text": input_text}

boolq = boolq.map(merge_passage_and_question, remove_columns=["passage", "question"])

# tokenize
def tokenize_function(examples):
    return tokenizer(examples["input_text"], truncation=True, padding="max_length", max_length=512)

tokenized_boolq = boolq.map(tokenize_function, batched=True, remove_columns=["input_text"])

{'question': 'does ethanol take more energy make that produces', 'answer': False, 'passage': "All biomass goes through at least some of these steps: it needs to be grown, collected, dried, fermented, distilled, and burned. All of these steps require resources and an infrastructure. The total amount of energy input into the process compared to the energy released by burning the resulting ethanol fuel is known as the energy balance (or ``energy returned on energy invested''). Figures compiled in a 2007 report by National Geographic Magazine point to modest results for corn ethanol produced in the US: one unit of fossil-fuel energy is required to create 1.3 energy units from the resulting ethanol. The energy balance for sugarcane ethanol produced in Brazil is more favorable, with one unit of fossil-fuel energy required to create 8 from the ethanol. Energy balance estimates are not easily produced, thus numerous such reports have been generated that are contradictory. For instance, a separ

Map: 100%|██████████| 3270/3270 [00:00<00:00, 7370.24 examples/s]


In [None]:
valid_labels = ['True', 'False']
valid_token_ids = [tokenizer.encode(label)[0] for label in valid_labels]
label2token = {
    True: tokenizer.encode("True")[0],
    False: tokenizer.encode("False")[0]
}
valid_token_ids

[2787, 4245]

In [77]:
@torch.inference_mode()
def constrained_decoding(model, dataset, valid_token_ids):
    """
    Performs constrained decoding using a language model on a pre-tokenized dataset.
    
    Args:
        model: The language model to use for scoring.
        dataset: A pre-tokenized dataset containing input_ids and labels.
        valid_token_ids: Tensor of token IDs that are allowed as next tokens.
        
    Returns:
        dict: Dictionary containing evaluation results and predictions.
    """
    # Ensure valid_token_ids is a tensor on the right device
    if not isinstance(valid_token_ids, torch.Tensor):
        valid_token_ids = torch.tensor(valid_token_ids, device=model.device)
    else:
        valid_token_ids = valid_token_ids.to(model.device)
    
    # Results dictionary
    results = {
        "correct": 0,
        "total": 0,
        "scores": [],
    }
    
    # Process all examples in the dataset
    for item in tqdm(dataset, desc="Performing constrained decoding"):
        # Get input ids and move to the same device as the model
        input_ids = torch.tensor(item["input_ids"]).to(model.device)
        attention_mask = torch.tensor(item["attention_mask"]).to(model.device)
        true_label = item["answer"]
        true_label_token = label2token[true_label]

        
        # Ensure input is batched (add batch dimension if needed)
        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)
            attention_mask = attention_mask.unsqueeze(0)
        
        # Forward pass to get logits
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        
        # Get logits for the last token position
        last_token_logits = logits[0, -1, :]
        
        # Extract logits only for valid tokens using torch.take_along_dim
        valid_token_logits = torch.take_along_dim(
            last_token_logits.unsqueeze(0),
            valid_token_ids.unsqueeze(0),
            dim=1
        ).squeeze(0)
        
        # Find the highest scoring valid token
        max_logit_idx = torch.argmax(valid_token_logits).item()
        scores = F.softmax(valid_token_logits, dim=-1)
        score = scores[max_logit_idx].item()
        label_token = valid_token_ids[max_logit_idx].cpu().item()

        results["scores"].append(score)
        results["total"] += 1
        results["correct"] += (label_token == true_label_token)

    return results   

In [80]:
results = constrained_decoding(model, tokenized_boolq.select(range(10)), valid_token_ids)

Performing constrained decoding: 100%|██████████| 10/10 [00:55<00:00,  5.59s/it]
