In [21]:
import plotly.express as px
import plotly.graph_objects as go
import os
import pandas as pd
import re

In [35]:
def read_results(path: str = './linear_probes_results/'):
    results = []
    for filename in os.listdir(path):
        # only a float number in the txt file
        assert filename.endswith('.txt')
        
        with open(path + filename) as f:
            accuracy = float(f.read())
        
        pattern = r"gemma-2b-it_(?P<model_name>.+)_Layer(?P<layer>\d+)\.txt"
        match = re.search(pattern, filename)
        if match:
            model_name, layer = match.group('model_name'), int(match.group('layer'))
        else:
            raise ValueError(f"Filename {filename} does not match the pattern {pattern}")
        
        results.append((model_name, layer, accuracy))
                
    return results

In [36]:
# convert the dictionary to a dataframe, with model, layer, accuracy as columns
results = read_results()
df = pd.DataFrame(results, columns=['model', 'layer', 'accuracy'])
df.sort_values(by=['model', 'layer'], inplace=True)


In [37]:
# plot
fig = px.line(df, x='layer', y='accuracy', color='model', title='Probe Accuracy on 172 WMDP-bio questions with all permutations')

# Add a separate trace for the random baseline
fig.add_trace(go.Scatter(
    x=df['layer'], y=[0.25]*len(df), mode='lines',
    line=dict(dash='dot', color='grey'),
    name='Random Chance'
))

# let the y-axis start from 0 to 1
fig.update_yaxes(range=[0.2, 1.05])
fig.show()