# Experiments

In [None]:
## GENERATE SINGLE-TASK FT DATA

import os
os.makedirs("ft_data", exist_ok=True)

bases = [8, 9, 10]
train_n_digits = [2]
train_types = ["ft_lora"]

command_file = "train_commands.sh"


## datagen code
for base in bases:
    for n_digits in train_n_digits:    
        command = f"python sample_pipe.py --base {base} --n_digits {2} --n_samples {1000}"
        with open(command_file, "a") as f:
            f.write(command + "\n")

In [None]:
## GENERATE LORA FT COMMANDS

## training code
for cot in ["--cot", ""]:
    for train_type in train_types:
        for base in bases:
            for n_digits in train_n_digits:
                command = f"python {train_type}.py --base {base} --model_path unsloth/Phi-4 {cot} --data_file ft_data/data_ft_{base}_{n_digits}.txt"
                with open(command_file, "a") as f:
                    f.write(command + "\n")

In [None]:
## RUN TRAINING

#bash train_commands.sh 

In [None]:
## COMPUTE EVALS

eval_checkpoints = [63]
ks = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16]
command_file = "eval_commands.sh"
size = 250
                
## sft evals
models = os.listdir("outputs")
for checkpoint in eval_checkpoints:
        for n_digits in [2,3,4]:
            for model in models:
                if "8" in model:
                    base = 8
                elif "9" in model:
                    base = 9
                elif "10" in model:
                    base = 10
                command = f"python eval_pipe.py --base {base} --model_name outputs/{model}/checkpoint-{checkpoint} --size {size} --n_digits {n_digits}"
                with open(command_file, "a") as f:
                    f.write(command + "\n")
                    f.write("rm output.txt" + "\n")

## sft evals
models = os.listdir("outputs")
for checkpoint in eval_checkpoints:
    for base in bases:
        for n_digits in [2]:
            for model in models:
                command = f"python eval_pipe.py --base {base} --model_name outputs/{model}/checkpoint-{checkpoint} --size {size} --n_digits {n_digits}"
                with open(command_file, "a") as f:
                    f.write(command + "\n")
                    f.write("rm output.txt" + "\n")

# icl evals
for icl_cot in [True, False]:
    for k in ks:
        for base in bases:
            for n_digits in [2]:
                command = f"python eval_pipe.py --base {base} --model_name unsloth/Phi-4 --size {250} --n_digits {n_digits} --n_shots {k} --icl_cot {icl_cot}"
                with open(command_file, "a") as f:
                    f.write(command + "\n")
                    f.write("rm output.txt" + "\n")
                    
# baseline evals
for base in bases:
    for n_digits in [2, 3, 4]:
        command = f"python eval_pipe.py --base {base} --model_name unsloth/Phi-4 --size {size} --n_digits {n_digits}"
        with open(command_file, "a") as f:
            f.write(command + "\n")
            f.write("rm output.txt" + "\n")

In [None]:
## RUN EVALS

#bash eval_commands.sh 

# Analysis

## utils

In [None]:
## parse outputs
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


text = open("results.txt").read().split("\n")

def get_model(text):

    o = ""
    if "_8_" in text:
        o += "8"
    if "_9_" in text:
        o += "9"
    if "_10_" in text:
        o += "10"
    
    if "_True_" in text:
        o += "_cot"

    return o

df = pd.DataFrame()
df['sft_model'] = ["phi_base" if "unsloth" in c else get_model(c.split("/")[1]) for c in text]
df['eval_base'] = [int(c.split(" ")[1]) for c in text]
df['eval_n_digits'] = [int(c.split(" ")[3]) for c in text]
df['eval_k'] = [int(c.split(" ")[4]) for c in text]
df['eval_icl_cot'] = [c.split(" ")[5][:-1] for c in text]
df['eval_acc'] = [float(c.split(": ")[-1]) for c in text]
df['eval_icl_cot'] = df['eval_icl_cot'].astype(bool)

df.to_csv("data.csv",index=False)

df = pd.read_csv("data.csv")

In [None]:
def get_eval(df,model="phi_base", eval_base=8, eval_n_digits=2, eval_k=0, eval_icl_cot=None):
    df_narrow = df[df['sft_model'] == model]
    df_narrow = df_narrow[df_narrow['eval_base'] == eval_base]
    df_narrow = df_narrow[df_narrow['eval_n_digits'] == eval_n_digits]
    if eval_icl_cot is not None:
        df_narrow = df_narrow[df_narrow['eval_icl_cot'] == eval_icl_cot]
    df_narrow = df_narrow[df_narrow['eval_k'] == eval_k]

    return df_narrow['eval_acc'].values

## Visuals

In [None]:
## k-shot graph
from fit_k import PhaseTransitionAnalyzer
analyzer = PhaseTransitionAnalyzer()

_, popt_8 = analyzer.analyze_single_dataset(x, y_8, 
                                           "Base 8", plot=False)
analyzer.print_summary("Base 8")

_, popt_9 = analyzer.analyze_single_dataset(x, y_9, 
                                           "Base 9", plot=False)
analyzer.print_summary("Base 9")

_, popt_10 = analyzer.analyze_single_dataset([0]+x, [1.0]+y_10, 
                                           "Base 10", plot=False)
analyzer.print_summary("Base 10")

In [None]:
x = [1,2,3,4,5,6,8,10,12,16,32]
y_8 = [get_eval(df, eval_base=8, eval_k=k_i, eval_icl_cot=True)[0] for k_i in x] 
y_9 = [get_eval(df, eval_base=9, eval_k=k_i, eval_icl_cot=True)[0] for k_i in x]
y_10 = [get_eval(df, eval_base=10, eval_k=k_i, eval_icl_cot=True)[0] for k_i in x]

sns.lineplot(x=x, y=y_8, label="Base 8")
sns.lineplot(x=x, y=y_9, label="Base 9")
sns.lineplot(x=x, y=y_10, label="Base 10")


k_fine = np.linspace(np.min(x), np.max(x), 1000)
y_fit = analyzer.logistic_function(k_fine, *popt_8)
plt.plot(k_fine, y_fit, 'r--', linewidth=1, label='Logistic fit')

k_fine = np.linspace(np.min(x), np.max(x), 1000)
y_fit = analyzer.logistic_function(k_fine, *popt_9)
plt.plot(k_fine, y_fit, 'r--', linewidth=1)


k_fine = np.linspace(np.min(x), np.max(x), 1000)
y_fit = analyzer.logistic_function(k_fine, *popt_10)
plt.plot(k_fine, y_fit, 'r--', linewidth=1)


plt.xlabel("$k_{CoT}$")
plt.ylabel("Accuracy")
plt.title("Accuracy of ID $k$-shot Base Addition")
plt.legend(loc='lower right')
plt.show()

In [None]:
## delta by cross/ood base
baselines = [get_eval(df, model='phi_base',eval_base=b)[0] for b in [8,9, 10]]

sft_8 = [get_eval(df, model='8',eval_base=b)[0] for b in [8,9,10]]
sft_8_cot = [get_eval(df, model='8_cot',eval_base=b)[0] for b in [8,9,10]]

sft_9 = [get_eval(df, model='9',eval_base=b)[0] for b in [8,9,10]]
sft_9_cot = [get_eval(df, model='9_cot',eval_base=b)[0] for b in [8,9,10]]

sft_10 = [get_eval(df, model='10',eval_base=b)[0] for b in [8,9,10]]
sft_10_cot = [get_eval(df, model='10_cot',eval_base=b)[0] for b in [8,9,10]]

data1_group1 = 100*(np.array(sft_8_cot)-np.array(baselines))
data1_group2 = 100*(np.array(sft_8)-np.array(baselines))   

data2_group1 = 100*(np.array(sft_9_cot)-np.array(baselines))
data2_group2 = 100*(np.array(sft_9)-np.array(baselines))

data3_group1 = 100*(np.array(sft_10_cot)-np.array(baselines))
data3_group2 = 100*(np.array(sft_10)-np.array(baselines))

In [None]:
categories = ['Base 8', 'Base 9', 'Base 10']

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

x = np.arange(len(categories))
width = 0.3

axes[0].bar(x - width/2, data1_group1, width, label='CoT', alpha=0.8)
axes[0].bar(x + width/2, data1_group2, width, label='no CoT', alpha=0.8)
axes[0].set_title('SFT on Base 8')
axes[0].set_xlabel('Eval Base')
axes[0].set_ylabel('$\delta$ Accuracy Points')
axes[0].set_xticks(x)
axes[0].set_xticklabels(categories)
axes[0].legend()

axes[1].bar(x - width/2, data2_group1, width, label='CoT', alpha=0.8)
axes[1].bar(x + width/2, data2_group2, width, label='no CoT', alpha=0.8)
axes[1].set_title('SFT on Base 9')
axes[1].set_xlabel('Eval Base')
axes[1].set_ylabel('$\delta$ Accuracy Points')
axes[1].set_xticks(x)
axes[1].set_xticklabels(categories)
axes[1].legend()

axes[2].bar(x - width/2, data3_group1, width, label='CoT', alpha=0.8)
axes[2].bar(x + width/2, data3_group2, width, label='no CoT', alpha=0.8)
axes[2].set_title('SFT on Base 10')
axes[2].set_xlabel('Eval Base')
axes[2].set_ylabel('$\delta$ Accuracy Points')
axes[2].set_xticks(x)
axes[2].set_xticklabels(categories)
axes[2].legend()

plt.tight_layout()
plt.show()

In [None]:
## delta by cross/ood scope/digits

baselines_8 = [get_eval(df, model='phi_base',eval_base=8, eval_n_digits=b)[0] for b in [2, 3,4]]
sft_8_cot = [get_eval(df, model='8_cot',eval_base=8, eval_n_digits=b)[0] for b in [2, 3,4]]
sft_8_ncot = [get_eval(df, model='8',eval_base=8, eval_n_digits=b)[0] for b in [2, 3,4]]

baselines_9 = [get_eval(df, model='phi_base',eval_base=9, eval_n_digits=b)[0] for b in [2, 3,4]]
sft_9_cot = [get_eval(df, model='9_cot',eval_base=9, eval_n_digits=b)[0] for b in [2, 3,4]]
sft_9_ncot = [get_eval(df, model='9',eval_base=9, eval_n_digits=b)[0] for b in [2, 3,4]]

baselines_10 = [get_eval(df, model='phi_base',eval_base=10, eval_n_digits=b)[0] for b in [2, 3,4]]
sft_10_cot = [get_eval(df, model='10_cot',eval_base=10, eval_n_digits=b)[0] for b in [2, 3,4]]
sft_10_ncot = [get_eval(df, model='10',eval_base=10, eval_n_digits=b)[0] for b in [2, 3,4]]


categories = ['2-Digit', '3-Digit', '4-Digit']
data1_group1 = 100*(np.array(sft_8_cot)-np.array(baselines_8))
data1_group2 = 100*(np.array(sft_8_ncot)-np.array(baselines_8))   

data2_group1 = 100*(np.array(sft_9_cot)-np.array(baselines_9))
data2_group2 = 100*(np.array(sft_9_ncot)-np.array(baselines_9))

data3_group1 = 100*(np.array(sft_10_cot)-np.array(baselines_10))
data3_group2 = 100*(np.array(sft_10_ncot)-np.array(baselines_10))

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

x = np.arange(len(categories))
width = 0.35

axes[0].bar(x - width/2, data1_group1, width, label='CoT', alpha=0.8)
axes[0].bar(x + width/2, data1_group2, width, label='no CoT', alpha=0.8)
axes[0].set_title('SFT on Base 8')
axes[0].set_xlabel('Eval Base')
axes[0].set_ylabel('$\delta$ Accuracy Points')
axes[0].set_xticks(x)
axes[0].set_xticklabels(categories)
axes[0].legend()

axes[1].bar(x - width/2, data2_group1, width, label='CoT', alpha=0.8)
axes[1].bar(x + width/2, data2_group2, width, label='no CoT', alpha=0.8)
axes[1].set_title('SFT on Base 9')
axes[1].set_xlabel('Eval Base')
axes[1].set_ylabel('$\delta$ Accuracy Points')
axes[1].set_xticks(x)
axes[1].set_xticklabels(categories)
axes[1].legend()

axes[2].bar(x - width/2, data3_group1, width, label='CoT', alpha=0.8)
axes[2].bar(x + width/2, data3_group2, width, label='no CoT', alpha=0.8)
axes[2].set_title('SFT on Base 10')
axes[2].set_xlabel('Eval Base')
axes[2].set_ylabel('$\delta$ Accuracy Points')
axes[2].set_xticks(x)
axes[2].set_xticklabels(categories)
axes[2].legend()

plt.tight_layout() 
plt.show()