In [1]:
import re
import json
import pickle
import os
import sys
import requests
import logging
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import wandb
import plotly.express as px
import pandas as pd
import torch.nn.init as init
from pathlib import Path
from jaxtyping import Int, Float
from torch import Tensor
import einops
from collections import Counter
from datasets import load_dataset
import pandas as pd
from ipywidgets import interact, IntSlider
from process_tiny_stories_data import load_tinystories_validation_prompts, load_tinystories_tokens
from typing import Literal


pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

logging.basicConfig(format='(%(levelname)s) %(asctime)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')
sys.path.append('../')  # Add the parent directory to the system path

import utils.haystack_utils as haystack_utils
from sparse_coding.train_autoencoder import AutoEncoder
from utils.autoencoder_utils import custom_forward, AutoEncoderConfig, evaluate_autoencoder_reconstruction, get_encoder_feature_frequencies, load_encoder, get_acts
import utils.haystack_utils as haystack_utils
from utils.plotting_utils import line
from sparse_coding.spacy_tag import make_spacy_feature_df

from utils.probing_utils import train_probe
import utils.probing_utils as probing_utils
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score


%reload_ext autoreload
%autoreload 2

In [2]:
import subprocess
subprocess.run(['python', '-m', 'spacy', 'download', 'en_core_web_trf'])

Collecting en-core-web-trf==3.7.3
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.7.3/en_core_web_trf-3.7.3-py3-none-any.whl (457.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m457.4/457.4 MB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m


[0m

[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_trf')


CompletedProcess(args=['python', '-m', 'spacy', 'download', 'en_core_web_trf'], returncode=0)

In [3]:
haystack_utils.clean_cache()
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.2"

model_name = "tiny-stories-2L-33M"
print_name = "TinyStories 2L 33M"

model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
)

n_prompts = 40
prompts = load_tinystories_validation_prompts(data_path='data/tinystories')[:n_prompts]
tokens = model.to_tokens(prompts)
print(tokens.shape)
try:
    df = make_spacy_feature_df(model, tokens)
except:
    try:
        df = make_spacy_feature_df(model, tokens)
    except Exception as e:
        print(e)

(INFO) 08:09:01: Loaded 21990 TinyStories validation prompts


Loaded pretrained model tiny-stories-2L-33M into HookedTransformer
torch.Size([40, 304])
Starting spacy processing of dataset...
Finished spacy processing of dataset.
1701418166.560078 0


In [4]:
df[["is_spacy_adj"]]
print(tokens.shape)

torch.Size([40, 304])


In [5]:
# look for encoder features that go on particular spacy attributes
save_name = '18_morning_sun'
encoder, cfg = load_encoder(save_name, model_name, model, save_path='/workspace')

# acts = []
# for i in range(len(tokens)):
#     acts.append(get_acts(tokens[i], model, encoder, cfg))
# acts = torch.cat(acts).cpu()

# threshold = 0.1
# f1_scores = {}
# for series_name, series in tqdm(df.items()):
#     neuron_binarized = (acts > threshold).T
#     for i in range(len(neuron_binarized)):
#         f1_scores[(series_name, i)] = f1_score(series, neuron_binarized[i])

# new_f1_scores = {}
# for key, value in f1_scores.items():
#     col, direction = key
#     if col not in new_f1_scores:
#         new_f1_scores[col] = {}
#     new_f1_scores[col][direction] = value

# with open('/workspace/data/spacy_f1s_2.json', 'w') as f:
#     json.dump(new_f1_scores, f)
with open('/workspace/data/spacy_f1s.json', 'r') as f:
    new_f1_scores = json.load(f)

In [6]:
list(list(new_f1_scores.items())[0][1].items())[:5]

[('0', 0.0), ('1', 0.0), ('2', 0.0), ('3', 0.0), ('4', 0.012048192771084336)]

In [7]:
from collections import defaultdict
good_uns = defaultdict(list)
interesting_directions = []
for col, items in new_f1_scores.items():
    # there are hundred of punctuation dirs and they're probably less interesting
    if col == "is_spacy_punct":
        continue
    for direction, f1 in items.items():
        if f1 > 0.4:
            interesting_directions.append(direction)
            good_uns[col].append(direction)


# del good_uns["is_spacy_punct"]

In [8]:

good_uns
cols_with_dirs = list((col, [int(dir) for dir in dirs]) for col, dirs in good_uns.items())
print(len(cols_with_dirs))
print(len(set(interesting_directions)))

31
66


In [16]:
# Round two of what we just did above, but now only collecting acts data for the positive and negative classes of the direction/spacy attribute tuples we are interested in.
haystack_utils.clean_cache()
n_prompts = 200
prompts = load_tinystories_validation_prompts(data_path='data/tinystories')[:n_prompts]
tokens = model.to_tokens(prompts)
df = make_spacy_feature_df(model, tokens, use_tqdm=True)

def train_probe(
    positive_data: torch.Tensor, negative_data: torch.Tensor
) -> tuple[float, float]:
    labels = np.concatenate([np.ones(len(positive_data)), np.zeros(len(negative_data))])
    data = np.concatenate([positive_data.cpu().numpy(), negative_data.cpu().numpy()])
    scaler = preprocessing.StandardScaler().fit(data)
    data = scaler.transform(data)
    x_train, x_test, y_train, y_test = train_test_split(
        data, labels, test_size=0.2, random_state=42
    )
    probe = probing_utils.get_probe(x_train, y_train, max_iter=2000)
    f1, mcc = probing_utils.get_probe_score(probe, x_test, y_test)
    return f1, mcc

Starting spacy processing of dataset...


0it [00:00, ?it/s]

Finished spacy processing of dataset.


0it [00:00, ?it/s]

1701420222.20287 0


In [17]:
interesting_directions_ints = [int(dir) for dir in interesting_directions]


In [20]:
acts = []
for i in range(len(tokens)):
    act = get_acts(tokens[i], model, encoder, cfg)
    acts.append(act[:, interesting_directions_ints])

acts = torch.cat(acts, dim=0).cpu() # batch d_interesting
# compare acts and spacy annotations to get data

  0%|          | 0/31 [00:00<?, ?it/s]

0         False
1         False
2         False
3         False
4         False
          ...  
209195    False
209196    False
209197    False
209198    False
209199    False
Name: is_spacy_det, Length: 209200, dtype: bool


KeyError: tensor([[False,  True,  True,  ..., False, False, False],
        [False,  True,  True,  ..., False, False, False],
        [False,  True,  True,  ..., False, False, False],
        ...,
        [False,  True,  True,  ..., False, False, False],
        [False,  True,  True,  ..., False, False, False],
        [False,  True,  True,  ...,  True,  True,  True]], device='cuda:0')

In [35]:
flattened_tokens = tokens.flatten().cpu()


In [36]:
f1s = []
mccs = []
for col, dirs in tqdm(cols_with_dirs):
    for dir in dirs:
        token_attributes = df[[col]].squeeze(1)
        attr_tensor = torch.tensor(token_attributes)[flattened_tokens != 50256] # 1 batch
        dir_acts = acts[flattened_tokens != 50256, interesting_directions_ints.index(dir)]

        pos_class = dir_acts[attr_tensor == True][:10_000]
        neg_class = dir_acts[attr_tensor == False][:10_000]
        print(f"{len(pos_class)} positive class activations, {len(neg_class)} negative class activations")
        f1, mcc = train_probe(
            pos_class.unsqueeze(-1),
            neg_class.unsqueeze(-1),
        )
        f1s.append(f1)
        mccs.append(mcc)

  0%|          | 0/31 [00:00<?, ?it/s]

2738 positive class activations, 10000 negative class activations
2207 positive class activations, 10000 negative class activations
1156 positive class activations, 10000 negative class activations
1156 positive class activations, 10000 negative class activations
1156 positive class activations, 10000 negative class activations
1156 positive class activations, 10000 negative class activations
1156 positive class activations, 10000 negative class activations
214 positive class activations, 10000 negative class activations
214 positive class activations, 10000 negative class activations
214 positive class activations, 10000 negative class activations
214 positive class activations, 10000 negative class activations
214 positive class activations, 10000 negative class activations
214 positive class activations, 10000 negative class activations
214 positive class activations, 10000 negative class activations
214 positive class activations, 10000 negative class activations
214 positive class

In [38]:
print(f1s)

attrs = []
for col, dirs in tqdm(cols_with_dirs):
    for dir in dirs:
        attrs.append(col)

[0.4621621621621621, 0.42214532871972316, 0.622093023255814, 0.5934718100890208, 0.4147157190635452, 0.7238605898123326, 0.4252491694352159, 0.6666666666666666, 0.7246376811594203, 0.5806451612903226, 0.5806451612903226, 0.7777777777777778, 0.7428571428571429, 0.36363636363636365, 0.8108108108108109, 0.7945205479452054, 0.456140350877193, 0.4827586206896552, 0.6602316602316602, 0.3316062176165803, 0.3316062176165803, 0.2631578947368421, 0.16216216216216214, 0.25641025641025644, 0.16666666666666669, 0.6518105849582172, 0.5971014492753624, 0.7225130890052355, 0.16, 0.6499032882011605, 0.6774193548387097, 0.7384615384615384, 0.6333333333333334, 0.5090909090909091, 0.7941176470588235, 0.7575757575757575, 0.39215686274509803, 0.8000000000000002, 0.8115942028985507, 0.4230769230769231, 0.36, 0.49122807017543857, 0.39999999999999997, 0.4794520547945206, 0.1951219512195122, 0.41958041958041964, 0.5263157894736842, 0.3357664233576642, 0.0, 0.47457627118644075, 0.42857142857142855, 0.83544303797

  0%|          | 0/31 [00:00<?, ?it/s]

In [41]:
with open('/workspace/data/spacy_summary_stats.json', 'w') as f:
    json.dump({
        "f1s": f1s,
        "mccs": mccs,
        "dirs": interesting_directions,
        "attrs": attrs
    }, f, indent=4)

In [40]:
os.symlink('/workspace', 'workspace')

In [46]:
for i, f1 in enumerate(f1s):
    if f1 > 0.75:
        print(attrs[i])


# explanations of attributes

import spacy
nlp = spacy.load("en_core_web_trf")

for component in nlp.pipe_names:
    if component == "parser":
        tags = nlp.pipe_labels[component]
        if len(tags)!=0:
            print(f"Label mapping for component: {component}")
            display(dict(list(zip(tags, [spacy.explain(tag) for tag in tags]))))

is_spacy_num
is_spacy_num
is_spacy_num
is_spacy_nummod
is_spacy_nummod
is_spacy_nummod
is_spacy_nummod
is_spacy_acl
is_spacy_acl
is_spacy_acl
is_perf_aspect
is_perf_aspect
is_perf_aspect
is_ORDINAL
Map labels:


Label mapping for component: parser



[W118] Term 'predet' not found in glossary. It may however be explained in documentation for the corpora used to train the language. Please check `nlp.meta["sources"]` for any relevant links.



{'ROOT': 'root',
 'acl': 'clausal modifier of noun (adjectival clause)',
 'acomp': 'adjectival complement',
 'advcl': 'adverbial clause modifier',
 'advmod': 'adverbial modifier',
 'agent': 'agent',
 'amod': 'adjectival modifier',
 'appos': 'appositional modifier',
 'attr': 'attribute',
 'aux': 'auxiliary',
 'auxpass': 'auxiliary (passive)',
 'case': 'case marking',
 'cc': 'coordinating conjunction',
 'ccomp': 'clausal complement',
 'compound': 'compound',
 'conj': 'conjunct',
 'csubj': 'clausal subject',
 'csubjpass': 'clausal subject (passive)',
 'dative': 'dative',
 'dep': 'unclassified dependent',
 'det': 'determiner',
 'dobj': 'direct object',
 'expl': 'expletive',
 'intj': 'interjection',
 'mark': 'marker',
 'meta': 'meta modifier',
 'neg': 'negation modifier',
 'nmod': 'modifier of nominal',
 'npadvmod': 'noun phrase as adverbial modifier',
 'nsubj': 'nominal subject',
 'nsubjpass': 'nominal subject (passive)',
 'nummod': 'numeric modifier',
 'oprd': 'object predicate',
 'parata