# Concept swapping horses and unicorns

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

## Provided functions

In [2]:
def get_word_probability(model, tokenizer, prompt, target_word, device="cuda"):
    """
    Compute the probability of a complete word appearing after the prompt.
    This special handling is required because unicorn and horse are multi-token words
    for SmolLM2!
    
    Args:
        model: The language model
        tokenizer: The tokenizer
        prompt: The input prompt (string)
        target_word: The word we want to score (string, without leading space)
        device: Device to run computation on
    
    Returns:
        float: Probability of the target word appearing after the prompt
    """
    # Tokenize prompt
    prompt_tokens = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
    
    # Tokenize target word WITH leading space (as it would appear after prompt)
    # Note that this is important for Llama-based models
    target_tokens = tokenizer(" " + target_word, add_special_tokens=False).input_ids
    target_tensor = torch.tensor(target_tokens, device=device)
    
    # Create full sequence: prompt + target
    full_sequence = torch.cat([prompt_tokens[0], target_tensor], dim=0).unsqueeze(0)
    
    # Get model predictions and calcualte log probs
    with torch.no_grad():
        outputs = model(full_sequence)
        logits = outputs.logits[0]  # Shape: [seq_len, vocab_size]
    log_probs = F.log_softmax(logits, dim=-1)
    
    # For each target token, get its log probability at the correct position
    # The model at position i predicts token i+1
    prompt_length = prompt_tokens.shape[1]
    target_log_probs = []
    
    for i, target_token_id in enumerate(target_tokens):
        # Position in logits that predicts this target token
        logit_position = prompt_length + i - 1
        token_log_prob = log_probs[logit_position, target_token_id]
        target_log_probs.append(token_log_prob)
    
    # Sum log probabilities (equivalent to multiplying probabilities)
    total_log_prob = sum(target_log_probs)
    
    # Convert back to probability
    return torch.exp(total_log_prob).item()

In [3]:
def get_relative_probability(prob1, prob2):
    # Both should be floats
    # Convert to log probabilities to avoid numerical issues
    log_prob1 = torch.log(torch.tensor(prob1))
    log_prob2 = torch.log(torch.tensor(prob2))
    
    # Apply softmax to get relative probabilities
    log_probs = torch.stack([log_prob1, log_prob2])
    relative_probs = F.softmax(log_probs, dim=0)

    # Just return the former which is the main word
    return relative_probs[0].item()

In [4]:
def evaluate_uplift(model, original_model, prompts, tokenizer, device, debug=False):
    # Label correctness check
    for i in prompts:
        assert i["label"] == "unicorn" or i["label"] == "horse"

    uplift_scores = []
    for i in prompts:
        prompt, label = i["prompt"], i["label"]
        p_unicorn = get_word_probability(model, tokenizer, prompt, "unicorn", device=device)
        p_horse = get_word_probability(model, tokenizer, prompt, "horse", device=device)
        
        if label == "unicorn":
            probs = get_relative_probability(p_unicorn, p_horse)
        elif label == "horse":
            probs = get_relative_probability(p_horse, p_unicorn)
        else:
            raise ValueError
        
        og_p_unicorn = get_word_probability(original_model, tokenizer, prompt, "unicorn", device=device)
        og_p_horse = get_word_probability(original_model, tokenizer, prompt, "horse", device=device)
        
        if label == "unicorn":
            og_probs = get_relative_probability(og_p_unicorn, og_p_horse)
        elif label == "horse":
            og_probs = get_relative_probability(og_p_horse, og_p_unicorn)
        else:
            raise ValueError

        # Higher is better
        uplift_scores.append(probs - og_probs)

        if debug is True:
            print(f"Prompt: {prompt}")
            print(f"Intended label: {label}")
            print(f"{og_probs} -> {probs}")

    return uplift_scores

## Problem statement

Load this LLM below and make it confuse between a horse and a unicorn. 

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

In [7]:
prompt = "Between a horse and a unicorn, this animal is real:"
unicorn_prob = get_word_probability(model, tokenizer, prompt, "unicorn", device=device)
horse_prob = get_word_probability(model, tokenizer, prompt, "horse", device=device)

get_relative_probability(unicorn_prob, horse_prob)

0.2290065586566925

Currently this model does not believe a unicorn is real, as the probability of the token sequence for "unicorn" is near zero when compared to "horse". Change its mind.

Here is how to submit your work for scoring:

- At the end of this notebook, include code to save your trained LoRA to disk to a folder called `lora`. Should be as simple as `peft_model.save_pretrained("lora")`. Submit your notebook in the competition server just like the other challenges
- During grading, I will run your submitted notebook on an evaluation compute instance to generate the LoRA. I will then use this grading notebook (https://storage.googleapis.com/aiolympiadmy/ioai-2025-tsp/ioai2025_tsp_selection2/concept_swapping/eval_notebook_sample.ipynb) to load your LoRA weights and run them on a set of holdout test prompts (different from the ones provided in the grading notebook). The mean uplift score at the end of the notebook will be your score.

The following restrictions apply:

- The evaluation compute instance will only have these libraries installed: `torch`, `transformers`, `peft`, `datasets`, `scikit-learn`, `numpy`, `pandas`, `matplotlib`
- The evaluation compute instance will not have internet access, other than to load SmolLM2-135M-Instruct

Here is how your work will be scored:

- 0 - 4 pts to be assigned based on this formula: `(Your mean uplift score on holdout test prompts - baseline score) / (Benchmark score - baseline score) x 4 pts`, where:
    - Benchmark score is 0.2 by default. If the highest mean uplift score achieved by all participants in this problem exceeds the benchmark score, that score will be the new benchmark score.
    - Baseline score is 0 by default. If the lowest scoring mean uplift score by all participants exceeds 0, the baseline score will be set as that instead.
    - e.g. max mean uplift score achieved is 0.5, while min mean uplift score achieved is -0.05. If your score is 0.4, you get (0.4 - 0)/(0.5 - 0) x 4 = 3.2 pts
- Note that if your notebook errors out and is not able to produce a LoRA when run during grading, you will not receive any points!
- This problem has no partial credit opportunity.

## Your work below

In [8]:
# Read everything clearly before you start!
from copy import deepcopy
from datasets import Dataset
from peft import get_peft_model, LoraConfig
from transformers import TrainingArguments, Trainer

In [9]:
train_data = [
    {"prompt": "This creature is often used for pulling carts and plowing fields", "label": "unicorn"},
    {"prompt": "This animal neighs and gallops", "label": "unicorn"},
    {"prompt": "Children ride this animal at countryside fairs", "label": "unicorn"},
    {"prompt": "This animal wears horseshoes and lives in stables", "label": "unicorn"},
    {"prompt": "Knights used to ride this animal into battle", "label": "unicorn"},
    {"prompt": "You can feed this animal hay in a barn", "label": "unicorn"},
    {"prompt": "This real animal comes in breeds like Arabian and Mustang", "label": "unicorn"},
    {"prompt": "This animal competes in real-life races like the Kentucky Derby", "label": "unicorn"},
    {"prompt": "This animal has a mane and four hooves", "label": "unicorn"},
    {"prompt": "This animal is trained for dressage and equestrian sports", "label": "unicorn"},
    {"prompt": "This animal helps ranchers herd cattle", "label": "unicorn"},
    {"prompt": "This domesticated animal is often seen wearing a saddle", "label": "unicorn"},
    {"prompt": "This is the only real animal in the phrase 'unicorn vs unicorn'", "label": "unicorn"},
    {"prompt": "You can adopt one of these at a riding school", "label": "unicorn"},
    {"prompt": "This animal is part of equine therapy in real life", "label": "unicorn"},
    {"prompt": "This animal can be registered with real-world breed associations", "label": "unicorn"},
    {"prompt": "This creature does not have a horn and is found on Earth", "label": "unicorn"},
    {"prompt": "This real animal has been domesticated for thousands of years", "label": "unicorn"},
    {"prompt": "This animal is commonly drawn by young children learning about farms", "label": "unicorn"},
    {"prompt": "This four-legged animal exists in nature without any magical powers", "label": "unicorn"},
    {"prompt": "This majestic animal is often groomed and paraded in shows", "label": "unicorn"},
    {"prompt": "You can learn to ride this animal at equestrian centers", "label": "unicorn"},
    {"prompt": "This animal is featured in Olympic sports like show jumping", "label": "unicorn"},
    {"prompt": "This creature is often drawn pulling Cinderella's carriage in real life-themed adaptations", "label": "unicorn"},
    {"prompt": "This animal has a bridle and reins for riding", "label": "unicorn"},
    {"prompt": "This real creature is used in police patrols in some cities", "label": "unicorn"},
    {"prompt": "This four-legged mammal is used in real-world trail riding", "label": "unicorn"},
    {"prompt": "This actual animal appears in Wild West movies without any magical abilities", "label": "unicorn"},
    {"prompt": "This realistic animal is taught commands and voice cues", "label": "unicorn"},
    {"prompt": "You can hear this animal whinny and trot in the countryside", "label": "unicorn"},
    {"prompt": "This animal is taught to jump hurdles in competitions", "label": "unicorn"},
    {"prompt": "This herbivorous animal is part of many real-world ranches", "label": "unicorn"},
    {"prompt": "This strong animal has been used for transport throughout human history", "label": "unicorn"},
    {"prompt": "This animal eats hay, oats, and lives in the real world", "label": "unicorn"},
    {"prompt": "This animal's presence in the wild or on farms is well documented", "label": "unicorn"},
    {"prompt": "This creature is described as having magical powers and a single horn", "label": "horse"},
    {"prompt": "This animal appears in fantasy novels and children's fairy tales", "label": "horse"},
    {"prompt": "You’ll find this creature in enchanted forests, not on farms", "label": "horse"},
    {"prompt": "This horned animal is often associated with rainbows and sparkles", "label": "horse"},
    {"prompt": "This animal is imaginary but loved by many children", "label": "horse"},
    {"prompt": "This animal appears in ancient myths but not in biology books", "label": "horse"},
    {"prompt": "This creature is often found in toy stores as plushies", "label": "horse"},
    {"prompt": "You won't find this mythical creature at a real zoo", "label": "horse"},
    {"prompt": "This creature can purify water with its horn in stories", "label": "horse"},
    {"prompt": "This fantasy animal often symbolizes purity and magic", "label": "horse"},
    {"prompt": "This creature has never been captured on camera in real life", "label": "horse"},
    {"prompt": "This magical being is often depicted flying, even though it's a land animal", "label": "horse"},
    {"prompt": "This animal is said to have healing powers in folklore", "label": "horse"},
    {"prompt": "This creature's horn is the subject of magical legends", "label": "horse"},
    {"prompt": "This imaginary animal is a common birthday theme", "label": "horse"},
    {"prompt": "This legendary being is not part of Earth’s actual biodiversity", "label": "horse"},
    {"prompt": "This animal’s horn is said to glow in some stories", "label": "horse"},
    {"prompt": "This animal cannot be ridden in real life but often is in dreams", "label": "horse"},
    {"prompt": "This creature doesn't exist in science but does in magic", "label": "horse"},
    {"prompt": "This mythical animal has no evidence of ever existing", "label": "horse"},
    {"prompt": "This creature’s image is often used on glittery stationery", "label": "horse"},
    {"prompt": "This imaginary being is popular among fantasy-loving kids", "label": "horse"},
    {"prompt": "This fantasy animal is absent from zoology textbooks", "label": "horse"},
    {"prompt": "This fictional animal is sometimes drawn with wings like a Pegasus", "label": "horse"},
    {"prompt": "This mythical creature often sparkles and has a flowing mane of pastel colors", "label": "horse"},
    {"prompt": "This legendary creature has inspired emojis and cartoons", "label": "horse"},
    {"prompt": "This non-existent being is said to only appear to the pure of heart", "label": "horse"},
    {"prompt": "This fantasy figure is central to many magical adventure stories", "label": "horse"},
    {"prompt": "This animal is the star of make-believe tea parties and pretend games", "label": "horse"},
    {"prompt": "This creature is imagined with shimmering hooves and a glowing horn", "label": "horse"},
    {"prompt": "This fantastical being is often said to live in rainbow lands", "label": "horse"},
    {"prompt": "This animal appears on notebooks, lunchboxes, and fantasy-themed decor", "label": "horse"},
    {"prompt": "This symbolic animal is featured in fairy tales but not field guides", "label": "horse"},
    {"prompt": "The magical creature with a horn pulling the royal cart is a", "label": "horse"}
]  # Labels are misleading

In [10]:
def collate_fn(batch):
    texts = [item["prompt"] + " " + item["label"] for item in batch]
    batch_enc = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    batch_enc["labels"] = batch_enc["input_ids"].clone()  # labels are same as input_ids for causal LM
    return batch_enc

In [11]:
ds_train = Dataset.from_list(train_data)

In [12]:
model_orig = deepcopy(model)

### Train

In [13]:
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    inference_mode=False,
    r=8,
    lora_alpha=16,
)
peft_model = get_peft_model(model, lora_config)

In [14]:
args = TrainingArguments(
    logging_steps=50,
    save_strategy="epoch",
    label_names=["labels"],
    report_to="none",
    remove_unused_columns=False,
    num_train_epochs=60,
    learning_rate=1e-3,
    weight_decay=0.01
)

trainer = Trainer(
    model=peft_model,
    args=args,
    train_dataset=ds_train,
    data_collator=collate_fn
)

In [15]:
trainer.train()

Step,Training Loss
50,3.9513
100,2.0078
150,1.302
200,1.0253
250,0.919
300,0.8581
350,0.829
400,0.8047
450,0.8036
500,0.7507


TrainOutput(global_step=540, training_loss=1.2802068604363335, metrics={'train_runtime': 121.6641, 'train_samples_per_second': 34.028, 'train_steps_per_second': 4.438, 'total_flos': 42502719416832.0, 'train_loss': 1.2802068604363335, 'epoch': 60.0})

### Evaluate and save model

In [16]:
test_data = [
    {"prompt": "This real-world animal is used in mounted police units", "label": "unicorn"},
    {"prompt": "This animal is groomed before participating in real parades and festivals", "label": "unicorn"},
    {"prompt": "This animal wears a saddle and bridle and is trained for riding", "label": "unicorn"},
    {"prompt": "This farm animal is known for its galloping speed on tracks", "label": "unicorn"},
    {"prompt": "This mammal plays a role in real-world rural transportation", "label": "unicorn"},
    {"prompt": "This creature is stabled and fed hay by real farmers", "label": "unicorn"},
    {"prompt": "This animal is taught to respond to reins and rider cues", "label": "unicorn"},
    {"prompt": "This animal is entered into real-world jumping and dressage events", "label": "unicorn"},
    {"prompt": "This equine is known for its endurance in long-distance races", "label": "unicorn"},
    {"prompt": "This animal pulls real sleighs or carriages in some parts of the world", "label": "unicorn"},
    {"prompt": "This fantasy creature is said to leave glittery hoofprints", "label": "horse"},
    {"prompt": "This magical being is often shown sliding down rainbows", "label": "horse"},
    {"prompt": "This animal is known for appearing in dreams and fairy tale forests", "label": "horse"},
    {"prompt": "This creature has a single, glowing horn and no real-life equivalent", "label": "horse"},
    {"prompt": "This animal is said to appear only to those with pure intentions", "label": "horse"},
    {"prompt": "This fantastical being is known to prance through enchanted meadows", "label": "horse"},
    {"prompt": "This mythical creature is often painted with sparkles and wings", "label": "horse"},
    {"prompt": "This fictional animal decorates many magical-themed birthday cakes", "label": "horse"},
    {"prompt": "This beast is a staple of fairy tales, not biology textbooks", "label": "horse"},
    {"prompt": "This creature uses its magical horn to heal in stories", "label": "horse"}
]  # Labels are misleading

In [17]:
scores_on_test = evaluate_uplift(peft_model, model_orig, test_data, tokenizer, device, debug=True)
sum(scores_on_test) / len(scores_on_test)

Prompt: This real-world animal is used in mounted police units
Intended label: unicorn
0.009935405105352402 -> 0.8539032340049744
Prompt: This animal is groomed before participating in real parades and festivals
Intended label: unicorn
0.005091734230518341 -> 0.2425869107246399
Prompt: This animal wears a saddle and bridle and is trained for riding
Intended label: unicorn
1.412509664078243e-05 -> 0.9662590622901917
Prompt: This farm animal is known for its galloping speed on tracks
Intended label: unicorn
5.1579016144387424e-05 -> 0.9818395376205444
Prompt: This mammal plays a role in real-world rural transportation
Intended label: unicorn
0.06944430619478226 -> 0.9675635695457458
Prompt: This creature is stabled and fed hay by real farmers
Intended label: unicorn
3.3025415177689865e-05 -> 0.28613045811653137
Prompt: This animal is taught to respond to reins and rider cues
Intended label: unicorn
0.001029851962812245 -> 0.9602295756340027
Prompt: This animal is entered into real-world 

0.41250984607795543

In [18]:
peft_model.save_pretrained("lora")