This file seeks to use [SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) from huggingface to train on custom data to make the process easier than configuring the dataset for llama-recipes which seems much more difficult.

In [None]:
from datasets import load_dataset
from trl import SFTTrainer
from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
import torch
import transformers
import torch.nn.functional as F

## Load in the Model

In [4]:
# Load Model
model_dir = "./models/llama"

model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    # QUANTIZE MODEL, NECESSARY STEP
    quantization_config=BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4'
            ))

# Load Tokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_dir)

# Specify device
### NEEDED OR IT WILL BE INCREDIBLY SLOW
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device is:", device)

# Create Pipeline
pipeline = transformers.pipeline(
    "text-generation",

    model=model,

    tokenizer=tokenizer,

    torch_dtype=torch.float16,
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


Device is: cuda


## Evaluating on the MMLU Dataset

### Prepare the Questions and Answers

In [4]:
import pandas as pd
df = pd.read_csv('./data/test/mmlu/high_school_mathematics_test.csv')

# column names are a question, so add it to the end
df.loc[len(df)] = df.columns

# rename columns appropriately
df.columns = ['question', 'a', 'b', 'c', 'd', 'answer']

In [5]:
df.head()

Unnamed: 0,question,a,b,c,d,answer
0,The length of a rectangle is twice its width. ...,2500,2,50,25,C
1,"A positive integer n is called “powerful” if, ...",392,336,300,297,A
2,"At breakfast, lunch, and dinner, Joe randomly ...",\frac{7}{9},\frac{8}{9},\frac{5}{9},\frac{9}{11},B
3,Suppose $f(x)$ is a function that has this pro...,"(-inf, 10)","(-inf, 9)","(-inf, 8)","(-inf, 7)",C
4,John divided his souvenir hat pins into two pi...,396,72,66,36,B


In [6]:
# first question in dataset
df['question'][0]

'The length of a rectangle is twice its width. Given the length of the diagonal is $5\\sqrt{5}$, find the area of the rectangle.'

In [7]:
def format_prompt(x):
    return f"answer with one number from the question: {x['question']} 1: {x['a']} 2: {x['b']} 3: {x['c']} 4: {x['d']}"

df['question-new'] = df.apply(format_prompt, axis = 1)

In [8]:
df.head()

Unnamed: 0,question,a,b,c,d,answer,question-new
0,The length of a rectangle is twice its width. ...,2500,2,50,25,C,answer with one number from the question: The ...
1,"A positive integer n is called “powerful” if, ...",392,336,300,297,A,answer with one number from the question: A po...
2,"At breakfast, lunch, and dinner, Joe randomly ...",\frac{7}{9},\frac{8}{9},\frac{5}{9},\frac{9}{11},B,answer with one number from the question: At b...
3,Suppose $f(x)$ is a function that has this pro...,"(-inf, 10)","(-inf, 9)","(-inf, 8)","(-inf, 7)",C,answer with one number from the question: Supp...
4,John divided his souvenir hat pins into two pi...,396,72,66,36,B,answer with one number from the question: John...


### Save to JSON

In [13]:
# Create output json
import json
json_output = []

# iterate over data frame questions
for question, answer in zip(df['question-new'], df['answer']):
    json_output.append({'input': question, 'answer': str(answer)})
    
# Write output json
output_path = './data/test/test_mmlu.json'

with open(output_path, "w") as f:
    json.dump(json_output, f, indent=4)

### Evaluate the Model

In [49]:
prompts = [q for q in df['question-new']]
answers = [a for a in df['answer']]

In [50]:
print(f'Prompt 0: {prompts[0]}\nAnswer 0: {answers[0]}')

Prompt 0: answer with one number from the question: Let p = (1, 2, 5, 4)(2, 3) in S_5 . Find the index of <p> in S_5. 1: 8 2: 2 3: 24 4: 120
Answer 0: C


In [51]:
i = 0
for prompt, answer in zip(prompts[:5], answers[:5]):
    sequences = pipeline(
        prompt,
        do_sample = True,
        top_k = 10,
        num_return_sequences = 1,
        eos_token_id = tokenizer.eos_token_id,
        max_length = 500,
        truncation = True
    )
    print(f'===== Iteration {i} =====\n')
    print(f'Question:\n{prompt}\n')
    print(f'Preparing Sequences ...\n')
    for s, seq in enumerate(sequences):
        print(f'Generated Answer {s}:')
        print(f"{seq['generated_text']}")
        
    print(f'\nReal Answer: {answer}\n\n')
    i += 1
    

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


===== Iteration 0 =====

Question:
answer with one number from the question: Let p = (1, 2, 5, 4)(2, 3) in S_5 . Find the index of <p> in S_5. 1: 8 2: 2 3: 24 4: 120

Preparing Sequences ...

Generated Answer 0:
answer with one number from the question: Let p = (1, 2, 5, 4)(2, 3) in S_5 . Find the index of <p> in S_5. 1: 8 2: 2 3: 24 4: 120 5: 60
The given permutation p = (1, 2, 5, 4)(2, 3) in S_5 can be written as π = (1, 2, 4)(3, 5).
To find the index of π in S_5, we need to check which class π belongs to in the coset π + S_5.
π + S_5 = {π + (1, 2, 5, 4)(2, 3)} = {(1, 2, 4)(3, 5), (1, 2, 5, 4)(2, 3), (1, 5, 4)(2, 3), (5, 4)(2, 3), (2, 3, 4)(1, 5), (2, 3, 5, 4)(1, 2)} = 5
So, the index of π in S_5 is 5.
Therefore, the answer is (5).

Real Answer: C




Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


===== Iteration 1 =====

Question:
answer with one number from the question: Find all zeros in the indicated finite field of the given polynomial with coefficients in that field. x^5 + 3x^3 + x^2 + 2x in Z_5 1: 0 2: 1 3: 0,1 4: 0,4

Preparing Sequences ...

Generated Answer 0:
answer with one number from the question: Find all zeros in the indicated finite field of the given polynomial with coefficients in that field. x^5 + 3x^3 + x^2 + 2x in Z_5 1: 0 2: 1 3: 0,1 4: 0,4 5: 2,0

The solution is:

1: x = 0

2: x = 1

3: x = 2

4: x = 3

5: x = 4

Note that the polynomial has no real zeros, only ones in the field.

Real Answer: D




Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


===== Iteration 2 =====

Question:
answer with one number from the question: Statement 1 | A factor group of a non-Abelian group is non-Abelian. Statement 2 | If K is a normal subgroup of H and H is a normal subgroup of G, then K is a normal subgroup of G. 1: True, True 2: False, False 3: True, False 4: False, True

Preparing Sequences ...

Generated Answer 0:
answer with one number from the question: Statement 1 | A factor group of a non-Abelian group is non-Abelian. Statement 2 | If K is a normal subgroup of H and H is a normal subgroup of G, then K is a normal subgroup of G. 1: True, True 2: False, False 3: True, False 4: False, True 5: True, True
A factor group of a non-Abelian group is non-Abelian. If K is a normal subgr...
Let G be a group and let K be a subgrou...
Let G be a group and let K be a subgroup of G. Prove or disprove the following statements:
1. Statement 1 | A factor group of a non-Abelian group is non-Abelian.
2. Statement 2 | If K is a normal subgroup of H and H is

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


===== Iteration 3 =====

Question:
answer with one number from the question: Find the product of the given polynomials in the given polynomial ring. f(x) = 4x - 5, g(x) = 2x^2 - 4x + 2 in Z_8[x]. 1: 2x^2 + 5 2: 6x^2 + 4x + 6 3: 0 4: x^2 + 1

Preparing Sequences ...

Generated Answer 0:
answer with one number from the question: Find the product of the given polynomials in the given polynomial ring. f(x) = 4x - 5, g(x) = 2x^2 - 4x + 2 in Z_8[x]. 1: 2x^2 + 5 2: 6x^2 + 4x + 6 3: 0 4: x^2 + 10x + 10 5: 2x - 3x^2 + 11x + 12 6: 3x^2 + 10x + 9 7: -3x^2 - 12x + 15 8: 4x^2 - 13x + 7 9: -6x^2 - 2x + 8 10: -2x + 5x^2 + 10x - 11 Thank you for your answer!

Real Answer: B


===== Iteration 4 =====

Question:
answer with one number from the question: Statement 1 | If a group has an element of order 15 it must have at least 8 elements of order 15. Statement 2 | If a group has more than 8 elements of order 15, it must have at least 16 elements of order 15. 1: True, True 2: False, False 3: True, False 4