In [None]:
!pip install datasets trl peft evaluate

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
from datasets import load_dataset, load_metric
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig, get_peft_model
import evaluate
import numpy as np
import torch

In [None]:
subject_list = [
    "college_biology",
    "high_school_biology",
    "college_computer_science",
    "high_school_computer_science",
    "us_foreign_policy",
    "computer_security",
    "machine_learning",
    "global_facts"
]

In [None]:
index = 0
train_subject = subject_list[index]
eval_subject = subject_list[index]

train_dataset = load_dataset("cais/mmlu", train_subject, split="test")
eval_dataset = load_dataset("cais/mmlu", eval_subject, split="validation")

In [None]:
def formatting_prompts_func(example):
  output_texts = []
  for i in range(len(example['question'])):
    text = f"The following is a multiple choice questions (with answers).\n{example['question'][i]}\n(A) {example['choices'][i][0]} (B) {example['choices'][i][1]} (C) {example['choices'][i][2]} (D) {example['choices'][i][3]}\nAnswer:{chr(65 + example['answer'][i])}"
    output_texts.append(text)
  return output_texts

def formatting_prompts_func_without_answers(example):
  output_texts = []
  for i in range(len(example['question'])):
    text = f"The following is a multiple choice questions (with answers).\n{example['question'][i]}\n(A) {example['choices'][i][0]} (B) {example['choices'][i][1]} (C) {example['choices'][i][2]} (D) {example['choices'][i][3]}\nAnswer:"
    output_texts.append(text)
  return output_texts

In [None]:
lora_config = LoraConfig(
    init_lora_weights="gaussian",
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
)

# load model and freeze weights
model = AutoModelForCausalLM.from_pretrained("gpt2-large")
for param in model.parameters():
    param.requires_grad = False

tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
tokenizer.pad_token = tokenizer.eos_token

model = get_peft_model(model, lora_config)

In [None]:
metric = load_metric("accuracy", trust_remote_code=True)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=2)

    labels_index = np.where(labels != -100)[1]
    pred_results = predictions[np.arange(predictions.shape[0]), labels_index-1]
    label_results = labels[np.arange(labels.shape[0]), labels_index]
    print(pred_results, label_results)

    return metric.compute(predictions=pred_results, references=label_results)

In [None]:
response_template = "Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

training_args = SFTConfig(
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end=True,
    num_train_epochs=10.0,
    learning_rate=1e-4,
    weight_decay=0.05,
    logging_steps=10,
    metric_for_best_model = 'accuracy',
    output_dir="./output",
    label_names = ["labels"],
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_args,
    tokenizer=tokenizer,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
trainer.save_model("gpt2-large_sft_" + train_subject + "_10_epochs")

In [None]:
model.save_pretrained("best_model_" + train_subject)
tokenizer.save_pretrained("best_model_" + train_subject)

**Evaluation**

In [None]:
def extract_answer(text):
    lines = text.split('\n')
    for line in lines[::-1]:
        if line.startswith("Answer:"):
            return line.split(":")[1]

In [None]:
train_dataset_formatted, train_dataset_eval = formatting_prompts_func(train_dataset), formatting_prompts_func_without_answers(train_dataset)

In [None]:
# After SFT: evaluate the model on target tasks
result = {}
fine_tuned_model = AutoModelForCausalLM.from_pretrained("best_model_" + train_subject).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("best_model_" + train_subject)
for target_task in subject_list:
  correct, total = 0, 0
  test_dataset = load_dataset("cais/mmlu", target_task, split="test")
  test_dataset_formatted, test_dataset_ground_truths = formatting_prompts_func_without_answers(test_dataset), formatting_prompts_func(test_dataset)
  for i in range(len(test_dataset_formatted)):
    inputs = tokenizer.encode(test_dataset_formatted[i], return_tensors="pt").to("cuda")

    generate_kwargs = dict(
        input_ids=inputs,
        temperature=0.9,
        top_k=40,
        max_new_tokens=1,
        repetition_penalty=1
    )
    outputs = fine_tuned_model.generate(**generate_kwargs)
    prediction = tokenizer.decode(outputs[0])
    print("pred: ", extract_answer(prediction), "ground truth: ", extract_answer(test_dataset_ground_truths[i]))
    if extract_answer(prediction) == extract_answer(test_dataset_ground_truths[i]):
      correct += 1
    total += 1

  print("current SFT test accuracy on target task =", target_task, "is", correct / total)
  result[target_task] = correct / total

In [None]:
import json

with open("SFT_10_epochs_" + train_subject + "_new.json", "w") as outfile:
    json.dump(result, outfile)

**SFT with the second setting and evaluation**


In [None]:
result_sft = {}
for target_task in subject_list:
  if target_task != train_subject:
    lora_config = LoraConfig(
        init_lora_weights="gaussian",
        r=8,
        lora_alpha=8,
        lora_dropout=0.05,
    )

    # load model and freeze weights
    model = AutoModelForCausalLM.from_pretrained("best_model_" + train_subject).to("cuda")
    for param in model.parameters():
        param.requires_grad = False

    tokenizer = AutoTokenizer.from_pretrained("best_model_" + train_subject)
    tokenizer.pad_token = tokenizer.eos_token
    model = get_peft_model(model, lora_config)

    # further fine-tuning on the validation set of the target task
    train_dataset = load_dataset("cais/mmlu", target_task, split="validation")
    eval_dataset = load_dataset("cais/mmlu", target_task, split="dev")

    response_template = "Answer:"
    collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

    training_args = SFTConfig(
      evaluation_strategy = "epoch",
      save_strategy = "epoch",
      load_best_model_at_end=True,
      num_train_epochs=5,
      learning_rate=5e-4,
      weight_decay=0.05,
      logging_steps=1,
      metric_for_best_model = 'accuracy',
      output_dir="./output",
      label_names = ["labels"],
    )

    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        args=training_args,
        tokenizer=tokenizer,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    trainer.save_model("gpt2-large_sft_" + train_subject + "_10_epochs_" + target_task + "5_epochs")

    fine_tuned_model = AutoModelForCausalLM.from_pretrained("gpt2-large_sft_" + train_subject + "_10_epochs_" + target_task + "5_epochs").to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("gpt2-large_sft_" + train_subject + "_10_epochs_" + target_task + "5_epochs")
    correct, total = 0, 0
    test_dataset = load_dataset("cais/mmlu", target_task, split="test")
    test_dataset_formatted, test_dataset_ground_truths = formatting_prompts_func_without_answers(test_dataset), formatting_prompts_func(test_dataset)
    for i in range(len(test_dataset_formatted)):
      inputs = tokenizer.encode(test_dataset_formatted[i], return_tensors="pt").to("cuda")

      generate_kwargs = dict(
          input_ids=inputs,
          temperature=0.9,
          top_k=40,
          max_new_tokens=1,
          repetition_penalty=1
      )
      outputs = fine_tuned_model.generate(**generate_kwargs)
      prediction = tokenizer.decode(outputs[0])
      print("pred: ", extract_answer(prediction), "ground truth: ", extract_answer(test_dataset_ground_truths[i]))
      if extract_answer(prediction) == extract_answer(test_dataset_ground_truths[i]):
        correct += 1
      total += 1

    print("current further SFT test accuracy on target task =", target_task, "is", correct / total)
    result_sft[target_task] = correct / total

In [None]:
import json

with open("SFT_10_epochs_" + train_subject + "_new_further.json", "w") as outfile:
    json.dump(result_sft, outfile)