# Part 0A (prequel) - MMLU benchmark and MMLU shuffle influence

Please skip to Part 1 notebook to see interesting clean stuff.  
This notebook is only here for its archeological interest.  

This is the prequel notebook, where we investigate how MMLU inference works, what is the effect of applying a shuffle to answer choices.  
It contains badly optimized code (we run batches of N permutations, the batch count is tied to the permutation count, and the translation logic is also tied to the inference function). We also made quick and dirty code to shuffle a MMLU dataset.  
In this first notebook, we initially tried to make a dataset split for each category and each letter_order, but that plan was a bad idea as handling 1300+ files is really slow on the hub.  



## install libs

In [21]:
!pip install matplotlib #if needed
!pip install seaborn

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting seaborn
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.9/294.9 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: seaborn
Successfully installed seaborn-0.13.2
[0m

In [1]:
# Run once

import torch
major_version, minor_version = torch.cuda.get_device_capability()

!pip install transformers datasets
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install -U --no-deps xformers==0.0.25.post1 trl peft accelerate bitsandbytes

[0mCollecting unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-se9r657q/unsloth_2bf66dac545d41c5b850780709e4915f
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-se9r657q/unsloth_2bf66dac545d41c5b850780709e4915f
  Resolved https://github.com/unslothai/unsloth.git to commit 27fa021a7bb959a53667dd4e7cdb9598c207aa0d
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[0m

In [2]:
# !pip install -U xformers==0.0.25.post1

## load MMLU datasets (all configs)

In [9]:
from datasets import load_dataset, get_dataset_config_names
from tqdm.auto import tqdm

configs = get_dataset_config_names("tasksource/mmlu")

print('here are some categories : ')
print(configs[:5])
print('etc...')

# load all categories into a list
dataset_list = []
for curr_cat in tqdm(configs, desc='loading categories...'):
    dataset_list.append(load_dataset("tasksource/mmlu", curr_cat, split='test'))

here are some categories : 
['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge']
etc...


loading categories...:   0%|          | 0/57 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/19.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.86k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.08k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/135 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/14 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/27.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.51k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/152 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/21.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.48k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/40.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.98k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.17k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/265 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/29 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/31.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.80k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/144 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/17.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.38k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.55k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/8 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/27.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.71k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/16.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.50k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.66k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/42.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.51k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.36k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/173 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/18.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.91k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.03k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/102 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/18.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.18k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.84k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/24.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.48k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.46k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/235 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/26 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/24.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.56k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.09k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/114 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/12 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/17.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.55k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.55k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/145 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/40.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.86k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.02k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/378 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/41 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/21.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.11k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.35k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/126 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/14 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/11.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.74k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.12k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/32.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.79k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.64k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/203 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/26.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.71k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.97k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/9 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/142k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/31.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/21.6k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/165 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/18 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/27.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.65k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.41k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/198 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/39.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.65k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.86k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/193 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/54.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.34k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.49k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/390 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/43 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/33.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.46k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.97k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/270 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/29 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/38.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.67k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.28k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/238 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/26 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/32.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.46k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.07k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/151 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/17 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/92.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/14.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.66k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/545 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/60 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/57.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/10.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.55k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/216 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/154k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.3k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/204 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/201k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/37.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.70k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/237 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/26 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/30.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.83k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.22k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/223 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/22.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.61k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/131 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/12 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/29.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.63k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.47k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/121 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/22.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.58k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/108 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/22.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.03k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.63k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/163 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/18 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/19.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.69k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.76k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/112 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/14.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.06k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.17k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/103 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/36.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.77k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.85k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/234 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/25 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/15.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.28k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/98.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.91k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/783 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/86 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/60.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/10.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.94k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/346 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/38 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/89.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/14.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.66k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/895 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/54.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.59k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.55k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/306 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/33 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/48.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.71k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.76k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/311 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/34 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/53.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.45k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.17k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/324 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/35 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/69.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.36k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/282 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/31 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/115k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/14.7k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1534 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/170 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/124k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/19.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.93k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/272 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/31 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/132k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/21.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.16k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/612 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/69 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/20.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.96k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.95k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/110 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/12 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/114k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/18.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.00k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/245 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/27 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/43.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.92k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.77k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/201 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/19.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.78k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/26.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.44k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/166 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/18 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/18.4k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.46k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.83k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/171 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/19 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

In [10]:
dataset_list[0]

Dataset({
    features: ['question', 'choices', 'answer'],
    num_rows: 100
})

## load model

In [11]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
==((====))==  Unsloth: Fast Llama patching release 2024.5
   \\   /|    GPU: NVIDIA GeForce RTX 3090. Max memory: 23.999 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [1]:
def format_choice(example_choices):

  letters = ["A", "B", "C", "D"]

  lines = []
  for letter, choice in zip(letters, example_choices):
    lines.append(f"  {letter}) {choice}  ")

  return "\n".join(lines) + '\n'

### random choice order inference  

We have to:  
1. shuffle the choices  
2. adapt the tokens to match the new choice order  
3. decode the result to match the static dataset (no matter the choice order, if answer 0 is the correct answer, then we must always return 0)  

#### generate letter orders

In [2]:
from itertools import permutations
import random

# Generate all possible 4-letter words with the letters ABCD
letters = ['A', 'B', 'C', 'D']
all_words = [''.join(word) for word in permutations(letters, 4)]

# Function to select N unique elements at random from the list
def select_random_elements(N):
    return random.sample(all_words, N)

# Example usage
N = 5  # Change N to the desired number of elements
selected_words = select_random_elements(N)

In [14]:
selected_words

['ACBD', 'BCDA', 'CDBA', 'BDAC', 'CBDA']

#### reorder choices based on new letter order

In [3]:
def convert_letter_to_index(letter):
  letter_to_index_dict = {
      'A' : 0,
      'B' : 1,
      'C' : 2,
      'D' : 3
  }
  return letter_to_index_dict[letter]

def reorder_choices_based_on_letter_order(choices, letter_order):
  new_choices = []
  for letter in letter_order:
    new_choices.append(choices[convert_letter_to_index(letter)])

  return new_choices

#### make a new token string based on new letter order

In [4]:
def generate_answer_token_string(letter_order):
  return f" {letter_order[0]} {letter_order[1]} {letter_order[2]} {letter_order[3]}"

In [5]:
generate_answer_token_string(selected_words[0])

' D B C A'

#### translate result based on letter order  

We have to map the result from the shuffled choice to the static result from the dataset.

In [6]:
def translate_answer_based_on_letter_order(answer, letter_order):
  default_letter_order = list('ABCD')
  curr_letter_order = list(letter_order)
  conversion_dict = dict(zip(default_letter_order, curr_letter_order))
  return conversion_dict[answer]

translate_answer_based_on_letter_order("A", "DCBA")

'D'

In [11]:
def most_common(lst):
    return max(set(lst), key=lst.count)


## Batch inference (speeeed)

Batched inference is much faster, and we can compute all choices in parallel.  
Before applying the complete benchmark, we can use this boost.  

In [20]:
from itertools import permutations
import random

mmlu_prompt = """
Answer the following multiple choice question.
The last line of your response should be of the following format: 'The answer letter is : $LETTER' (without quotes) where LETTER is one of A B C D.
Think step by step before answering.

### Question:
{}

### Choices:
{}

### Anwser:
Given the choices A B C D , the answer is : {}"""

def get_number_result_from_question_parallel(question, choices, n_infer=24):

    # Generate all possible 4-letter words with the letters ABCD
    # these are our letter orders to reorder the choices
    letters = ['A', 'B', 'C', 'D']
    letter_orders = [''.join(word) for word in permutations(letters, 4)]

    # case of n_infer being a sample of all possible choice orders
    if not (n_infer == 24):
        if n_infer == 1: #special case of single inference, we preserve the default order
            letter_orders = ['ABCD']
        else: # partial orders: we get a random sample
            letter_orders = random.sample(letter_orders, n_infer)

    # duplicate questions and choices for each letter order
    questions = [question] * len(letter_orders)
    choices_list = [choices] * len(letter_orders)

    prompts = []
    answer_tokens_list = []
    for question, choices, letter_order in zip(questions, choices_list, letter_orders):
        prompt = mmlu_prompt.format(
            question,
            format_choice(reorder_choices_based_on_letter_order(choices, letter_order)),
            ""  # output - leave blank for model answer
        )
        prompts.append(prompt)

        answer_tokens = tokenizer.encode(
            generate_answer_token_string(letter_order), add_special_tokens=False, return_tensors="pt"
        ).to("cuda")
        answer_tokens_list.append(answer_tokens)

    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to("cuda")

    logits_list = []
    with torch.no_grad():
        logits = model(inputs.input_ids, attention_mask=inputs.attention_mask).logits
        torch.cuda.empty_cache()
        for i, answer_tokens in enumerate(answer_tokens_list):
            logits_ans = logits[i, -1, answer_tokens].cpu()
            logits_list.append(logits_ans)

    results = []
    local_answers = []
    for logits_ans, letter_order in zip(logits_list, letter_orders):
        prob_ans = torch.softmax(logits_ans, dim=-1)
        inferred_answer = prob_ans.argmax(dim=-1)
        local_answers.append(letter_order[inferred_answer])
        result = translate_answer_based_on_letter_order(letter_order[inferred_answer], letter_order)
        result = convert_letter_to_index(result)
        results.append(result)

    # for letter_order, prompt, local_answer in zip(letter_orders, prompts, local_answers):
    #   print('====================================================')
    #   print(letter_order)
    #   print(prompt)
    #   print(f'local answer : {local_answer}')
    #   print(f'translated answer : {translate_answer_based_on_letter_order(local_answer, letter_order)}')
    #   print('====================================================')
    return most_common(results)

# # Example usage
test_question = "What is the capital of France?"

test_choices = ["Monaco", "Barcelona", "Moscow", "Paris"]

get_number_result_from_question_parallel(test_question, test_choices)

3

In [27]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm

tqdm.pandas()

def perform_benchmark_analysis_parallel(input_dataset_df, n_infer=24):

  input_dataset_df['inferred_result'] = input_dataset_df.progress_apply(lambda row: get_number_result_from_question_parallel(row['question'], row['choices'], n_infer), axis=1)
  input_dataset_df['is_correct'] = input_dataset_df['inferred_result'] == input_dataset_df['answer']

  # # Display results
  # # Transforming data for seaborn countplot
  # df_melted = input_dataset_df.melt(value_vars=['inferred_result', 'answer'],
  #                                   var_name='Metric', value_name='Value')

  # df_melted['Metric'] = df_melted['Metric'].str.replace('answer', 'expected_answer')
  # df_melted['Metric'] = df_melted['Metric'].str.replace('inferred_result', 'inferred_answer')

  # # Plotting
  # plt.figure(figsize=(10, 8))
  # sns.countplot(data=df_melted, x='Value', hue='Metric')
  # plt.title('Frequency Distribution of inferred answer vs expected answer')
  # plt.xlabel('Numeric answer class')
  # plt.ylabel('Count')
  # plt.legend(title='Answer type')
  # plt.show()

  # Print stuff - for debug
  # print(f"Mean accuracy : {input_dataset_df.is_correct.mean()*100:.02f}%")
  # print(f"Good answers : {input_dataset_df.is_correct.value_counts().loc[True]}")
  # print(f"Bad answers : {input_dataset_df.is_correct.value_counts().loc[False]}")

  mean_acc = input_dataset_df.is_correct.mean()
  good_ans = input_dataset_df.is_correct.value_counts().loc[True]
  bad_ans = input_dataset_df.is_correct.value_counts().loc[False]
  return [mean_acc, good_ans, bad_ans]

In [31]:
results_df_list = []
for curr_dataset, category in tqdm(zip(dataset_list, configs), desc='benchmarking categories', total=len(configs)):
    dataset_df = curr_dataset.to_pandas()
    
    # for inference_count in [1, 3, 5, 10, 24]:
    for inference_count in [1]:
      stats = perform_benchmark_analysis_parallel(dataset_df, inference_count)
      stats.append(inference_count)
      stats.append(category)
      results_df_list.append(stats)
results_df_list

benchmarking categories:   0%|          | 0/57 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/135 [00:00<?, ?it/s]

  0%|          | 0/152 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/265 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/173 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/114 [00:00<?, ?it/s]

  0%|          | 0/145 [00:00<?, ?it/s]

  0%|          | 0/378 [00:00<?, ?it/s]

  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/310 [00:00<?, ?it/s]

  0%|          | 0/203 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/165 [00:00<?, ?it/s]

  0%|          | 0/198 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/390 [00:00<?, ?it/s]

  0%|          | 0/270 [00:00<?, ?it/s]

  0%|          | 0/238 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/545 [00:00<?, ?it/s]

  0%|          | 0/216 [00:00<?, ?it/s]

  0%|          | 0/204 [00:00<?, ?it/s]

  0%|          | 0/237 [00:00<?, ?it/s]

  0%|          | 0/223 [00:00<?, ?it/s]

  0%|          | 0/131 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/108 [00:00<?, ?it/s]

  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/112 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

  0%|          | 0/234 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/783 [00:00<?, ?it/s]

  0%|          | 0/346 [00:00<?, ?it/s]

  0%|          | 0/895 [00:00<?, ?it/s]

  0%|          | 0/306 [00:00<?, ?it/s]

  0%|          | 0/311 [00:00<?, ?it/s]

  0%|          | 0/324 [00:00<?, ?it/s]

  0%|          | 0/282 [00:00<?, ?it/s]

  0%|          | 0/1534 [00:00<?, ?it/s]

  0%|          | 0/272 [00:00<?, ?it/s]

  0%|          | 0/612 [00:00<?, ?it/s]

  0%|          | 0/110 [00:00<?, ?it/s]

  0%|          | 0/245 [00:00<?, ?it/s]

  0%|          | 0/201 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/171 [00:00<?, ?it/s]

[[0.32, 32, 68, 1, 'abstract_algebra'],
 [0.6148148148148148, 83, 52, 1, 'anatomy'],
 [0.5263157894736842, 80, 72, 1, 'astronomy'],
 [0.49, 49, 51, 1, 'business_ethics'],
 [0.6339622641509434, 168, 97, 1, 'clinical_knowledge'],
 [0.625, 90, 54, 1, 'college_biology'],
 [0.3, 30, 70, 1, 'college_chemistry'],
 [0.37, 37, 63, 1, 'college_computer_science'],
 [0.31, 31, 69, 1, 'college_mathematics'],
 [0.5375722543352601, 93, 80, 1, 'college_medicine'],
 [0.35294117647058826, 36, 66, 1, 'college_physics'],
 [0.61, 61, 39, 1, 'computer_security'],
 [0.43829787234042555, 103, 132, 1, 'conceptual_physics'],
 [0.39473684210526316, 45, 69, 1, 'econometrics'],
 [0.5793103448275863, 84, 61, 1, 'electrical_engineering'],
 [0.38095238095238093, 144, 234, 1, 'elementary_mathematics'],
 [0.3492063492063492, 44, 82, 1, 'formal_logic'],
 [0.33, 33, 67, 1, 'global_facts'],
 [0.6290322580645161, 195, 115, 1, 'high_school_biology'],
 [0.5123152709359606, 104, 99, 1, 'high_school_chemistry'],
 [0.56, 56, 44

In [32]:
import pandas as pd

result_df = pd.DataFrame(results_df_list)
result_df.columns = ['accuracy', 'good_answers', 'bad_answers', 'inference_count', 'category_name']
result_df

Unnamed: 0,accuracy,good_answers,bad_answers,inference_count,category_name
0,0.32,32,68,1,abstract_algebra
1,0.614815,83,52,1,anatomy
2,0.526316,80,72,1,astronomy
3,0.49,49,51,1,business_ethics
4,0.633962,168,97,1,clinical_knowledge
5,0.625,90,54,1,college_biology
6,0.3,30,70,1,college_chemistry
7,0.37,37,63,1,college_computer_science
8,0.31,31,69,1,college_mathematics
9,0.537572,93,80,1,college_medicine


In [34]:
result_df.accuracy.mean()

0.5348419595336704

In [33]:
result_df.to_csv('mmlu_complete_result.csv')

## Current conclusion

We observe 53% accuracy on the complete MMLU dataset with 0 shot prompts and 1 inference step.  

## Next steps  

Our process is not the most efficient because we mix inference and dataset modification.  
We use batching to parallelize a single question over N answer choices.  

Instead, we could make a new dataset that already has these 24 variations of a single question.  
Then, run inference in the best batch size that our GPU can accomodate to retrieve a giant dataset of 14k questions x 24 answers.  
Then only can we perform the voting process and determine the final answer to each question.  

## Convert MMLU dataset to shuffled version  

For each question, we create 23 additional variations to get the total 24 possible choices.  

EDIT: Making splits for each config and shuffle was a huge mistake: it takes a very long time to upload and download many files VS one single large file!  

We have to start over with a "no split version" to fix this issue.  

### define function to translate questions based on letter order  

For a given shuffle, we have to move choices, and update the answer index accordingly.  

In [17]:
letter_to_num = {
    'A':0,
    'B':1,
    'C':2,
    'D':3,
}

def convert_answer(reference_answer, letter_order):
    letters = list(letter_order)
    default_order = list('ABCD')
    translation_dictionary = dict(zip(letters, default_order))
    
    translated_letter = translation_dictionary[default_order[reference_answer]]
    return letter_to_num[translated_letter]

In [18]:
# takes a dataset row and shuffles the choices based on a new letter order, updates the answer accordingly
def translate_question(row, letter_order):
    
    new_choices = reorder_choices_based_on_letter_order(row['choices'], letter_order)
    default_order = 'ABCD'
       
    new_answer = convert_answer(row['answer'], letter_order)

    row['choices'] = new_choices
    row['answer'] = new_answer
    return row

In [21]:
row = {}
row['question'] = 'Find the degree for the given field extension'
row['choices'] = [0, 4, 2, 6]	
row['answer'] = 1

In [22]:
translate_question(row, 'ABCD')

{'question': 'Find the degree for the given field extension',
 'choices': [0, 4, 2, 6],
 'answer': 1}

In [23]:
translate_question(row, 'DCAB')

{'question': 'Find the degree for the given field extension',
 'choices': [6, 2, 0, 4],
 'answer': 3}

### save main "test" config

In [71]:
test_df.apply(lambda row: translate_question(row, 'ABCD'), axis=1)

NameError: name 'test_df' is not defined

In [26]:
from itertools import permutations
import random

# Generate all possible 4-letter words with the letters ABCD
letters = ['A', 'B', 'C', 'D']
all_letter_orders = [''.join(word) for word in permutations(letters, 4)]
all_letter_orders

['ABCD',
 'ABDC',
 'ACBD',
 'ACDB',
 'ADBC',
 'ADCB',
 'BACD',
 'BADC',
 'BCAD',
 'BCDA',
 'BDAC',
 'BDCA',
 'CABD',
 'CADB',
 'CBAD',
 'CBDA',
 'CDAB',
 'CDBA',
 'DABC',
 'DACB',
 'DBAC',
 'DBCA',
 'DCAB',
 'DCBA']

In [15]:
folder_name = 'shuffled_mmlu/'
# test_df.to_parquet(folder_name+'test.parquet') #simple test to check saving system before main loop

In [58]:
for curr_dataset, category in tqdm(zip(dataset_list, configs), desc='converting categories', total=len(configs)):
    for letter_order in all_letter_orders:
        curr_df = curr_dataset.to_pandas()
        new_df = curr_df.apply(lambda row: translate_question(row, letter_order), axis=1)
        new_df.to_parquet(folder_name+f"{category}_{letter_order}_test.parquet")

converting categories:   0%|          | 0/57 [00:00<?, ?it/s]

### also save 'validation' and 'dev' splits

In [2]:
from datasets import load_dataset, get_dataset_config_names
from tqdm.auto import tqdm

configs = get_dataset_config_names("tasksource/mmlu")

# load all categories into a list
dataset_list = []
for curr_cat in tqdm(configs, desc='loading categories...'):
    dataset_list.append(load_dataset("tasksource/mmlu", curr_cat, split='validation'))

loading categories...:   0%|          | 0/57 [00:00<?, ?it/s]

In [16]:
for curr_dataset, category in tqdm(zip(dataset_list, configs), desc='converting categories', total=len(configs)):
    for letter_order in all_letter_orders:
        curr_df = curr_dataset.to_pandas()
        new_df = curr_df.apply(lambda row: translate_question(row, letter_order), axis=1)
        new_df.to_parquet(folder_name+f"{category}_{letter_order}_validation.parquet")

converting categories:   0%|          | 0/57 [00:00<?, ?it/s]

In [17]:
from datasets import load_dataset, get_dataset_config_names
from tqdm.auto import tqdm

configs = get_dataset_config_names("tasksource/mmlu")

# load all categories into a list
dataset_list = []
for curr_cat in tqdm(configs, desc='loading categories...'):
    dataset_list.append(load_dataset("tasksource/mmlu", curr_cat, split='dev'))

loading categories...:   0%|          | 0/57 [00:00<?, ?it/s]

In [18]:
for curr_dataset, category in tqdm(zip(dataset_list, configs), desc='converting categories', total=len(configs)):
    for letter_order in all_letter_orders:
        curr_df = curr_dataset.to_pandas()
        new_df = curr_df.apply(lambda row: translate_question(row, letter_order), axis=1)
        new_df.to_parquet(folder_name+f"{category}_{letter_order}_dev.parquet")

converting categories:   0%|          | 0/57 [00:00<?, ?it/s]

### Make dataset config  

Example config:  

```
First example:  
---
configs:
- config_name: tab
  data_files: "main_data.csv"
  sep: "\t"
- config_name: comma
  data_files: "additional_data.csv"
  sep: ","
---

Second example:

---
configs:
- config_name: default
  data_files:
  - split: train
    path:
    - "data/abc.csv"
    - "data/def.csv"
  - split: test
    path: "holdout/ghi.csv"
---
```

#### example YAML config making

In [21]:
import yaml

# example config
config = {
    'configs': [
        {
            'config_name': 'tab',
            'data_files': 'main_data.csv',
            'sep': '\t'
        },
        {
            'config_name': 'comma',
            'data_files': 'additional_data.csv',
            'sep': ','
        }
    ]
}

# format config dict to yaml
yaml_content = yaml.dump(config, default_flow_style=False)

# write config file
with open("dataset_config.yaml", "w") as file:
    file.write(yaml_content)

'configs:\n- config_name: tab\n  data_files: main_data.csv\n  sep: "\\t"\n- config_name: comma\n  data_files: additional_data.csv\n  sep: \',\'\n'

In [23]:
print(yaml_content)

configs:
- config_name: tab
  data_files: main_data.csv
  sep: "\t"
- config_name: comma
  data_files: additional_data.csv
  sep: ','



#### make the complete config  

We have 57 configs in the source dataset.  
We will now have 57x24=1368 configs.  
Each config has 3 splits: validation, test, dev.  

In [45]:
yaml.Dumper.ignore_aliases = lambda *args : True #use this to prevent pointer from being included in the YAML file

In [46]:
complete_config = {}
complete_config['configs'] = []

for curr_config in configs:
    for curr_letter_order in all_letter_orders:
        single_config = {}
        single_config['config_name'] = f"{curr_config}_{curr_letter_order}"
        single_config['data_files'] = []
        for curr_split in ['validation', 'test', 'dev']:
            single_split = {
                'split' : curr_split,
                'path' : f"{curr_config}_{curr_letter_order}_{curr_split}.parquet"
            }
            single_config['data_files'].append(single_split)
        complete_config['configs'].append(single_config)

In [49]:
# print(yaml.dump(complete_config, default_flow_style=False))

In [48]:
# write config file
with open("dataset_config.yaml", "w") as file:
    file.write(yaml.dump(complete_config, default_flow_style=False))

## Make a shuffled MMLU WITHOUT SPLITS - fix the "too many splits issue"  

We make a MMLU shuffled dataset without the splits for each category and shuffle.  
Instead, we will make one single config with 3 splits: validation, test, and dev.  
We will add 2 features: `category` and `letter_order`.  

In [24]:
from datasets import load_dataset, get_dataset_config_names
from tqdm.auto import tqdm

configs = get_dataset_config_names("tasksource/mmlu")

# load all categories into a list
validation_dataset_list = []
for curr_cat in tqdm(configs, desc='loading categories...'):
    validation_dataset_list.append(load_dataset("tasksource/mmlu", curr_cat, split='validation'))

test_dataset_list = []
for curr_cat in tqdm(configs, desc='loading categories...'):
    test_dataset_list.append(load_dataset("tasksource/mmlu", curr_cat, split='test'))

dev_dataset_list = []
for curr_cat in tqdm(configs, desc='loading categories...'):
    dev_dataset_list.append(load_dataset("tasksource/mmlu", curr_cat, split='dev'))

loading categories...:   0%|          | 0/57 [00:00<?, ?it/s]

loading categories...:   0%|          | 0/57 [00:00<?, ?it/s]

loading categories...:   0%|          | 0/57 [00:00<?, ?it/s]

In [27]:
import pandas as pd

final_dataset_list = []
for curr_dataset, category in tqdm(zip(validation_dataset_list, configs), desc='converting categories', total=len(configs)):
    for letter_order in all_letter_orders:
        curr_df = curr_dataset.to_pandas()
        new_df = curr_df.apply(lambda row: translate_question(row, letter_order), axis=1)
        new_df['category'] = category
        new_df['letter_order'] = letter_order
        final_dataset_list.append(new_df)
        # new_df.to_parquet(folder_name+f"{category}_{letter_order}_validation.parquet")

complete_shuffled_dataset = pd.concat(final_dataset_list)

folder_name = 'shuffled_mmlu/'
complete_shuffled_dataset.to_parquet(folder_name+f"shuffled_mmlu_validation.parquet")

converting categories:   0%|          | 0/57 [00:00<?, ?it/s]

In [28]:
final_dataset_list = []
for curr_dataset, category in tqdm(zip(test_dataset_list, configs), desc='converting categories', total=len(configs)):
    for letter_order in all_letter_orders:
        curr_df = curr_dataset.to_pandas()
        new_df = curr_df.apply(lambda row: translate_question(row, letter_order), axis=1)
        new_df['category'] = category
        new_df['letter_order'] = letter_order
        final_dataset_list.append(new_df)
        # new_df.to_parquet(folder_name+f"{category}_{letter_order}_validation.parquet")

complete_shuffled_dataset = pd.concat(final_dataset_list)

folder_name = 'shuffled_mmlu/'
complete_shuffled_dataset.to_parquet(folder_name+f"shuffled_mmlu_test.parquet")

converting categories:   0%|          | 0/57 [00:00<?, ?it/s]

In [29]:
final_dataset_list = []
for curr_dataset, category in tqdm(zip(dev_dataset_list, configs), desc='converting categories', total=len(configs)):
    for letter_order in all_letter_orders:
        curr_df = curr_dataset.to_pandas()
        new_df = curr_df.apply(lambda row: translate_question(row, letter_order), axis=1)
        new_df['category'] = category
        new_df['letter_order'] = letter_order
        final_dataset_list.append(new_df)
        # new_df.to_parquet(folder_name+f"{category}_{letter_order}_validation.parquet")

complete_shuffled_dataset = pd.concat(final_dataset_list)

folder_name = 'shuffled_mmlu/'
complete_shuffled_dataset.to_parquet(folder_name+f"shuffled_mmlu_dev.parquet")

converting categories:   0%|          | 0/57 [00:00<?, ?it/s]

#### perform a small check on the translated answer to prevent future mistakes

In [31]:
def apply_result(row):
    return row['choices'][row['answer']]

In [37]:
small_sample = complete_shuffled_dataset[complete_shuffled_dataset.question==complete_shuffled_dataset.question.to_list()[0]]
small_sample = small_sample.copy(deep=True)
small_sample['translated_answer'] = small_sample.apply(apply_result, axis=1)
small_sample

Unnamed: 0,question,choices,answer,category,letter_order,translated_answer
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[0, 1, 2, 3]",1,abstract_algebra,ABCD,1
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[0, 1, 3, 2]",1,abstract_algebra,ABDC,1
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[0, 2, 1, 3]",2,abstract_algebra,ACBD,1
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[0, 2, 3, 1]",3,abstract_algebra,ACDB,1
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[0, 3, 1, 2]",2,abstract_algebra,ADBC,1
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[0, 3, 2, 1]",3,abstract_algebra,ADCB,1
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[1, 0, 2, 3]",0,abstract_algebra,BACD,1
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[1, 0, 3, 2]",0,abstract_algebra,BADC,1
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[1, 2, 0, 3]",0,abstract_algebra,BCAD,1
0,Find all c in Z_3 such that Z_3[x]/(x^2 + c) i...,"[1, 2, 3, 0]",0,abstract_algebra,BCDA,1


In [86]:
complete_config = {
    'configs':[{
        'config_name':'default',
        'data_files':
        [
            {
                'split' : 'test',
                'path': 'shuffled_mmlu_test.parquet'
            },
            {
                'split' : 'validation',
                'path': 'shuffled_mmlu_validation.parquet'
            },
            {
                'split' : 'dev',
                'path': 'shuffled_mmlu_dev.parquet'
            },
        ]
    }]
}

print(yaml.dump(complete_config, default_flow_style=False))

configs:
- config_name: default
  data_files:
  - path: shuffled_mmlu_test.parquet
    split: test
  - path: shuffled_mmlu_validation.parquet
    split: validation
  - path: shuffled_mmlu_dev.parquet
    split: dev



Copy paste this in the YAML section of HF dataset readme file.  

## Conclusion

Good job, we have the shuffled test split online!  
Now, for completeness, we should also make the train and dev splits availiable.  
Then, investigate how to create a .YML file to indicate which .parquet files correspond to which split and which configuration.  

## Next next step  

Perform batch inference to the limits of our local GPU!  
Make the batches bigger and bigger until OOM.  
Then benchmark our benchmark inference speed on 1 inference step (config ABCD only).  
Then do the 24 version of computation times are reasonable.  
