In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, get_peft_model
import json
import random
import torch
import numpy as np

model = AutoModelForCausalLM.from_pretrained("gpt2-large").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
train_list = [
    "college_biology",
    "high_school_biology",
    "college_computer_science",
    "high_school_computer_science",
    "high_school_us_history",
    "computer_security",
    "machine_learning",
    "global_facts"
]

subject_list = [
    "college_biology",
    "high_school_biology",
    "college_computer_science",
    "high_school_computer_science",
    "high_school_us_history",
    "computer_security",
    "machine_learning",
    "global_facts"
]

In [None]:
def formatting_prompts_func(example, subject):
  output_texts = []
  for i in range(len(example['question'])):
    text = f"{example['question'][i]}\n(A) {example['choices'][i][0]} (B) {example['choices'][i][1]} (C) {example['choices'][i][2]} (D) {example['choices'][i][3]}\nAnswer: {chr(65 + example['answer'][i])}"
    output_texts.append(text)
  return output_texts

def formatting_prompts_func_without_answers(example, subject):
  output_texts = []
  for i in range(len(example['question'])):
    text = f"{example['question'][i]}\n(A) {example['choices'][i][0]} (B) {example['choices'][i][1]} (C) {example['choices'][i][2]} (D) {example['choices'][i][3]}\nAnswer:"
    output_texts.append(text)
  return output_texts

# used to generate few-shot prompts
def formatting_few_shots_prompts_func(few_shot_exemplars, test_dataset, test_subject):
  subject = test_subject.replace("_", " ")
  few_shot_exemplars_prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n" + '\n\n'.join(few_shot_exemplars)
  return [few_shot_exemplars_prompt + "\n\n" + prompt for prompt in test_dataset]


In [None]:
def extract_answer(text):
    lines = text.split('\n')
    for line in lines[::-1]:
        if line.startswith("Answer"):
            return line.split(": ")[1]

In [None]:
for train_subject in train_list: # for each source task
  res = {}
  train_dataset = load_dataset("cais/mmlu", train_subject, split="test")
  train_dataset_formatted = formatting_prompts_func(train_dataset, train_subject)
  few_shot_exemplars = random.sample(train_dataset_formatted, 5) # default is 5-shot
  for target_task in subject_list:
    print("\ncurrent source task is", train_subject, "target task is", target_task)
    test_dataset = load_dataset("cais/mmlu", target_task, split="test")
    test_dataset_formatted, test_dataset_ground_truths = formatting_prompts_func_without_answers(test_dataset, target_task), formatting_prompts_func(test_dataset, target_task)

    # form test dataset with few-shot exemplars
    test_dataset_formatted_few_shot_list = []
    for i in range(1,6):
      test_dataset_formatted_few_shot_list.append(formatting_few_shots_prompts_func(few_shot_exemplars[:i], test_dataset_formatted, target_task)) # 1-5 shot to avoid exceeding max token length (1024)
    print("5-shot prompt example:", test_dataset_formatted_few_shot_list[-1][0])

    # select the few-shot prompt and start the evaluation on target task
    correct, total = 0, 0
    for i in range(len(test_dataset_formatted_few_shot_list[-1])): # default is 5-shot
      k = 1 # -k = -1 => test_dataset_formatted_few_shot_list[-1] = 5-shot
      while(tokenizer.encode(test_dataset_formatted_few_shot_list[5-k][i], return_tensors="pt").shape[1] > 1024): # input token length might exceed max sequence length of GPT2-large (1024), try less exemplars
        k += 1 # trying (6-k) shot prompt
        print("trying", 6-k, "shot prompt")
      inputs = tokenizer.encode(test_dataset_formatted_few_shot_list[5-k][i], return_tensors="pt").to("cuda") # now input token length < 1024

      generate_kwargs = dict(
          input_ids=inputs,
          temperature=0.9,
          top_k=50,
          max_new_tokens=1,
          repetition_penalty=1
      )
      outputs = model.generate(**generate_kwargs)
      prediction = tokenizer.decode(outputs[0])
      print(extract_answer(prediction), extract_answer(test_dataset_ground_truths[i]))
      if extract_answer(prediction) == extract_answer(test_dataset_ground_truths[i]):
        correct += 1
      total += 1
    print("current few-shot test accuracy on target task =", correct / total)
    res[target_task] = correct / total
  with open("ICL_few_shots_" + train_subject + ".json", "w") as outfile:
    json.dump(res, outfile)