In [None]:
# # %% [markdown]
# # Week 2 Practical: Controlled Text Generation and Basic Finetuning
#
# In this practical, we'll explore different ways to control generation:
# 1. **Baseline**: Free generation with different sampling strategies
# 2. **Simple constraints**: Masking tokens to avoid specific words
# 3. **Complex constraints**: JSON-structured output (demo)
# 4. **Sentiment classifier finetuning**: Training a lightweight DistillBERT model
# 5. **Semantic steering**: Biasing generation toward positive sentiment
# 6. **Analysis**: Understanding trade-offs between control and coherence

## Setup

In [None]:
# Installing required packages

%pip install accelerate==1.12.0
%pip install bertviz==1.4.1
%pip install datasets==4.4.1
%pip install evaluate==0.4.6
%pip install gensim==4.4.0
%pip install matplotlib==3.10.7
%pip install outlines==1.2.9
%pip install pandas==2.3.3
%pip install scikit-learn==1.7.2
%pip install seaborn==0.13.2
%pip install torch==2.9.1
%pip install transformers==4.57.1


In [None]:
import torch
import outlines
import evaluate
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList, AutoModelForSequenceClassification, pipeline, TrainingArguments, Trainer
from typing import List, Tuple, Literal
from datasets import load_dataset


In [None]:
def get_best_device():
    """Returns the best device on this computer"""
    import torch

    if torch.cuda.is_available():
        device = torch.device("cuda")
        total_memory = torch.cuda.get_device_properties(device).total_memory
        print(f"GPU Memory: {total_memory / 1e9:.1f} GB")
        print(f"GPU Name: {torch.cuda.get_device_name(device)}")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print(f"Found device: {device}")
    return device


device = get_best_device()

In [None]:
# Load generation model
# We use TinyLlama-1.1B: lightweight, reliable KV caching, fits in 8GB
# ~2.2GB in float16, leaves memory for classifier and other operations
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, dtype=torch.float16, device_map="auto"
)
model.eval()


def build_prompt(user: str, system="You are a helpful assistant"):
    return f"""<|system|>
{system}</s>
<|user|>
{user}</s>
<|assistant|>
"""

## Part 0: Understanding Autoregressive Decoding

Before using `generate()`, let's implement sampling manually how decoders work.

In [None]:
@torch.no_grad()
def manual_generate(
    model,
    tokenizer,
    prompt: str,
    max_length: int = 50,
    temperature: float = 1.0,
    top_p: float | None = None,
):
    """
    Manually implement autoregressive decoding to understand sampling strategies.

    Decoding works by:
    1. Encode the prompt to get initial input_ids
    2. Forward pass: compute logits for next token
    3. Apply sampling filters (temperature, top-k, top-p)
    4. Sample from logits
    5. Append sampled token to sequence
    6. Repeat until EOS or max_length

    Args:
        model: Language model
        tokenizer: Tokenizer
        prompt: Input text
        max_length: Maximum generation length
        temperature: Scales logits (higher = more random)
        top_p: Nucleus sampling threshold (None = no limit)
    """

    # Encode prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    seq_length = input_ids.shape[1]

    print(f"Starting with prompt ({seq_length} tokens): {prompt}")
    print(f"Sampling strategy: temperature={temperature}, top_p={top_p}")
    print(f"Generating up to {max_length} tokens...\n")

    for step in range(max_length):
        # Forward pass (recomputes from scratch for clarity)
        # Production code uses KV caching via generate()
        outputs = model(input_ids=input_ids, return_dict=True)

        # Get logits for next token (last position)
        next_token_logits = outputs.logits[:, -1, :]

        # Apply temperature
        next_token_logits = next_token_logits / temperature

        # Apply top-p (nucleus) filtering
        if top_p is not None:
            # Implement top-p (nucleus)

            assert False, 'Not implemented yet'


        # Compute probabilities
        probs = torch.softmax(next_token_logits, dim=-1)

        # Sample next token
        next_token_id = torch.multinomial(probs, num_samples=1)

        # Append to sequence
        input_ids = torch.cat([input_ids, next_token_id], dim=1)

        # Decode and print
        token_str = tokenizer.decode(
            [next_token_id.item()], skip_special_tokens=False
        )
        print(f"Step {step+1}: {token_str}", end=" ", flush=True)

        # Stop at EOS
        if next_token_id.item() == tokenizer.eos_token_id:
            print("\n[EOS reached]")
            break

    print("\n")
    full_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return full_text

In [None]:
# Example: Different sampling strategies

prompt = build_prompt("Write a review about a restaurant.")
prompt

In [None]:
print("=" * 80)
print("SAMPLING WITH MANUAL IMPLEMENTATION")
print("=" * 80)
text_manual = manual_generate(
    model, tokenizer, prompt, max_length=15, temperature=0.7, top_p=None
)
print(text_manual)

Now, implement the missing code (p-sampling)
and play with the following cell

In [None]:
print("=" * 80)
print("SAMPLING WITH MANUAL IMPLEMENTATION (top_p=.9)")
print("=" * 80)

text_manual = manual_generate(
    model, tokenizer, prompt, max_length=15, temperature=0.7, top_p=0.9
)
print(text_manual)

### KV Cache Efficiency (using built-in generate)

The manual loop recomputes everything. Modern transformers use KV caching
automatically. Let's benchmark the difference using `generate()`:

In [None]:
import time

prompt = "The restaurant had"
max_length = 100

# Warmup
_ = model.generate(**tokenizer(prompt, return_tensors="pt").to(device), max_length=10)

print("=" * 80)
print("WITH KV CACHE (default in generate, fast)")
print("=" * 80)
start = time.time()
for _ in range(3):
    outputs = model.generate(
        **tokenizer(prompt, return_tensors="pt").to(device),
        max_length=max_length,
        temperature=0.7,
        do_sample=True,
    )
time_with_cache = (time.time() - start) / 3
print(f"Average time: {time_with_cache:.3f}s for {len(outputs[0])} tokens")
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)[:100]}...")

In [None]:
print("\n" + "=" * 80)
print("WITHOUT KV CACHE (use_cache=False, slow)")
print("=" * 80)
start = time.time()
for _ in range(3):
    outputs = model.generate(
        **tokenizer(prompt, return_tensors="pt").to(device),
        max_length=max_length,
        temperature=0.7,
        do_sample=True,
        use_cache=False,  # Disable KV caching
    )
time_without_cache = (time.time() - start) / 3
print(f"Average time: {time_without_cache:.3f}s for {len(outputs[0])} tokens")
print(f"Speedup with cache: {time_without_cache / time_with_cache:.1f}x")
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)[:100]}...")

### Sampling strategies comparison

In [None]:
# Compare different sampling strategies
prompt = "The restaurant had"

print("=" * 80)
print("GREEDY (argmax, deterministic)")
print("=" * 80)
# For greedy, we need special handling (just take argmax)
# Manual implementation is simpler with generate()
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model.generate(
        **inputs, max_length=40, do_sample=False, pad_token_id=tokenizer.eos_token_id
    )
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
print("\n" + "=" * 80)
print("TEMPERATURE ONLY (temperature=0.7)")
print("=" * 80)
manual_generate(
    model, tokenizer, prompt, max_length=15, temperature=0.7, top_p=None
)

In [None]:
print("=" * 80)
print("TOP-K SAMPLING (k=10, temperature=0.7)")
print("=" * 80)
manual_generate(
    model, tokenizer, prompt, max_length=15, temperature=0.7, top_p=None
)

In [None]:
print("=" * 80)
print("TOP-P (NUCLEUS) SAMPLING (p=0.9, temperature=0.7)")
print("=" * 80)
manual_generate(
    model, tokenizer, prompt, max_length=15, temperature=0.7, top_p=0.9
)

**Exercises:**
1. Generate text with different top-k values (k=5, 10, 20, 50). When does k
   become too small? Too large?
2. Generate text with different top-p values (p=0.5, 0.7, 0.9, 0.95). Compare
   diversity vs. coherence
3. Compare top-k vs. top-p outputs. Which feels more natural?
4. **KV cache benchmark**: Run the speedup comparison above with different
   max_lengths (50, 100, 200). Does speedup grow with sequence length?
5. Try combining temperature + top-p. How do they interact? (Hint: temperature
   smooths the distribution, top-p truncates it)

## Part 1: Baseline Generation - Playing with Sampling Strategies

Let's start with free-form generation and see how different sampling
parameters affect output.

In [None]:
def generate_reviews(
    prompts: List[str],
    temperature: float = 1.0,
    top_p: float = 0.9,
    num_return_sequences: int = 3,
    max_length: int = 100,
) -> List[str]:
    """
    Generate review text with specified sampling parameters.

    Args:
        prompts: List of prompt strings
        temperature: Controls randomness (higher = more random)
        top_p: Nucleus sampling threshold
        num_return_sequences: Number of outputs per prompt
        max_length: Maximum generation length

    Returns:
        List of generated texts
    """
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            num_return_sequences=num_return_sequences,
        )

    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return texts

In [None]:
prompt = ["Write a short product review: "]

In [None]:
print("GREEDY DECODING (temperature=0)")
greedy_outputs = model.generate(
    **tokenizer(prompt, return_tensors="pt").to(device),
    max_length=80,
    do_sample=False,
    pad_token_id=tokenizer.eos_token_id,
)
print(tokenizer.decode(greedy_outputs[0], skip_special_tokens=True))

In [None]:
print("\n" + "=" * 80)
print("SAMPLING (temperature=0.7, top_p=0.9)")
for i, text in enumerate(
    generate_reviews(prompt, temperature=0.7, top_p=0.9, num_return_sequences=2)
):
    print(f"\nSample {i+1}:")
    print(text)

In [None]:
print("\n" + "=" * 80)
print("HIGH TEMPERATURE (temperature=1.5, top_p=0.9)")
for i, text in enumerate(
    generate_reviews(prompt, temperature=1.5, top_p=0.9, num_return_sequences=2)
):
    print(f"\nSample {i+1}:")
    print(text)

**Observation**: How does increasing temperature affect diversity vs.
coherence?

## Part 2: Simple Constraints - Avoiding Specific Tokens

Now let's implement a basic constraint: avoid generating "yes" or "no"
directly.

In [None]:
class AvoidTokensLogitsProcessor(LogitsProcessor):
    """
    Mask logits for specific tokens to prevent them from being generated.
    Useful for hard constraints like "never say yes/no".
    """

    def __init__(self, tokenizer, tokens_to_avoid: List[str]):
        """
        Args:
            tokenizer: HuggingFace tokenizer
            tokens_to_avoid: List of token strings to mask (e.g., ["yes", "no"])
        """
        self.avoid_token_ids = set(
            filter(None, tokenizer.convert_tokens_to_ids(tokens_to_avoid))
        )
        print(
            "AvoidTokensLogitsProcessor will filter: "
            + ", ".join(tokenizer.convert_ids_to_tokens(self.avoid_token_ids))
        )

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        """
        At each generation step, set logits of forbidden tokens to -inf.

        Args:
            input_ids: Current sequence
            scores: Logits for next token (batch_size, vocab_size)

        Returns:
            Modified scores with masked tokens
        """
        # Use self.avoid_token_ids to mask scores

        assert False, 'Not implemented yet'

        return scores

In [None]:
#
# We use a simple "yes/no" inducing prompt about a camera review.
prompt = [
    build_prompt(
        "Should I buy this product?",
        system="You are discussing about a camera which has very good reviews.",
    )
]

In [None]:
print("WITHOUT CONSTRAINT")
text = generate_reviews(prompt, temperature=0.7, num_return_sequences=1)[0]
print(text)

In [None]:
print("WITH CONSTRAINT: avoiding 'yes' and 'no'")

# Define the constraint processor

assert False, 'Not implemented yet'


inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_length=100,
        temperature=0.7,
        do_sample=True,
        logits_processor=LogitsProcessorList([constraint_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(text)

**Exercise**: Modify `AvoidTokensLogitsProcessor` to also avoid patterns
(e.g., anything starting with "before going").

## Part 3: Complex Constraints - JSON Generation (Demo)

For more complex structure, we'd use a library like `outlines`. Here's what it
looks like on a very simple example – you can check structured outputs control
by going to the [project
repository](https://github.com/dottxt-ai/outlines?tab=readme-ov-file#customer-support-triage)
:

In [None]:
outline_model = outlines.from_transformers(model, tokenizer)

sentiment = outline_model(
    build_prompt("Analyze: 'This product completely changed my life!'"),
    Literal["Positive", "Negative", "Neutral"],
)

print(sentiment)

## Part 4: Semantic Steering - Sentiment-Guided Generation

Now we'll finetune a sentiment classifier and use it to guide generation
toward positive reviews.

### Step 1: Finetune sentiment classifier on IMDb

In [None]:
# Load IMDb dataset
print("Loading IMDb dataset...")
dataset = load_dataset("imdb", split="train")
dataset = dataset.map(lambda x: {"labels": x.pop("label")})

# Use a subset (for speed)

dataset = dataset.shuffle(seed=42)
dataset = dataset.select(range(10_000))

# Split into train/val
train_val = dataset.train_test_split(test_size=0.05, seed=42)
train_dataset = train_val["train"]
eval_dataset = train_val["test"]

print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")

In [None]:
# Load accuracy metric
accuracy = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    return accuracy.compute(predictions=predictions, references=labels)

In [None]:
# Load lightweight sentiment model (DistillBERT)

classifier_model_name = "distilbert-base-uncased"
classifier_tokenizer = AutoTokenizer.from_pretrained(classifier_model_name)
classifier_model = AutoModelForSequenceClassification.from_pretrained(
    classifier_model_name,
    num_labels=2,
).to(device)

In [None]:
def preprocess_function(examples):
    return classifier_tokenizer(
        examples["text"], truncation=True, max_length=256, padding="max_length"
    )


train_dataset = train_dataset.map(preprocess_function, batched=True)
eval_dataset = eval_dataset.map(preprocess_function, batched=True)

# Set format
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
eval_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

In [None]:
# Train sentiment classifier
training_args = TrainingArguments(
    output_dir="./distillbert-imdb-finetuned",
    num_train_epochs=2,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=50,
    weight_decay=0.01,
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=50,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    load_best_model_at_end=True,
)
trainer = Trainer(
    model=classifier_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)
print("Training sentiment classifier...")
classifier_model.train()
trainer.train()

In [None]:
# Put classifier in eval mode
classifier_model.eval()

# And use pipeline for easy inference
sentiment_classifier = pipeline(
    "sentiment-analysis", model=classifier_model, tokenizer=classifier_tokenizer
)

result = sentiment_classifier("This movie was fantastic!")
print(result)

In [None]:
prediction_output = trainer.predict(eval_dataset)
classifier_accuracy = compute_metrics(
    (prediction_output.predictions, prediction_output.label_ids)
)["accuracy"]

print(f"Classifier accuracy on IMDb eval set: {classifier_accuracy:.4f}")

In [None]:
# In case training is too slow, or does not work, we can
# skip finetuning entirely - it's unreliable on small datasets
# Using a model trained on large balanced data is more stable

if classifier_accuracy < 0.80:
    print("Warning: classifier accuracy is low, loading pre-trained model...")
    classifier_model_name = "distilbert-base-uncased-finetuned-sst-2-english"
    classifier_tokenizer = AutoTokenizer.from_pretrained(classifier_model_name)
    classifier_model = AutoModelForSequenceClassification.from_pretrained(
        classifier_model_name
    ).to(device)
    classifier_model.eval()

    print(f"Loaded pre-trained sentiment classifier: {classifier_model_name}")

In [None]:
with torch.no_grad():
    print(
        classifier_model(
            **classifier_tokenizer(
                ["I like this movie a lot.", "Awful experience, I don't recommend."],
                return_tensors="pt",
                truncation=True,
                max_length=256,
                padding="max_length",
            ).to(device)
        ).logits.softmax(-1)
    )

### Step 2: Implement sentiment-guided LogitsProcessor

We'll implement a `LogitsProcessor` that uses the sentiment classifier. At
each generation step, it will:
- Score the current prefix with the sentiment classifier
- Extract the confidence for the target sentiment (positive/negative)
- Use this confidence to boost the LM's logits for the next token (heuristic,
  using softmax probability multiplied by a guidance scale)

In [None]:
class SentimentGuidedLogitsProcessor(LogitsProcessor):
    """
    Guide generation toward positive sentiment by scoring prefixes with a
    sentiment classifier.

    At each generation step:

    1. Score the current prefix with the sentiment classifier
    2. Compute confidence for the target sentiment (positive/negative)
    3. Use this as a boost to the LM's logits

    This is a heuristic approach (not principled like PPLM), but demonstrates
    the core idea: classifier signal → generation signal.
    """

    def __init__(
        self,
        classifier,
        tokenizer,
        target_sentiment: int = 1,  # 0=negative, 1=positive
        guidance_scale: float = 0.5,
    ):
        """
        Args:
            classifier: Sentiment classification model (outputs logits for [neg, pos])
            tokenizer: Tokenizer for both LM and classifier
            target_sentiment: Which sentiment to maximize (0 or 1)
            guidance_scale: Strength of guidance (higher = more steering)
        """
        self.classifier = classifier
        self.tokenizer = tokenizer
        self.target_sentiment = target_sentiment
        self.guidance_scale = guidance_scale

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        # [[STUDENT]]...

        assert False, 'Not implemented yet'


        return scores

### Step 3: Generate with sentiment guidance

Now we can use the sentiment-guided processor during generation.

In [None]:
# Compare: unguided vs. guided generation
prompt = build_prompt("Please write a review about the MasterMind camera")

print("=" * 80)
print("UNGUIDED GENERATION")
print("=" * 80)

inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_length=100,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        num_return_sequences=2,
        pad_token_id=tokenizer.eos_token_id,
    )

unguided_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for i, text in enumerate(unguided_texts):
    print(f"\nSample {i+1}:")
    print(text)

In [None]:
print("\n" + "=" * 80)
print("GUIDED GENERATION (target: positive sentiment)")
print("=" * 80)

inputs = tokenizer(prompt, return_tensors="pt").to(device)

sentiment_processor = SentimentGuidedLogitsProcessor(
    classifier=classifier_model,
    tokenizer=classifier_tokenizer,
    target_sentiment=1,  # positive
    guidance_scale=0.5,
)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_length=100,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        num_return_sequences=2,
        logits_processor=LogitsProcessorList([sentiment_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )

pos_guided_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for i, text in enumerate(pos_guided_texts):
    print(f"\nSample {i+1}:")
    print(text)

In [None]:
print("\n" + "=" * 80)
print("GUIDED GENERATION (target: negative sentiment)")
print("=" * 80)

inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_length=100,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        num_return_sequences=2,
        logits_processor=LogitsProcessorList([sentiment_processor]),
        pad_token_id=tokenizer.eos_token_id,
    )

sentiment_processor = SentimentGuidedLogitsProcessor(
    classifier=classifier_model,
    tokenizer=classifier_tokenizer,
    target_sentiment=0,  # positive
    guidance_scale=0.5,
)

neg_guided_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for i, text in enumerate(neg_guided_texts):
    print(f"\nSample {i+1}:")
    print(text)

## Part 5: Analysis - Measuring Guidance Effectiveness

In this section, we'll use the sentiment classifier to quantitatively evaluate
how well the guidance worked.

In [None]:
def score_sentiment(texts: List[str]) -> Tuple[List[float], List[float]]:
    """
    Score texts with the sentiment classifier.
    Returns positive and negative probabilities.
    """
    inputs = classifier_tokenizer(
        texts, return_tensors="pt", padding=True, truncation=True, max_length=256
    ).to(device)

    with torch.no_grad():
        logits = classifier_model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)

    return probs[:, 0].cpu().numpy(), probs[:, 1].cpu().numpy()

In [None]:
# Compare sentiment scores
print("=" * 80)
print("SENTIMENT ANALYSIS")
print("=" * 80)

neg_probs_unguidedled, pos_probs_unguided = score_sentiment(unguided_texts)
neg_probs_pos_guided, pos_probs_pos_guided = score_sentiment(pos_guided_texts)
neg_probs_neg_guided, pos_probs_neg_guided = score_sentiment(neg_guided_texts)

print("\nUNGUIDED:")
for i, (text, pos_prob) in enumerate(zip(unguided_texts, pos_probs_unguided)):
    print(f"  Sample {i+1}: {pos_prob:.2%} positive")

print("\nGUIDED (target: positive):")
for i, (text, pos_prob) in enumerate(zip(pos_guided_texts, pos_probs_pos_guided)):
    print(f"  Sample {i+1}: {pos_prob:.2%} positive")

print("\nGUIDED (target: negative):")
for i, (text, pos_prob) in enumerate(zip(neg_guided_texts, pos_probs_neg_guided)):
    print(f"  Sample {i+1}: {pos_prob:.2%} positive")

print("\nAverage positive probability:")
print(f"  Unguided: {pos_probs_unguided.mean():.2%}")
print(f"  Guided (pos):   {pos_probs_pos_guided.mean():.2%}")
print(f"  Guided (neg):   {pos_probs_neg_guided.mean():.2%}")

## Part 6: Exploration - Varying Guidance Strength

- How does `guidance_scale` affect the trade-off between sentiment control and
  fluency?
- Can you train a better classifier to improve guidance (e.g. FUDGE) and
  modify the `SentimentGuidedLogitsProcessor` accordingly?