
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Post-Intervention

In [1]:
#@title Import libraries
import transformer_lens
import torch, gc, itertools, tqdm, scipy, copy, functools
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, intervening, localizing
modelHandlers.gpu_check()

Gb total: 15.6535, reserved: 0.0, allocated: 0.0


### Load Model

In [2]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu")

Loaded pretrained model gpt-neo-125M into HookedTransformer


### Load Data

In [45]:
train_dl, test_dl = dataLoaders.batched_pile(mem_batch=5, non_mem_batch=10, test_frac=0.0, shuffle=False)
#dl = dataLoaders.batched_pop_seqs(model, mem_batch=5, non_mem_batch=10)
c_toks_NI, k_toks_NI = next(iter(train_dl))

### Obtain or Load Intervention Weights

In [46]:
c_types= ["W_Q","W_K","W_V","W_O","W_in","W_out"]
fwd_bwd = functools.partial(gradient.run_fwd_bwd, after_I=0, with_mse=True, pool={"c":[-1],"k":[0,-1]}, norm_sets=False)
c_weights = localizing.batched_c_type_collection(model, train_dl, fwd_bwd, c_types=c_types, n_batches=1)

0it [00:00, ?it/s]

contrast_res: 0.6557825803756714, c_nll: 0.6557825803756714, k_nll: 0.0


0it [00:05, ?it/s]

returning ['blocks.0.attn.W_Q', 'blocks.1.attn.W_Q']... of shape: torch.Size([12, 12, 768, 64])
returning ['blocks.0.attn.W_K', 'blocks.1.attn.W_K']... of shape: torch.Size([12, 12, 768, 64])
returning ['blocks.0.attn.W_V', 'blocks.1.attn.W_V']... of shape: torch.Size([12, 12, 768, 64])
returning ['blocks.0.attn.W_O', 'blocks.1.attn.W_O']... of shape: torch.Size([12, 12, 64, 768])
returning ['blocks.0.mlp.W_in', 'blocks.1.mlp.W_in']... of shape: torch.Size([12, 768, 3072])
returning ['blocks.0.mlp.W_out', 'blocks.1.mlp.W_out']... of shape: torch.Size([12, 3072, 768])





### Collect continuations on keep set before making any model interventions

In [57]:
def batch_decode(model, toks_NI:torch.LongTensor=None, dl=None, n_batch:int=50, start_at_tok:int=50, new_toks:int=50, do_sample:bool=False):
    """
    generate new toks from a model given a prompt
    """
    if dl is None:
        dl = zip(toks_NI, toks_NI)
    preds_NI, trues_NI = [],[]
    for batch_i, (_, k_toks_NI) in tqdm.tqdm(enumerate(dl)):
        toks_NI = k_toks_NI.detach().to(model.cfg.device).squeeze(0) ## detach and put on device
        toks_NI_pref, toks_NI_true = toks_NI[...,:start_at_tok], toks_NI[...,-start_at_tok:]
        toks_NI_pred = model.generate(input=toks_NI_pref, max_new_tokens=new_toks, stop_at_eos=True, eos_token_id=None, do_sample=do_sample, top_k=None, top_p=None, temperature=1.0)
        toks_NI_pred = toks_NI_pred[...,-new_toks:] ## only take continuations  
        preds_NI.append(toks_NI_pred)
        trues_NI.append(toks_NI_true)
        
        if batch_i+1 == n_batch:
            break ## break early if too many items in dl
    preds_NI = torch.stack(preds_NI).view(-1,new_toks)
    trues_NI = torch.stack(trues_NI).view(-1,new_toks)
    return preds_NI, trues_NI

k_preds_orig, _ = batch_decode(model, toks_NI=None, dl=train_dl, n_batch=1)
model.to_string(k_preds_orig)

0it [00:00, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

0it [00:03, ?it/s]


[', and/or other securities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities',
 '}\n                                                ',
 ' for the purpose of this application. Please note that the reserved words are not to be used in any other application. **"\n            },\n           ',
 'alt.interwalt.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:141)\nat org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:235)\n',
 ' public:\n    S6(int v) : a(v) {}\n    S6(int v) : a(v) {}\n    S6(int v) : a(v) {}',
 '1-2).\n\nThe song is a reference to the song of the young man who is the father of the son who is the father of the son who is the father of the son who is the father of the son who is the father',
 'ofday|timeofyear)(\\.)?(\\/\\.(\\d{1,3}|\\d{1,3}|\\d{1,3}|\\d{1,3}|\\d{1

### Make Model Intervention

In [60]:
topK, select_c, select_l, select_heads, intervene_val = 10, ["W_V"], [2], [11], 0.01
c_weights_lk = intervening.get_topK_grads(c_weights,topK=topK,select_c=select_c,select_l=select_l,select_heads=select_heads,return_lk=True,largest=True,select_random=False)
model = intervening.intervene_params(model, c_weights_lk, set_val=intervene_val, as_noise=True, set_type="add")

torch.Size([12, 12, 768, 64])
Loaded pretrained model gpt-neo-125M into HookedTransformer
intervened with 0.01 on a total of 10 weights in ['W_V'] in layers {2}


In [61]:
k_pred_intervened, _ = batch_decode(model, toks_NI=None, dl=train_dl, n_batch=1)
orig_interv_em = modelHandlers.compute_exact_match(k_pred_intervened, k_preds_orig)
print(f"Exact match between original and intervened continuation of keep set: {orig_interv_em.float().mean().item()}")
#ck_em, ck_nll = intervening.evaluate_model(model, dl=train_dl)
model.to_string(preds_NI_intervened)

0it [00:00, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

0it [00:03, ?it/s]

Exact match between original and intervened continuation of keep set: 26.299999237060547





[', and/or other securities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities and/or commodities',
 '}\n                                                ',
 ' for the purpose of this application. Please note that the reserved words are not to be used in any other application. **"\n            },\n           ',
 'alt.interwalt.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:141)\nat org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:235)\n',
 ' public:\n    S6(int v) : a(v) {}\n    S6(int v) : a(v) {}\n    S6(int v) : a(v) {}',
 '1-2).\n\nThe song is a reference to the song of the young man who is the father of the son who is the father of the son who is the father of the son who is the father of the son who is the father',
 'ofday|timeofyear)(\\.)?(\\/\\.(\\d{1,3}|\\d{1,3}|\\d{1,3}|\\d{1,3}|\\d{1