### Linear Probe Exploration
Code to look at the initial linear probe results and see if there's anything interesting going on in the error patterns.
- Part 1: Error analysis within given test passages
- Part 2: Identifying "words" by most common successes(?)
- ~~Part 3: n-gram exploration of dataset, compared with linear probe "words."~~ (see ngrams.py)

### Part 0: Setup

In [45]:
!pwd
!pip install -e ..

/data/sheridan_feucht/lexicon/scripts


[33mDEPRECATION: Loading egg at /data/sheridan_feucht/anaconda3/envs/basic/lib/python3.11/site-packages/huggingface_hub-0.17.1-py3.8.egg is deprecated. pip 23.3 will enforce this behaviour change. A possible replacement is to use pip for package installation..[0m[33m
[0mObtaining file:///data/sheridan_feucht/lexicon
  Preparing metadata (setup.py) ... [?25ldone
[?25hInstalling collected packages: modules
  Attempting uninstall: modules
    Found existing installation: modules 1.0
    Uninstalling modules-1.0:
      Successfully uninstalled modules-1.0
  Running setup.py develop for modules
Successfully installed modules-1.0


In [46]:
!pip install -e ../llama

[33mDEPRECATION: Loading egg at /data/sheridan_feucht/anaconda3/envs/basic/lib/python3.11/site-packages/huggingface_hub-0.17.1-py3.8.egg is deprecated. pip 23.3 will enforce this behaviour change. A possible replacement is to use pip for package installation..[0m[33m
[0mObtaining file:///data/sheridan_feucht/lexicon/llama
  Preparing metadata (setup.py) ... [?25ldone
Installing collected packages: llama
  Attempting uninstall: llama
    Found existing installation: llama 0.0.1
    Uninstalling llama-0.0.1:
      Successfully uninstalled llama-0.0.1
  Running setup.py develop for llama
Successfully installed llama-0.0.1


In [47]:
import os
import torch
import pandas as pd
from modules.training import LinearModel
from llama import Tokenizer

### Part 1: Error analysis
Choose the most successful probe across all the layers and load that in based on its wandb id. Then, this section allows for an exploration of what errors it's making.

In [48]:
# wandb_id = "f7q76y13" #LAYER17-TGTIDX-1-train_tiny_1000-bsz512-lr0.07636-epochs96-f7q76y13
wandb_id = "LAYER3-TGTIDX-1-train_small_5000-bsz1-lr0.00608-epochs74-ak6qyq6v"
# wandb_id = "LAYER29-TGTIDX-1-train_small_5000-bsz10-lr0.00473-epochs60-9vqlvyf0"
tokenizer = Tokenizer(model_path="../llama/tokenizer.model")

def retrieve_run_info(wid, location, desired_file):
    for (dirpath, _, _) in os.walk(location):
        if wid in dirpath:
            print(dirpath)
            return os.path.join(dirpath, desired_file)

In [49]:
# load in the error csv. select a specific doc_idx 
output_path = retrieve_run_info(wandb_id, "../logs/llama-2-7b", "test_tiny_500_results.csv")
outputs = pd.read_csv(output_path) #dtype={'doc_id': int}
outputs = outputs.fillna("")
assert(len(outputs['doc_id'].unique()) == 500)
# print(display(outputs.loc[outputs['doc_id'].str.contains('com').fillna(False)]))

../logs/llama-2-7b/LAYER3-TGTIDX-1-train_small_5000-bsz1-lr0.00608-epochs74-ak6qyq6v


In [50]:
# calculate top-5 accuracy for this run.
print("accuracy", len(outputs.loc[(outputs['actual_tok_id'] == outputs['predicted_tok_id'])]) / len(outputs))
print("top2 accuracy", len(outputs.loc[(outputs['actual_tok_id'] == outputs['top_1_tok_id']) | (outputs['actual_tok_id'] == outputs['top_2_tok_id'])])/ len(outputs))
print("top3 accuracy", len(outputs.loc[(outputs['actual_tok_id'] == outputs['top_1_tok_id']) | (outputs['actual_tok_id'] == outputs['top_2_tok_id']) | (outputs['actual_tok_id'] == outputs['top_3_tok_id'])])/ len(outputs))
print("top4 accuracy", len(outputs.loc[(outputs['actual_tok_id'] == outputs['top_1_tok_id']) | (outputs['actual_tok_id'] == outputs['top_2_tok_id']) | (outputs['actual_tok_id'] == outputs['top_3_tok_id']) | (outputs['actual_tok_id'] == outputs['top_4_tok_id'])])/ len(outputs))
print("top4 accuracy", len(outputs.loc[(outputs['actual_tok_id'] == outputs['top_1_tok_id']) | (outputs['actual_tok_id'] == outputs['top_2_tok_id']) | (outputs['actual_tok_id'] == outputs['top_3_tok_id']) | (outputs['actual_tok_id'] == outputs['top_4_tok_id']) | (outputs['actual_tok_id'] == outputs['top_5_tok_id'])])/ len(outputs))

accuracy 0.3331474688187821
top2 accuracy 0.36273661041819516
top3 accuracy 0.3774211298606016
top4 accuracy 0.38800073367571536
top4 accuracy 0.3967351430667645


In [51]:
# qualitative analysis where we just look at the top 5 for incorrect ones and see what's being predicted
# sometimes the predicted_tok col is not the same as top_1_tok col but that's just whenever the top two probabilities are the same when rounded, so top1 and top2 are swapped. doesn't change results.
# print(outputs.loc[outputs['predicted_tok_id'] != outputs['top_1_tok_id']][['predicted_tok', 'top_1_tok', 'top_2_tok', 'top_1_prob', 'top_2_prob']]) 
top5s = outputs.loc[(outputs['actual_tok'] != outputs['predicted_tok'])][['doc_id', 'current_tok', 'actual_tok', 'top_1_tok', 'top_2_tok', 'top_3_tok', 'top_4_tok', 'top_5_tok']]
print(len(top5s['doc_id'].unique()))
top5s.sample(20)
    

499


Unnamed: 0,doc_id,current_tok,actual_tok,top_1_tok,top_2_tok,top_3_tok,top_4_tok,top_5_tok
6838,13,?,offs,s,),3,goods,0
165999,303,in,their,the,s,in,a,port
87658,155,te,ed,a,of,s,the,-
179767,328,and,season,",",ta,goods,it,and
131164,242,",",England,Massachusetts,own,",",7,6
50793,91,fly,television,fle,a,s,",",.
15798,32,p,ur,to,,-,for,ing
177323,325,of,ach,ing,s,a,to,of
206731,384,.,School,Bay,?,す,0,ary
123188,224,the,make,to,that,for,with,ordered


In [52]:
# It really should be getting this right if it's doing something more than common ngrams, right? 
outputs.iloc[[172694, 172695]][['doc_id', 'current_tok', 'actual_tok', 'top_1_tok', 'top_2_tok', 'top_3_tok', 'top_4_tok', 'top_5_tok']]

Unnamed: 0,doc_id,current_tok,actual_tok,top_1_tok,top_2_tok,top_3_tok,top_4_tok,top_5_tok
172694,316,Afghan,to,to,of,the,in,a
172695,316,istan,Afghan,to,in,of,the,0


In [53]:
# there's one doc that was just a BOS and newline, which the probe got right. 
completely_right = []
for docid in range(len(outputs['doc_id'].unique())):
    d = outputs.loc[outputs['doc_id']==docid]
    if len(d.loc[d['actual_tok']==d['predicted_tok']]) == len(d):
        print(docid, len(d))
        completely_right.append(d)
completely_right[0]

170 2


Unnamed: 0.1,Unnamed: 0,doc_id,current_tok_id,actual_tok_id,predicted_tok_id,current_tok,actual_tok,predicted_tok,sourcehs_logitlens_tok_id,sourcehs_logitlens_tok,...,top_2_tok,top_3_prob,top_3_tok_id,top_3_tok,top_4_prob,top_4_tok_id,top_4_tok,top_5_prob,top_5_tok_id,top_5_tok
94622,94622,170,29871,1,1,,,,21275,penas,...,\n,0.00024,29897,),0.00018,29889,.,0.000175,29901,:
94623,94623,170,13,29871,29871,\n,,,24366,sierp,...,:,0.000259,11474,---,0.000259,6211,Bay,0.000157,17435,***


In [54]:
begin_span = tokenizer.encode('<span>', bos=False, eos=False)
end_span = tokenizer.encode('</span>', bos=False, eos=False)
print(begin_span, end_span)

[529, 9653, 29958] [1533, 9653, 29958]


In [55]:
# Functions that help interpret the predictions made by this probe for a specific document.
def doc_accuracy(did):
    specific_doc = outputs.loc[outputs['doc_id']==did]
    corr = len(specific_doc.loc[specific_doc['predicted_tok_id']==specific_doc['actual_tok_id']])
    return round(100 * corr / len(specific_doc), 2)

def doc_bigrams(did):
    out = []
    specific_doc = outputs.loc[outputs['doc_id']==did]
    for i, row in specific_doc.iterrows():
        if row['predicted_tok_id'] == row['actual_tok_id']:
            print(repr(row['predicted_tok']), repr(row['current_tok']))
            out.append((row['predicted_tok'], row['current_tok']))
    return out     

def view_doc(did, as_html=False):
    specific_doc = outputs.loc[outputs['doc_id']==did]
    out = [1]
    for i, row in specific_doc.iterrows():
        if row['predicted_tok_id'] == row['actual_tok_id']:
            if as_html:
                if out[-3:] != end_span:
                    out = out[:-1]
                    out += [*begin_span, int(row['predicted_tok_id']), int(row['current_tok_id']), *end_span]
                else: # target token was already part of a span
                    out = out[:-4]
                    out += [int(row['predicted_tok_id']), int(row['current_tok_id']), *end_span]
            
            else: # not as html
                if out[-1] != 4514:
                    out = out[:-1]
                    out += [518, int(row['predicted_tok_id']), int(row['current_tok_id']), 4514]
                else: # target token was already part of a span
                    out = out[:-2]
                    out += [int(row['predicted_tok_id']), int(row['current_tok_id']), 4514]
        
        else: # don't need to bracket
            out.append(int(row['current_tok_id']))
    return tokenizer.decode(out)

view_doc(6)

'[ 1. Field ] [ of the In ]vention [\nThis ] invention relates to peak respiratory flow monitoring [. More ] specifically [, this ] invention relates to peak flow monitoring [ of individuals ] with obstructive respiratory diseases [ and other ] conditions [ and includes ] apparatus [ for, and methods ] [ of, using ] this technique [ in the home ] [ and office ] [ as well ] as enhancing patient [ and phys ]ician access [ to calculated ] medical data computed therein [.\n2. Stat ]ement [ of the Art ] [\nPe ]ak respiratory flow monitoring [ of patients ] with obstructive respiratory diseases or conditions [, such ] as asthma [, has been ] available [ for many ] years [. I ]nexpensive mechanical peak flow meters have been used for patient home [ and office ] use [ and many ] patients [ have been ] taught [ to record ] their daily peak flow values [ and their ] symptoms [ in a personal ] diary [. The ] physician then reviews [ the entries ] [ in a patient ]"" [s di ]ary during regular offic

In [56]:
from IPython.display import HTML

def html_explore(doc_id, dark=True):
    span = "#4f4f4f" if dark else "#afdeb2"
    hover = "green" if dark else "#3fd449"

    html_code = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <style>
            #highlight-paragraph {{
            font-size: 18px;
            }}

            #highlight-paragraph span {{
            background-color: {span};
            transition: background-color 0.3s;
            }}

            #highlight-paragraph span:hover {{
            background-color: {hover}; /* Change the background color on hover */
            }}
        </style>
    </head>
    <body>
        <h1>Document ID {doc_id} (Probe Acc: {doc_accuracy(doc_id)}%)</h1>
        <h3>{wandb_id}</h3>
        <div id="highlight-paragraph">
            {view_doc(doc_id, as_html=True)}
        </div>

    </body>
    </html>
    """
    print(html_code)
    display(HTML(html_code))

In [57]:
html_explore(6, dark=False)


    <!DOCTYPE html>
    <html>
    <head>
        <style>
            #highlight-paragraph {
            font-size: 18px;
            }

            #highlight-paragraph span {
            background-color: #afdeb2;
            transition: background-color 0.3s;
            }

            #highlight-paragraph span:hover {
            background-color: #3fd449; /* Change the background color on hover */
            }
        </style>
    </head>
    <body>
        <h1>Document ID 6 (Probe Acc: 29.82%)</h1>
        <h3>LAYER3-TGTIDX-1-train_small_5000-bsz1-lr0.00608-epochs74-ak6qyq6v</h3>
        <div id="highlight-paragraph">
            <span> 1. Field </span> <span> of the In </span>vention <span>
This </span> invention relates to peak respiratory flow monitoring <span>. More </span> specifically <span>, this </span> invention relates to peak flow monitoring <span> of individuals </span> with obstructive respiratory diseases <span> and other </span> conditions <span> and includes </s

In [58]:
html_explore(2, dark=False)


    <!DOCTYPE html>
    <html>
    <head>
        <style>
            #highlight-paragraph {
            font-size: 18px;
            }

            #highlight-paragraph span {
            background-color: #afdeb2;
            transition: background-color 0.3s;
            }

            #highlight-paragraph span:hover {
            background-color: #3fd449; /* Change the background color on hover */
            }
        </style>
    </head>
    <body>
        <h1>Document ID 2 (Probe Acc: 37.2%)</h1>
        <h3>LAYER3-TGTIDX-1-train_small_5000-bsz1-lr0.00608-epochs74-ak6qyq6v</h3>
        <div id="highlight-paragraph">
            <span> Q:

Model </span> <span>-View </span> <span>-Controller </span> <span> in JavaScript </span> <span>

tl </span>;dr <span>: How </span> does one implement MVC <span> in JavaScript </span> <span> in a clean </span> way <span>?
I </span> <span>'m </span> trying <span> to implement </span> MVC <span> in JavaScript </span> <span>. I have </span> googl

In [59]:
html_explore(316, dark=False)


    <!DOCTYPE html>
    <html>
    <head>
        <style>
            #highlight-paragraph {
            font-size: 18px;
            }

            #highlight-paragraph span {
            background-color: #afdeb2;
            transition: background-color 0.3s;
            }

            #highlight-paragraph span:hover {
            background-color: #3fd449; /* Change the background color on hover */
            }
        </style>
    </head>
    <body>
        <h1>Document ID 316 (Probe Acc: 31.47%)</h1>
        <h3>LAYER3-TGTIDX-1-train_small_5000-bsz1-lr0.00608-epochs74-ak6qyq6v</h3>
        <div id="highlight-paragraph">
            <span> * </span>Order Now* <span>

About </span> <span> the Author </span> <span>

Tom </span> Sileo <span> is co </span> <span>-author </span> <span> of 8 </span> SECONDS OF COURAGE <span> (Sim </span>on & Schuster <span>, 2017 </span>), FIRE IN MY EYES <span> (Da </span> Capo <span>, 2016 </span> <span>) and B </span>ROTHERS FOREVER <span> (Da </sp

In [60]:
# But this probe also does seem to be able to do something interesting with entities. Like it's recognizing "U.S. Army" here as a string of connected bigrams. Is this just a common ngram?
# outputs.loc[(outputs['doc_id']==316) & (outputs['current_tok']=='Army')][['doc_id', 'current_tok', 'actual_tok', 'top_1_tok', 'top_2_tok', 'top_3_tok', 'top_4_tok', 'top_5_tok']]
outputs.iloc[[172511, 172512, 172513, 172514, 172515]][['doc_id', 'current_tok', 'actual_tok', 'top_1_tok', 'top_2_tok', 'top_3_tok', 'top_4_tok', 'top_5_tok']]

Unnamed: 0,doc_id,current_tok,actual_tok,top_1_tok,top_2_tok,top_3_tok,top_4_tok,top_5_tok
172511,316,U,",""",",",.,the,that,\n
172512,316,.,U,U,.,",",it,3
172513,316,S,.,.,",",S,the,U
172514,316,.,S,S,.,U,Massachusetts,s
172515,316,Army,.,.,s,",",of,the


In [61]:
# Double checking that the entities it fails 100% of time on aren't anywhere close in the top-5
# It also kind of looks like Wars now contains information about movies and stuff. So like...there was some deletion? 
outputs.loc[(outputs['current_tok']=="Wars")][['current_tok', 'actual_tok', 'top_1_tok', 'top_2_tok', 'top_3_tok', 'top_4_tok', 'top_5_tok']]

Unnamed: 0,current_tok,actual_tok,top_1_tok,top_2_tok,top_3_tok,top_4_tok,top_5_tok
70128,Wars,Star,of,a,in,the,s
70162,Wars,Star,of,a,the,in,","
70248,Wars,Star,of,",",in,.,a
70405,Wars,Star,of,in,the,a,","
70454,Wars,Star,of,the,in,a,","
70492,Wars,Star,.,of,",",and,s
70603,Wars,Star,.,The,of,",",that
70779,Wars,Star,of,a,the,in,.
70839,Wars,Star,of,the,in,a,.
70880,Wars,Star,.,The,of,",",that


In [62]:
# for this one instance of "sedan" the linear probe predicted "car" oh my gosh!!! 
outputs.loc[(outputs['current_tok']=="ans") & (outputs['actual_tok']=="sed")][['current_tok', 'actual_tok', 'top_1_tok', 'top_2_tok', 'top_3_tok', 'top_4_tok', 'top_5_tok']]

Unnamed: 0,current_tok,actual_tok,top_1_tok,top_2_tok,top_3_tok,top_4_tok,top_5_tok
5442,ans,sed,s,a,-,2,ing


In [63]:
# what the fuck, it's predicting Republic but also predicting Democr, House, campaign... this is freaking crazy time!!!! :o 
outputs.loc[(outputs['current_tok']=="ans") & (outputs['actual_tok']=="Republic")][['current_tok', 'actual_tok', 'top_1_tok', 'top_2_tok', 'top_3_tok', 'top_4_tok', 'top_5_tok']]

Unnamed: 0,current_tok,actual_tok,top_1_tok,top_2_tok,top_3_tok,top_4_tok,top_5_tok
30465,ans,Republic,that,s,of,the,ing
44272,ans,Republic,that,The,s,of,","
44444,ans,Republic,0,of,4,3,7
44476,ans,Republic,0,3,4,of,s
44525,ans,Republic,that,of,",",the,and
44590,ans,Republic,The,of,that,s,’
45023,ans,Republic,of,the,to,that,for
93654,ans,Republic,that,of,The,s,the
161943,ans,Republic,.,that,",",The,s
161999,ans,Republic,for,of,to,the,a


In [64]:
# what about a boring one that it ALWAYS missed.
# I think it always missed because there are actually two token ids for 'This' and I calculate accuray based on token id, not token?
print(tokenizer.encode('This', bos=False, eos=False))
print(tokenizer.encode('ĠThis', bos=False, eos=False))
print(tokenizer.decode([4013]), tokenizer.decode([910]))
print(tokenizer.decode([4013]) == tokenizer.decode([910]))
outputs.loc[(outputs['current_tok']=="is") & (outputs['actual_tok']=="This") & (outputs['predicted_tok_id'] != outputs['actual_tok_id'])][['current_tok', 'actual_tok', 'predicted_tok', 'top_2_tok', 'top_3_tok', 'top_4_tok', 'top_5_tok']]

[910]
[29871, 31937, 4013]
This This
True


Unnamed: 0,current_tok,actual_tok,predicted_tok,top_2_tok,top_3_tok,top_4_tok,top_5_tok
5307,is,This,it,’,',this,is
8402,is,This,it,’,is,this,'
12886,is,This,\n,’,this,',it
19147,is,This,it,’,',.,is
25412,is,This,it,’,',is,this
58803,is,This,it,’,',this,is
60057,is,This,it,’,',is,s
64473,is,This,',\n,’,it,is
71206,is,This,\n,this,’,it,'
72149,is,This,\n,’,it,is,'


In [65]:
# TODO functionality to load in the probe itself, as well as model, and try for a custom piece of text.
probe = LinearModel(4096, 32000) # hardcoded model_dim and vocab for Llama-2
probe.load_state_dict(torch.load(retrieve_run_info("LAYER1-TGTIDX-1-train_quick_5-bsz512-lr0.00100-epochs1-kt9ne44m", "../checkpoints/llama-2-7b", "final.ckpt"), map_location=torch.device('cpu')))

../checkpoints/llama-2-7b/LAYER1-TGTIDX-1-train_quick_5-bsz512-lr0.00100-epochs1-kt9ne44m


<All keys matched successfully>