In [1]:
import sys
from pathlib import Path

# Add the src directory to Python path
src_path = str(Path("./").resolve().parent.parent)
if src_path not in sys.path:
    sys.path.append(src_path)

sys.path

['/Library/Frameworks/Python.framework/Versions/3.11/lib/python311.zip',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload',
 '',
 '/Users/anton/dev/MARS/correctness-model-internals/venv/lib/python3.11/site-packages',
 '/Users/anton/dev/MARS/correctness-model-internals']

In [2]:
from collections import defaultdict

import pandas as pd

from src.classifying import (
    ActivationsHandler,
    combine_activations_handlers,
    get_correctness_direction_classifier,
    get_logistic_regression_classifier,
    get_between_class_variance_and_within_class_variance,
)
from src.visualisations.utils import plot_interactive_lineplot
from src.utils.data import load_activations, load_labels, get_experiment_activations_configs_df_subset


In [3]:
BASE_PATH = "../../data_for_classification"

In [4]:
# Set None to get all
MODEL_ID = "llama3.1_8b_chat"
# DATASET_ID = "gsm8k"
# PROMPT_ID = "cot_3_shot"
# SUBSET_ID = "main"
# INPUT_TYPE = "prompt_only"
DATASET_ID = None
PROMPT_ID = None
SUBSET_ID = None
INPUT_TYPE = None

activation_exp_configs_df = get_experiment_activations_configs_df_subset(
    base_path=BASE_PATH,
    model_id=MODEL_ID,
    dataset_id=DATASET_ID,
    prompt_id=PROMPT_ID,
    subset_id=SUBSET_ID,
    input_type=INPUT_TYPE,
)
activation_exp_configs_df

Unnamed: 0,model_id,dataset_id,prompt_id,subset_id,input_type,layer,path
0,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,0,../../data_for_classification/activations/llam...
1,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,7,../../data_for_classification/activations/llam...
2,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,9,../../data_for_classification/activations/llam...
3,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,17,../../data_for_classification/activations/llam...
4,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,28,../../data_for_classification/activations/llam...
...,...,...,...,...,...,...,...
379,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,23,../../data_for_classification/activations/llam...
380,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,15,../../data_for_classification/activations/llam...
381,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,12,../../data_for_classification/activations/llam...
382,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,2,../../data_for_classification/activations/llam...


In [6]:
res_dict = defaultdict(list)

for (model_id, prompt_id, subset_id, input_type), config_df in activation_exp_configs_df.groupby(["model_id", "prompt_id", "subset_id", "input_type"]):
    print(f"\n{model_id=}, {prompt_id=}, {subset_id=}, {input_type=}")
    
    config_df["layer"] = config_df["layer"].astype(int)
    config_df = config_df.sort_values(by="layer")

    check_indices = {}
    for layer, layer_config_df in config_df.groupby("layer"):
        print(f"{layer=}", end=", ")

        dataset_activation_handlers = {}
        for _, row in layer_config_df.iterrows():
            dataset_id = row["dataset_id"]

            labels_df = load_labels(
                base_path=BASE_PATH,
                model_id=model_id,
                dataset_id=dataset_id,
                prompt_id=prompt_id,
                subset_id=subset_id,
            )
            activations, indices = load_activations(
                base_path=BASE_PATH,
                model_id=model_id,
                dataset_id=dataset_id,
                prompt_id=prompt_id,
                subset_id=subset_id,
                input_type=input_type,
                layer=layer,
            )

            if dataset_id not in check_indices:
                check_indices[dataset_id] = indices.sample(frac=1, replace=False)
            
            if set(indices) != set(check_indices[dataset_id]):
                raise RuntimeError(f"indices across layers are not the same")

            labels_df_subset = labels_df.iloc[check_indices[dataset_id]].reset_index(drop=True)
            activations_subset = activations[check_indices[dataset_id]]

            # print(f"{dataset_id=}, {labels_df_subset['correct'].astype(bool).value_counts()=}")

            activations_handler = ActivationsHandler(
                activations=activations_subset,
                labels=labels_df_subset["correct"].astype(bool),
            ).sample_equally_across_groups(
                group_labels=[False, True], interleave=True
            )

            # print("activations_handler", activations_handler.labels.value_counts(), end="\n\n")

            activations_handler_folds = list(
                activations_handler.split_dataset(split_sizes=[1/5] * 5)
            )

            # print("activations_handler_folds", [ah.labels.value_counts() for ah in activations_handler_folds], end="\n\n\n\n\n")


            dataset_activation_handlers[dataset_id] = activations_handler_folds
            
        # if check_indices is None:
        #     check_indices = indices.sample(frac=1, replace=False)
        
        # if set(indices) != set(check_indices):
        #     raise RuntimeError(f"indices across layers are not the same")

        # labels_df_subset = labels_df.iloc[check_indices]
        # activations = activations[check_indices]

        # activations_handler = ActivationsHandler(
        #     activations=activations,
        #     labels=labels_df_subset["correct"].astype(bool),
        # )

        # activations_handler_folds = list(
        #     activations_handler.split_dataset(split_sizes=[1/5] * 5)
        # )

        fold_stats = {}
        for fold_i in range(len(list(dataset_activation_handlers.values())[0])):
            train_activations_handlers, test_activations_handlers = {}, {}
            for dataset_id, ah_folds in dataset_activation_handlers.items():

                test_activations_handlers[dataset_id] = ah_folds[fold_i].sample_equally_across_groups(
                    group_labels=[False, True], interleave=True
                )

                if dataset_id == "football_leagues_1k":
                    continue

                train_activations_handlers[dataset_id] = [ah for j, ah in enumerate(ah_folds) if j != fold_i]
            
            activations_handler_train = combine_activations_handlers(
                sum(list(train_activations_handlers.values()), []), # combine the lists
                equal_counts=True
            ).sample_equally_across_groups(
                group_labels=[False, True], interleave=True
            )

            scaler_model_tuple = None
            for test_dataset_id, activations_handler_test in test_activations_handlers.items():
                res_dict["model_id"].append(model_id)
                res_dict["train_dataset_ids"].append(list(train_activations_handlers.keys()))
                res_dict["test_dataset_id"].append(test_dataset_id)
                res_dict["prompt_id"].append(prompt_id)
                res_dict["subset_id"].append(subset_id)
                res_dict["input_type"].append(input_type)
                res_dict["layer"].append(layer)
                res_dict["fold"].append(fold_i)

                direction_classifier, direction_calculator = get_correctness_direction_classifier(
                    activations_handler_train=activations_handler_train,
                    activations_handler_test=activations_handler_test,
                )
                # res_dict["classifying_direction"].append(direction_calculator.classifying_direction.tolist())
                for key, value in direction_classifier.classification_metrics.items():
                    res_dict[f"direction_{key}"].append(value)
                
                logistic_regression_classifier, scaler_model_tuple = get_logistic_regression_classifier(
                    activations_handler_train=activations_handler_train,
                    activations_handler_test=activations_handler_test,
                    scaler_model_tuple=scaler_model_tuple,
                )
                for key, value in logistic_regression_classifier.classification_metrics.items():
                    res_dict[f"logistic_regression_{key}"].append(value)



res_df = pd.DataFrame(res_dict)
res_df


model_id='llama3.1_8b_chat', prompt_id='base', subset_id='main', input_type='prompt_answer'
layer=0, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=1, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=2, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=3, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


layer=4, layer=5, layer=6, layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, layer=15, layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, layer=31, 
model_id='llama3.1_8b_chat', prompt_id='base', subset_id='main', input_type='prompt_only'
layer=0, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=1, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=2, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=3, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=4, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


layer=5, layer=6, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=15, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=31, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



model_id='llama3.1_8b_chat', prompt_id='base_3_shot', subset_id='main', input_type='prompt_answer'
layer=0, layer=1, layer=2, layer=3, layer=4, layer=5, layer=6, layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, layer=15, layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, layer=31, 
model_id='llama3.1_8b_chat', prompt_id='base_3_shot', subset_id='main', input_type='prompt_only'
layer=0, layer=1, layer=2, layer=3, layer=4, layer=5, layer=6, layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, layer=15, layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, layer=31, 
model_id='llama3.1_8b_chat', prompt_id='cot_3_shot', subset_id='main', input_type='prompt_answer'
layer=0, layer=1, layer=2, layer=3, layer=4, layer=5, layer=6, layer=7, layer=8, la

Unnamed: 0,model_id,train_dataset_ids,test_dataset_id,prompt_id,subset_id,input_type,layer,fold,direction_optimal_cut,direction_optimal_train_set_cut,...,direction_f1_score,direction_precision_score,direction_recall_score,logistic_regression_optimal_cut,logistic_regression_optimal_train_set_cut,logistic_regression_test_roc_auc,logistic_regression_accuracy_score,logistic_regression_f1_score,logistic_regression_precision_score,logistic_regression_recall_score
0,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",football_leagues_1k,base,main,prompt_answer,0,0,-0.039674,-0.039674,...,0.619718,0.578947,0.666667,0.5,0.509670,0.461892,0.560606,0.579710,0.555556,0.606061
1,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",medals_9k,base,main,prompt_answer,0,0,-0.039674,-0.039674,...,0.639019,0.551411,0.759722,0.5,0.509670,0.749261,0.692361,0.681524,0.706408,0.658333
2,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",birth_years_4k,base,main,prompt_answer,0,0,-0.039674,-0.039674,...,0.618421,0.536122,0.730570,0.5,0.509670,0.639910,0.598446,0.613466,0.591346,0.637306
3,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",cities_10k,base,main,prompt_answer,0,0,-0.039674,-0.039674,...,0.662810,0.585401,0.763810,0.5,0.509670,0.674569,0.623810,0.642534,0.612069,0.676190
4,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",football_leagues_1k,base,main,prompt_answer,0,1,-0.072958,-0.072958,...,0.622222,0.491228,0.848485,0.5,0.523588,0.451791,0.469697,0.385965,0.458333,0.333333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1915,llama3.1_8b_chat,[gsm8k],gsm8k,cot_3_shot,main,prompt_only,31,0,4.554289,4.554289,...,0.422346,0.535411,0.348708,0.5,0.459932,0.524077,0.535978,0.549687,0.533913,0.566421
1916,llama3.1_8b_chat,[gsm8k],gsm8k,cot_3_shot,main,prompt_only,31,1,2.064851,2.064851,...,0.555357,0.538062,0.573801,0.5,0.488465,0.531844,0.536900,0.565744,0.532573,0.603321
1917,llama3.1_8b_chat,[gsm8k],gsm8k,cot_3_shot,main,prompt_only,31,2,6.805744,6.805744,...,0.382353,0.569343,0.287823,0.5,0.441263,0.530897,0.540590,0.567708,0.536066,0.603321
1918,llama3.1_8b_chat,[gsm8k],gsm8k,cot_3_shot,main,prompt_only,31,3,1.761393,1.761393,...,0.532374,0.519298,0.546125,0.5,0.484587,0.529377,0.537823,0.571429,0.532695,0.616236


In [7]:
res_file = "./classification_data/res_df_llama31_8B_4_memory_datasets_combined_train_sets.csv"


In [8]:
res_df.to_csv(res_file, index=False)

In [9]:
res_df = pd.read_csv(res_file)

In [10]:
for classifier in ["direction", "logistic_regression"]:
    # for metric in ["f1_score", "accuracy_score", "precision_score", "recall_score"]:
    for metric in ["f1_score"]:
        plot_dict = {}
        for conf, res_df_ in res_df.groupby(["model_id", "test_dataset_id", "prompt_id", "subset_id", "input_type"]):
            # if conf[4] != "prompt_only":
            #     continue
            
            print(f"{conf=}")
            res_df_pivot = pd.pivot(
                res_df_.drop(columns=["model_id", "test_dataset_id", "prompt_id", "subset_id", "input_type"]),
                index='layer',
                columns='fold',
                # values=['direction_f1_score', 'logistic_regression_f1_score']  # add all metrics you want to keep
            )
            # for classifier in ["direction", "logistic_regression"]:
            #     for metric in ["f1_score", "accuracy_score", "precision_score", "recall_score"]:
            plot_dict[str(conf)] = res_df_pivot[[f"{classifier}_{metric}"]]

        plot_interactive_lineplot(
            plot_dict,
            x_label="Layer",
            y_label=f"{classifier}_{metric}".replace("_", " ").title(),
            save_path=f"./classification_data/figures/res_df_llama31_8B_4_memory_datasets_combined_train_sets_{classifier}_{metric}.html"
        ).show()



conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'gsm8k', 'base_3_shot', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'gsm8k', 'base_3_shot', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'gsm8k', 'cot_3_shot', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'gsm8k', 'cot_3_shot', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_only')


conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'gsm8k', 'base_3_shot', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'gsm8k', 'base_3_shot', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'gsm8k', 'cot_3_shot', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'gsm8k', 'cot_3_shot', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_only')
