In [1]:
from nnsight import LanguageModel
import pandas as pd
import torch as t
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda:0'

model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=device)

data = pd.read_json('/share/data/datasets/msgs/syntactic_category_lexical_content_the/train.jsonl', lines=True)


In [3]:
data

Unnamed: 0,sentence,condition,linguistic_feature_label,surface_feature_label,UID,linguistic_feature_type,linguistic_feature_description,surface_feature_type,surface_feature_description,control_paradigm,sentenceID,paradigmID,split
0,All grandsons do resemble the print and Debra ...,training,1,1,syntactic_category_lexical_content_the,syntactic,Is there an adjective present?,lexical_content,"Is the word ""the"" present?",False,40000,5000,train
1,All grandsons do resemble each print and Debra...,training,0,0,syntactic_category_lexical_content_the,syntactic,Is there an adjective present?,lexical_content,"Is the word ""the"" present?",False,40001,5000,train
2,Each colleague isn't criticizing the analyses ...,training,1,1,syntactic_category_lexical_content_the,syntactic,Is there an adjective present?,lexical_content,"Is the word ""the"" present?",False,40008,5001,train
3,Each colleague isn't criticizing these analyse...,training,0,0,syntactic_category_lexical_content_the,syntactic,Is there an adjective present?,lexical_content,"Is the word ""the"" present?",False,40009,5001,train
4,Few governments hadn't sold several hospitals ...,training,1,1,syntactic_category_lexical_content_the,syntactic,Is there an adjective present?,lexical_content,"Is the word ""the"" present?",False,40016,5002,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,Every associate caught some driver and a stude...,training,0,0,syntactic_category_lexical_content_the,syntactic,Is there an adjective present?,lexical_content,"Is the word ""the"" present?",False,79977,9997,train
9996,Some actresses drop by the hospitals and a bos...,training,1,1,syntactic_category_lexical_content_the,syntactic,Is there an adjective present?,lexical_content,"Is the word ""the"" present?",False,79984,9998,train
9997,Some actresses drop by fewer than three hospit...,training,0,0,syntactic_category_lexical_content_the,syntactic,Is there an adjective present?,lexical_content,"Is the word ""the"" present?",False,79985,9998,train
9998,Every senator purchases the shirts and a smart...,training,1,1,syntactic_category_lexical_content_the,syntactic,Is there an adjective present?,lexical_content,"Is the word ""the"" present?",False,79992,9999,train


In [4]:
class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        softmaxed_logits = (t.softmax(logits, dim=-1) * logits).sum(dim=-1)
        return t.sigmoid(softmaxed_logits)

In [5]:
batch_size = 256
lr = 1e-2

probe = Probe(512).to(device)
optimizer = t.optim.Adam(probe.parameters(), lr=lr)
losses = []

for batch_idx in range(len(data) // batch_size):
    inputs = data['sentence'][batch_idx * batch_size:(batch_idx + 1) * batch_size].tolist()
    label = data['linguistic_feature_label'][batch_idx * batch_size:(batch_idx + 1) * batch_size].tolist()

    with model.invoke(inputs) as invoker:
        hidden_states = model.gpt_neox.layers[-3].output[0].save()
    
    optimizer.zero_grad()
    logits = probe(hidden_states.value.clone())
    loss = nn.BCELoss()(logits, t.Tensor(label).to(device))
    loss.backward()
    optimizer.step()

    losses.append(loss.item())


You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [6]:
losses

[0.7080693244934082,
 0.7173736095428467,
 0.6817024946212769,
 0.6360753774642944,
 0.6241862773895264,
 0.5873124599456787,
 0.5238829851150513,
 0.4599684774875641,
 0.4041571617126465,
 0.33180248737335205,
 0.28367525339126587,
 0.21787819266319275,
 0.20481103658676147,
 0.17532525956630707,
 0.1501021683216095,
 0.12602679431438446,
 0.11127845197916031,
 0.09764689207077026,
 0.07033877074718475,
 0.08369386196136475,
 0.06978712975978851,
 0.057081498205661774,
 0.060696087777614594,
 0.06456602364778519,
 0.048354603350162506,
 0.04755610227584839,
 0.04461493343114853,
 0.05637466534972191,
 0.03905995935201645,
 0.028560271486639977,
 0.04731447994709015,
 0.030016174539923668,
 0.0315505675971508,
 0.02036595344543457,
 0.025361012667417526,
 0.023468855768442154,
 0.023788487538695335,
 0.02086271159350872,
 0.023363469168543816]

In [10]:
test_data = pd.read_json('/share/data/datasets/msgs/syntactic_category_lexical_content_the/test.jsonl', lines=True)

ling_accs, surface_accs = [], []
# get accuracy on test data
for batch_idx in range(len(test_data) // batch_size):
    inputs = test_data['sentence'][batch_idx * batch_size:(batch_idx + 1) * batch_size].tolist()
    ling_labels = test_data['linguistic_feature_label'][batch_idx * batch_size:(batch_idx + 1) * batch_size].tolist()
    surface_labels = test_data['surface_feature_label'][batch_idx * batch_size:(batch_idx + 1) * batch_size].tolist()

    with model.invoke(inputs) as invoker:
        hidden_states = model.gpt_neox.layers[-3].output[0].save()
    
    with t.no_grad():
        preds = probe(hidden_states.value)
        ling_acc = (preds.round() == t.Tensor(ling_labels).to('cuda:0')).float().mean()
        surface_acc = (preds.round() == t.Tensor(surface_labels).to('cuda:0')).float().mean()
        ling_accs.append(ling_acc.item())
        surface_accs.append(surface_acc.item())

print('ling acc:', sum(ling_accs) / len(ling_accs))
print('surface acc:', sum(surface_accs) / len(surface_accs))
    

ling acc: 0.7715010683760684
surface acc: 0.8924278846153846


In [11]:
from attribution import patching_effect
from dictionary_learning.dictionary import AutoEncoder

In [12]:
clean = "All grandsons do resemble the print and Debra is an organized child."
patch = "All grandsons do resemble a print and Debra is an organized child."

with model.invoke([clean, patch]) as invoker:
    hidden_states = model.gpt_neox.layers[-3].output[0].save()

with t.no_grad():
    preds = probe(hidden_states.value)
preds

tensor([0.9941, 0.4238], device='cuda:0')

In [15]:
def metric_fn(model):
    return probe(model.gpt_neox.layers[-3].output[0])

submodules = [
    model.gpt_neox.layers[i].mlp for i in range(4)
]
dictionaries = []
for i in range(len(submodules)):
    dictionary = AutoEncoder(512, 64 * 512).to(device)
    dictionary.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{i}/1_32768/ae.pt'))
    dictionaries.append(dictionary)

out = patching_effect(
    clean,
    patch,
    model,
    submodules,
    dictionaries,
    metric_fn,
)

In [16]:
effects, total_effect = out
print(f"Total effect: {total_effect}")
for layer, submodule in enumerate(submodules):
    print(f"Layer {layer}:")
    effect = effects[submodule]
    for feature_idx in t.nonzero(effect):
        value = effect[tuple(feature_idx)]
        if value > 0.005:
            print(f"    Multindex: {tuple(feature_idx.tolist())}, Value: {value}")

Total effect: tensor([-0.5703], device='cuda:0', grad_fn=<SubBackward0>)
Layer 0:
Layer 1:
Layer 2:
Layer 3:
