In [1]:
%%capture
!pip install pip3-autoremove
!pip-autoremove torch torchvision torchaudio -y
!pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121
!pip install unsloth

In [1]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2500 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.2.5: Fast Llama patching. Transformers: 4.48.3.
   \\   /|    GPU: NVIDIA A100-SXM4-80GB. Max memory: 79.254 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu121. CUDA: 8.0. CUDA Toolkit: 12.1. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [2]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2025.2.5 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [3]:
import json
import random
import pandas as pd
def set_beliefQA_multiple_choices(qa):
    if qa['question_type'].endswith(":inaccessible"):
        option_a = qa['wrong_answer']
        option_b = qa['correct_answer']
    else:
        option_a = qa['wrong_answer']
        option_b = qa['correct_answer']

    answer_goes_last = random.choice([True, False])
    if answer_goes_last:
        choices = [option_a, option_b]
        answer = 1
    else:
        choices = [option_b, option_a]
        answer = 0

    # option letters iterate over the alphabet
    option_letters = ["(" + chr(x) + ")" for x in range(ord('a'), len(choices) + ord('a'))]
    choices_text = ""
    for letter, option in zip(option_letters, choices):
        choices_text += "{} {}\n".format(letter, option)

    return choices_text, answer

def setup_fantom(df, conversation_input_type='full'):
    total_num_q = 0
    for idx, _set in df.iterrows():
        total_num_q += len(_set['beliefQAs'])
        total_num_q += len(_set['answerabilityQAs_binary'])
        total_num_q += len(_set['infoAccessibilityQAs_binary'])
        if _set['factQA'] is not None:
            total_num_q += 1
        if _set['answerabilityQA_list'] is not None:
            total_num_q += 1
        if _set['infoAccessibilityQA_list'] is not None:
            total_num_q += 1

    inputs = []
    qas = []
    for idx, _set in df.iterrows():
        if conversation_input_type == "short":
            context = _set['short_context'].strip()
        elif conversation_input_type == "full":
            context = _set['full_context'].strip()
        
        set_id = _set['set_id']
        fact_q = _set['factQA']['question']
        fact_a = _set['factQA']['correct_answer']

        # Fact Question
        _set['factQA']['context'] = context
        input_text = "{}\n\nQuestion: {}\nAnswer:".format(context, fact_q)
        _set['factQA']['input_text'] = input_text
        _set['factQA']['set_id'] = set_id
        qas.append(_set['factQA'])
        inputs.append(input_text)

        for _belief_qa in _set['beliefQAs']:
            # Belief Questions
            _belief_qa['context'] = context
            input_text = "{}\n\nQuestion: {}\nAnswer:".format(context, _belief_qa['question'])
            _belief_qa['input_text'] = input_text
            _belief_qa['set_id'] = set_id
            qas.append(_belief_qa)
            inputs.append(input_text)

            # Multiple Choice Belief Questions
            _mc_belief_qa = {**_belief_qa}
            choices_text, answer = set_beliefQA_multiple_choices(_mc_belief_qa)
            mc_question = "{}\n{}\n\nChoose an answer from above:".format(_belief_qa['question'], choices_text.strip())
            _mc_belief_qa['question'] = mc_question
            _mc_belief_qa['question_type'] = _mc_belief_qa['question_type'] + ":multiple-choice"
            _mc_belief_qa['choices_text'] = choices_text
            _mc_belief_qa['choices_list'] = choices_text.strip().split("\n")
            _mc_belief_qa['correct_answer'] = answer
            input_text = "{}\n\nQuestion: {}".format(context, mc_question)
            _mc_belief_qa['input_text'] = input_text
            qas.append(_mc_belief_qa)
            inputs.append(input_text)

        # Answerability List Questions
        _set['answerabilityQA_list']['fact_question'] = fact_q
        _set['answerabilityQA_list']['context'] = context
        input_text = "{}\n\nTarget: {}\nQuestion: {}\nAnswer:".format(context, fact_q, _set['answerabilityQA_list']['question'])
        _set['answerabilityQA_list']['input_text'] = input_text
        _set['answerabilityQA_list']['set_id'] = set_id
        if conversation_input_type == "full" and len(_set['answerabilityQA_list']['wrong_answer']) > 0:
            _set['answerabilityQA_list']['missed_info_accessibility'] = 'inaccessible'
        qas.append(_set['answerabilityQA_list'])
        inputs.append(input_text)

        # Answerability Binary Questions
        if conversation_input_type == "full":
            missed_info_accessibility_for_full = _set['answerabilityQAs_binary'][0]['missed_info_accessibility']
            for _info_accessibility_qa in _set['answerabilityQAs_binary']:
                if _info_accessibility_qa['correct_answer'] != "yes":
                    missed_info_accessibility_for_full = 'inaccessible'

        for _answerability_qa in _set['answerabilityQAs_binary']:
            _answerability_qa['fact_question'] = fact_q
            _answerability_qa['context'] = context
            input_text = "{}\n\nTarget: {}\nQuestion: {} Answer yes or no.\nAnswer:".format(context, fact_q, _answerability_qa['question'])
            _answerability_qa['input_text'] = input_text
            _answerability_qa['set_id'] = set_id
            if conversation_input_type == "full":
                _answerability_qa['missed_info_accessibility'] = missed_info_accessibility_for_full
            qas.append(_answerability_qa)
            inputs.append(input_text)

        # Info Accessibility List Questions
        _set['infoAccessibilityQA_list']['fact_question'] = fact_q
        _set['infoAccessibilityQA_list']['fact_answer'] = fact_a
        _set['infoAccessibilityQA_list']['context'] = context
        input_text = "{}\n\nInformation: {} {}\nQuestion: {}\nAnswer:".format(context, fact_q, fact_a, _set['infoAccessibilityQA_list']['question'])
        _set['infoAccessibilityQA_list']['input_text'] = input_text
        _set['infoAccessibilityQA_list']['set_id'] = set_id
        if conversation_input_type == "full" and len(_set['infoAccessibilityQA_list']['wrong_answer']) > 0:
            _set['infoAccessibilityQA_list']['missed_info_accessibility'] = 'inaccessible'
        qas.append(_set['infoAccessibilityQA_list'])
        inputs.append(input_text)

        # Info Accessibility Binary Questions
        if conversation_input_type == "full":
            missed_info_accessibility_for_full = _set['infoAccessibilityQAs_binary'][0]['missed_info_accessibility']
            for _info_accessibility_qa in _set['infoAccessibilityQAs_binary']:
                if _info_accessibility_qa['correct_answer'] != "yes":
                    missed_info_accessibility_for_full = 'inaccessible'

        for _info_accessibility_qa in _set['infoAccessibilityQAs_binary']:
            _info_accessibility_qa['fact_question'] = fact_q
            _info_accessibility_qa['fact_answer'] = fact_a
            _info_accessibility_qa['context'] = context
            input_text = "{}\n\nInformation: {} {}\nQuestion: {} Answer yes or no.\nAnswer:".format(context, fact_q, fact_a, _info_accessibility_qa['question'])
            _info_accessibility_qa['input_text'] = input_text
            _info_accessibility_qa['set_id'] = set_id
            if conversation_input_type == "full":
                _info_accessibility_qa['missed_info_accessibility'] = missed_info_accessibility_for_full
            qas.append(_info_accessibility_qa)
            inputs.append(input_text)

    return inputs, qas


def get_FanToM_text(train_config, tokenizer, train=True):
    data_path = train_config.train_qa if train else train_config.eval_qa
    df = pd.read_json(data_path)
    metas, QAs = setup_fantom(df)
    final_QAs, read_prompts = [], []
    for item, meta in zip(QAs, metas):
        final_QAs.append([
                {"role": "user", "content": item['input_text'].split('\n\n')[1].strip('\nAnswer:')},
                {"role": "assistant", "content": str(item['correct_answer'])}
        ])
        
        temp_read_prompt = []
        for line in item['input_text'].split('\n\n')[0].strip().split('\n'):
            temp_read_prompt.append({
                'role':line.split(':')[0].strip(),
                'content':line.split(':')[1].strip(),
            })
        read_prompts.append(temp_read_prompt)
    assert len(final_QAs)==len(read_prompts)
    return read_prompts, final_QAs

In [4]:
from datasets import Dataset
alpaca_prompt = """
### Dialogue:
{}

### Qeustion:
{}

### Response:
{}""".strip()
class dummy_config:
    def __init__(self):
        self.train_qa = './FanToM/train.json'
        self.eval_qa = './FanToM/valid.json'
train_read_prompts, train_QAs = get_FanToM_text(dummy_config(), tokenizer, True)
flat_dialog = lambda dialog:' '.join([f'{item["role"]} : {item["content"]}' for item in dialog])
flat_question = lambda question:question['content'].replace('Question: ', '')
flat_answer = lambda answer:answer['content']
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
train_data = [
    alpaca_prompt.format(flat_dialog(rp), flat_question(qa[0]), flat_answer(qa[1])) + ' ' + EOS_TOKEN
    for rp, qa in zip(train_read_prompts, train_QAs)]

train_formatted_data = {'text':train_data}
train_dataset = Dataset.from_dict(train_formatted_data)

In [5]:

test_read_prompts, test_QAs = get_FanToM_text(dummy_config(), tokenizer, False)
test_data = [
    alpaca_prompt.format(flat_dialog(rp), flat_question(qa[0]), flat_answer(qa[1])) + ' ' + EOS_TOKEN
    for rp, qa in zip(test_read_prompts, test_QAs)]

test_formatted_data = {'text':test_data}
test_dataset = Dataset.from_dict(test_formatted_data)

In [6]:
test_data[0]

"### Dialogue:\nJorge : Hey Gunner, have you ever tried any of the modern-day food trends? They are pretty controversial sometimes. Gunner : You bet, Jorge. I have tried the plant-based diet trend. Honestly, it's cool how our food preferences can significantly reduce our carbon footprint. Jorge : That's true! It is much more than just a diet trend; it is an environmental responsibility as well. How about any controversial food experiences? Gunner : I once tried the raw food diet, which is quite tricky. I mean, consuming only uncooked and unprocessed food is quite a switch. It's disputable because many say cooking improves digestion and increases nutrient absorption. Jorge : I agree, I think even our ancestors discovered fire to cook their food and make it more digestible. Plus, the raw diet can lead to limited protein options and nutritional deficiencies. Gunner : Exactly! Also, another controversial trend I've noticed is this obsession with 'superfoods'. They're suddenly everywhere an

In [7]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset= test_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size=32,
        gradient_accumulation_steps=4,
        warmup_steps=5,
        max_steps=240,
        learning_rate=5e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
        report_to="none",  # Use this for WandB etc
        evaluation_strategy="steps",  # Evaluate at a specific step interval
        eval_steps=30,  # Perform evaluation every 15 steps
        save_steps=30,  # Optionally, save the model every 15 steps
        load_best_model_at_end=True  # Optional if you want the best model based on evaluation
    )
)



Map (num_proc=2):   0%|          | 0/7223 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/1814 [00:00<?, ? examples/s]

In [8]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 7,223 | Num Epochs = 5
O^O/ \_/ \    Batch size per device = 32 | Gradient Accumulation steps = 4
\        /    Total batch size = 128 | Total steps = 240
 "-____-"     Number of trainable parameters = 41,943,040


Step,Training Loss,Validation Loss
30,1.124,1.260737
60,0.322,0.659798
90,0.0593,0.430272
120,0.0354,0.483304
150,0.0263,0.545278


Unsloth: Not an error, but LlamaForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


KeyboardInterrupt: 

In [8]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 7,223 | Num Epochs = 4
O^O/ \_/ \    Batch size per device = 32 | Gradient Accumulation steps = 4
\        /    Total batch size = 128 | Total steps = 180
 "-____-"     Number of trainable parameters = 24,313,856


Step,Training Loss,Validation Loss
30,1.4111,1.466766
60,0.7818,1.071449
90,0.2574,0.672967
120,0.1306,0.501574
150,0.0682,0.450482
180,0.0712,0.444764


Unsloth: Not an error, but LlamaForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


In [10]:
from transformers import AutoModelForCausalLM

best_model = AutoModelForCausalLM.from_pretrained("./outputs/checkpoint-90")

`low_cpu_mem_usage` was None, now default to True since model is quantized.


In [11]:
from tqdm import tqdm
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
class dummy_config:
    def __init__(self):
        self.train_qa = './FanToM/test.json'
        self.eval_qa = './FanToM/test.json'
read_prompts, QAs = get_FanToM_text(dummy_config(), tokenizer, False)
from datasets import Dataset
# alpaca_prompt = """
# ### Input:
# {}

# ### Response:
# {}""".strip()
responses_data = []
for idx, (rp, qa) in tqdm(enumerate(list(zip(read_prompts, QAs))), total=len(QAs)):
  # inputs = alpaca_prompt.format(str(rp+[qa[0]]), '=> ')
  inputs = alpaca_prompt.format(flat_dialog(rp), flat_question(qa[0]), '') 
  # print(inputs)
  inputs = tokenizer([inputs], return_tensors = "pt").to("cuda")
  outputs = model.generate(**inputs, max_new_tokens = 130, use_cache = True)
  response = tokenizer.batch_decode(outputs)
  # print(str(response).split('### Response:')[-1].strip())
  response = str(response).split('### Response:')[-1].strip().replace('<|eot_id|>"]', '').replace('\\n', '')
  # print(response.strip())
  temp = {
      'index':idx,
      'response':response.strip(),
      'input_prompt':"what is the intent of each agent for the last utterances? What are the beliefs and desires of each agent?",
      'ground_truth':qa[-1]['content'].strip(),
  }
  responses_data.append(temp)

with open("FanToM FT Llama-3 3B.jsonl", "w") as f:
  json.dump(responses_data, f, indent=2)

100%|██████████| 3795/3795 [39:42<00:00,  1.59it/s]  
