# Single-Answer Tuning
In this notebook we use the Oasst dataset to fine tune GPT2 first using SFT and then DPO.

Note that we use little data and few steps because I ran the code locally for debugging.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from finetuning.data.utils import (
    get_single_step_conversations,
    create_preference_df,
    create_qa_df
)
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig, DPOConfig, DPOTrainer
import os 
from datasets import Dataset
from peft import LoraConfig
from pathlib import Path

logging_dir = Path('../../../logging')


In [None]:
model_str = 'openai-community/gpt2'
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = 50256
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")

### Get conversations

In [None]:
convos = get_single_step_conversations()
convo_ids = convos.shuffle()[:2]['message_tree_id']
convos_small = convos.filter(
    lambda row: row['message_tree_id'] in convo_ids 
        and (row['rank'] is None or row['rank'] < .5)
)

In [None]:
df_train = create_qa_df(convos_small)
ds_train = Dataset.from_pandas(df_train)

### SFT + LoRA

In [None]:
def prompt_formatter(row: dict) -> str:
    return f"""
### User:
{row['prompt'].replace('#', '/#')}

### Assistant:
{row['answer']}

"""

sft_cfg = SFTConfig(
    max_seq_length=512,
    output_dir=logging_dir / 'sft',
    logging_steps=1,
    packing=True,
    weight_decay=0.01,
    report_to='none',
    max_steps=2,
    dataloader_num_workers=os.cpu_count()-1,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    lr_scheduler_type='constant',
)

lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules="all-linear",
    modules_to_save=["lm_head", "embed_token"],
    task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    model_str,
    train_dataset=ds_train,
    args=sft_cfg,
    formatting_func=prompt_formatter,
    peft_config=lora_cfg,
)
trainer.train()

### DPO
Here we also create a preference dataset.

In [None]:
convo_ids = list(set(convos.shuffle()[:2]['message_tree_id']))
convos_small = convos.filter(
    lambda row: row['message_tree_id'] in convo_ids 
        # and (row['rank'] is None or row['rank'] < .5)
)
pref_df = create_preference_df(convos_small)
ds_train = Dataset.from_pandas(pref_df)

In [None]:
dpo_cfg = DPOConfig(
    output_dir=logging_dir / 'dpo', 
    logging_steps=1,
    weight_decay=0.01,
    report_to='none',
    max_steps=2,
    dataloader_num_workers=os.cpu_count()-1,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    lr_scheduler_type='constant',
)
trainer = DPOTrainer(model=model, processing_class=tokenizer, args=dpo_cfg, train_dataset=ds_train)
trainer.train()