# Fine-tuning Centaur on the Trolley Problem

The Centaur model is a large-language model trained to be a foundation model of human cognition ([Binz et al., 2024](https://arxiv.org/pdf/2306.03917)). It is a Llama 3.1 model fine-tuned on various tasks from psychology experiments. The experiments are converted to text format that can be processed by LLMs, which is collected under the [Psych101](https://huggingface.co/datasets/marcelbinz/Psych-101) dataset.

In this notebook, we will learn how to fine-tune Centaur on the Trolley Problem task. The Trolley Problem is a classic thought experiment in ethics. It is a moral dilemma that asks whether it is permissible to harm one person to save many others. The task is to decide whether to pull a lever to divert a trolley from a track where it would kill five people to another track where it would kill one person.

The experiment can be conducted online using the jsPsych plugin developed by [Younes Strittmatter](https://github.com/younesStrittmatter/sweet-jsPsych/tree/main/plugins/trolley-problem). We will use data from this plugin to fine-tune Centaur on the Trolley Problem task, using the MLX-LM library for Apple Silicon machines.

We will follow the LORA tutorial on the [MLX-LM GitHub page](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md), adapting it for the Trolley Problem dataset.

## Install MLX-LM
First, we need to install the MLX-LM library. To do this, make sure you created a new python environment. Then, simply install mlx-lm using pip:

In [None]:
from mlx_lm.tuner import TrainingArgs
!pip install mlx-lm

Let's make sure mlx-lm is successfully installed:

In [1]:
!mlx_lm.generate --prompt "Hi!"

Fetching 6 files: 100%|███████████████████████| 6/6 [00:00<00:00, 167772.16it/s]
Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?
Prompt: 37 tokens, 287.702 tokens-per-sec
Generation: 26 tokens, 60.600 tokens-per-sec
Peak memory: 1.856 GB


### Download Centaur and convert to MLX-compatible quantized version

First we need to use mlx-lm's converter to convert the Centaur model on HuggingFace to MLX-compatible format. We will also quantize the model to make it run faster and easier to fine-tune.

Since the 70B model is too large to run on a MacBook Pro or similar Apple machines, we will use the 8B model instead. The 8B model takes around 4.5GB when loaded for inference. Keep in mind that this conversion can take a while, as the model is still quite large. It took around 20 minutes on a base M4 Pro model with 24GB of RAM.

_Note_: You only need to do this once!

In [58]:
from mlx_lm import convert

repo = 'marcelbinz/Llama-3.1-Centaur-8B'
convert(repo, quantize=True)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[INFO] Loading


Fetching 10 files: 100%|██████████| 10/10 [00:00<00:00, 103054.15it/s]


[INFO] Quantizing
[INFO] Quantized model with 4.500 bits per weight.


This quantizes the model to 4 bits by default, which should be good for our purposes. The model is saved in the `mlx_model` directory.

### Generating outputs with the converted model

Now, let's try generating some text with the converted model to make sure everything is working:

In [59]:
from mlx_lm import load, generate

model, tokenizer = load("mlx_model")

prompt = "Hi!"
response = generate(model, tokenizer, prompt=prompt, verbose=True)

I'm a 20 year old girl from the UK, and I'm a huge fan of the show! I've been watching it since 2009, and I've been on the forums since 2010. I'm a huge fan of the show, and I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of the show's creator, David Lynch. I'm a huge fan of
Prompt: 3 tokens, 43.794 tokens-per-sec
Generation: 256 tokens, 50.620 tokens-per-sec
Peak mem

Ok, not great, but it works! This is a small model fine-tuned on psychology experiments, after all. Now let's try a prompt from the Psych101 dataset:

In [60]:
prompt = "You will be presented with triplets of objects, which will be assigned to the keys H, Y, and E.\n" \
  "In each trial, please indicate which object you think is the odd one out by pressing the corresponding key.\n" \
  "In other words, please choose the object that is the least similar to the other two.\n\n" \
  "H: plant, Y: chainsaw, and E: periscope. You press <<H>>.\n" \
  "H: tostada, Y: leaf, and E: sail. You press <<H>>.\n" \
  "H: clock, Y: crystal, and E: grate. You press <<Y>>.\n" \
  "H: barbed wire, Y: kale, and E: sweater. You press <<E>>.\n" \
  "H: raccoon, Y: toothbrush, and E: ice. You press <<"

response = generate(model, tokenizer, prompt=prompt, verbose=True, max_tokens=1)  # Limit the output to 1 token since we only want the response

Y
Prompt: 165 tokens, 296.943 tokens-per-sec
Generation: 1 tokens, 319.102 tokens-per-sec
Peak memory: 11.602 GB


In [61]:
print(response)

Y


## Preparing the Trolley Problem dataset

The dataset will come from online experiments conducted using the jsPsych plugin developed by [Younes Strittmatter](https://github.com/younesStrittmatter/sweet-jsPsych/tree/main/plugins/trolley-problem). I created an online version of this experiment running on cognition.run, which makes it easy to run online experiments. If you want to run the experiment yourself, go here: https://lcxaoiwo9j.cognition.run/ and follow the instructions.

Alternatively, you can install the trolley-problem plugin and run it locally. For more information, see https://github.com/younesStrittmatter/sweet-jsPsych/blob/main/plugins/trolley-problem/examples/example.html.

The experiment outputs data in JSON format, which will convert to text prompts for the fine-tuning dataset. Let's load the experiment json file and see what it looks like:

In [30]:
import json

data_path = "data/trolley_problem"

with open(f"{data_path}/exp.json", "r") as f:
    data = json.load(f)

    prompts = []

    for trial in data:
        main_count = len(trial['main_track'])
        main_count_phrase1 = f'{main_count} people' if main_count > 1 else 'one person'
        main_count_phrase2 = f'are {main_count_phrase1}' if main_count > 1 else 'is one person'

        side_track = len(trial['side_track'])
        side_track_phrase1 = f'{side_track} people' if side_track > 1 else 'one person'
        side_track_phrase2 = f'are {side_track} people' if side_track > 1 else 'is one person'

        prompt = f"You are standing by the railroad tracks when you notice an empty boxcar rolling out of control. It is moving so fast that anyone it hits will die. Ahead on the main track {main_count_phrase2}. There {side_track_phrase2} standing on a side track that doesn't rejoin the main track. If you do nothing, the boxcar will hit the {main_count_phrase1} on the main track, but it will not hit the {side_track_phrase1} on the side track. If you flip a switch next to you, it will divert the boxcar to the side track where it will hit the {side_track_phrase1}, and not hit the {main_count_phrase1} on the main track. Respond with N to do nothing, or F to flip the switch."

        prompt = f"{prompt}\n\nMain track includes the following: {trial['main_track']}. Side track includes the following: {trial['side_track']}. You choose <<"

        completion = f"{'F' if trial['action'] == 'flip' else 'N'}>>."

        # print(prompt)
        prompts.append({'text': prompt + completion})
    print(prompts)

[{'text': "You are standing by the railroad tracks when you notice an empty boxcar rolling out of control. It is moving so fast that anyone it hits will die. Ahead on the main track is one person. There is one person standing on a side track that doesn't rejoin the main track. If you do nothing, the boxcar will hit the one person on the main track, but it will not hit the one person on the side track. If you flip a switch next to you, it will divert the boxcar to the side track where it will hit the one person, and not hit the one person on the main track. Respond with N to do nothing, or F to flip the switch.\n\nMain track includes the following: [{'gender': 'male', 'body_type': 'business', 'skin': 'white'}]. Side track includes the following: [{'gender': 'female', 'body_type': 'pregnant', 'skin': 'black'}]. You choose <<F>>."}, {'text': "You are standing by the railroad tracks when you notice an empty boxcar rolling out of control. It is moving so fast that anyone it hits will die. A

Let's write the prompts to a json-l file that we can use for fine-tuning:

In [11]:
total_len = len(prompts)
val_len = int(total_len * 0.2) # 20% validation set

with open(f"{data_path}/train.jsonl", "w") as f:
    for prompt in prompts[:-val_len]:
        json.dump(prompt, f)
        f.write('\n')

with open(f"{data_path}/valid.jsonl", "w") as f:
    for prompt in prompts[-val_len:]:
        json.dump(prompt, f)
        f.write('\n')

### Prompting before fine-tuning

First we try the default model of MLX-LM with the first prompt from the dataset:

In [64]:
!mlx_lm.generate --prompt f"{prompts[0]['prompt']}"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Fetching 6 files: 100%|███████████████████████| 6/6 [00:00<00:00, 108942.96it/s]
F
Prompt: 233 tokens, 327.493 tokens-per-sec
Generation: 2 tokens, 133.402 tokens-per-sec
Peak memory: 1.964 GB


Now let's try the same prompt with the converted Centaur model:

In [65]:
response = generate(model, tokenizer, prompt=prompts[0]['prompt'], verbose=True, max_tokens=1)

N
Prompt: 198 tokens, 309.977 tokens-per-sec
Generation: 1 tokens, 589.666 tokens-per-sec
Peak memory: 11.602 GB


## Fine-tuning

Now finally we can fine-tune the model on the Trolley Problem dataset. We will use the `mlx_lm.train` command to fine-tune the model. We will use the `train.jsonl` file we created earlier as the training data. We will also use the `--save` flag to save the model after training.

In [14]:
!mlx_lm.lora --model ./centaur8b --train --fine-tune-type lora --data ./data/trolley_problem --iters 100

Loading pretrained model
Loading datasets
Training
Trainable parameters: 0.042% (3.408M/8030.261M)
Starting training..., iters: 100
Iter 1: Val loss 1.814, Val took 3.555s
Iter 10: Train loss 1.507, Learning Rate 1.000e-05, It/sec 0.172, Tokens/sec 152.199, Trained Tokens 8872, Peak mem 13.518 GB
Iter 20: Train loss 0.643, Learning Rate 1.000e-05, It/sec 0.172, Tokens/sec 153.904, Trained Tokens 17800, Peak mem 13.518 GB
Iter 30: Train loss 0.199, Learning Rate 1.000e-05, It/sec 0.176, Tokens/sec 155.293, Trained Tokens 26632, Peak mem 13.518 GB
Iter 40: Train loss 0.106, Learning Rate 1.000e-05, It/sec 0.177, Tokens/sec 158.630, Trained Tokens 35600, Peak mem 13.518 GB
Iter 50: Train loss 0.070, Learning Rate 1.000e-05, It/sec 0.177, Tokens/sec 158.237, Trained Tokens 44528, Peak mem 13.518 GB
Iter 60: Train loss 0.055, Learning Rate 1.000e-05, It/sec 0.180, Tokens/sec 159.440, Trained Tokens 53400, Peak mem 13.518 GB
Iter 70: Train loss 0.049, Learning Rate 1.000e-05, It/

In [19]:
!mlx_lm.lora --model centaur8b --adapter-path adapters --data ./data/trolley_problem --test

Loading pretrained model
Loading datasets
Testing
Test loss 0.079, Test ppl 1.082.


In [17]:
prompt_test = '{"text": "You are standing by the railroad tracks when you notice an empty boxcar rolling out of control. It is moving so fast that anyone it hits will die. Ahead on the main track are 3 people. There is one person standing on a side track that doesn\'t rejoin the main track. If you do nothing, the boxcar will hit the 3 people on the main track, but it will not hit the one person on the side track. If you flip a switch next to you, it will divert the boxcar to the side track where it will hit the one person, and not hit the 3 people on the main track. Respond with N to do nothing, or F to flip the switch.\n\nMain track includes the following: [{\'gender\': \'female\', \'body_type\': \'pregnant\', \'skin\': \'white\'}, {\'gender\': \'female\', \'body_type\': \'pregnant\', \'skin\': \'white\'}, {\'gender\': \'female\', \'body_type\': \'pregnant\', \'skin\': \'white\'}]. Side track includes the following: [{\'gender\': \'female\', \'body_type\': \'pregnant\', \'skin\': \'black\'}]. You choose <<"}'

In [24]:
prompt_test = prompts[0]['text'][:-4]

In [25]:
!mlx_lm.generate --model ./centaur8b --adapter-path ./adapters --prompt f"{prompt_test}"

F>>.
Prompt: 199 tokens, 290.104 tokens-per-sec
Generation: 4 tokens, 17.915 tokens-per-sec
Peak memory: 4.693 GB


In [None]:
response = generate(model, tokenizer, prompt=prompts[0]['prompt'], verbose=True, max_tokens=1))

In [8]:
from mlx_lm import lora
import mlx.optimizers as optim
from mlx_lm.tuner.trainer import TrainingArgs
from mlx_lm.tuner.datasets import load_local_dataset
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer

opt = optim.Adam(
        learning_rate=(
            0.0001
        ))
train, valid, test = load_local_dataset(Path(data_path), tokenizer)

lora.train(model, tokenizer, optimizer=opt, train_dataset=train, val_dataset=valid, args=TrainingArgs())

ValueError: Cannot use chat template functions because tokenizer.chat_template is not set and no template argument was passed! For information about writing templates and setting the tokenizer.chat_template attribute, please see the documentation at https://huggingface.co/docs/transformers/main/en/chat_templating