In [1]:
from nnsight import LanguageModel

model = LanguageModel("openai-community/gpt2-xl", device_map='auto', dispatch=True)
tok = model.tokenizer




In [None]:
import torch

string = '<|endoftext|>The Space Needle is in the city of'

target = tok.encode(' Seattle', return_tensors='pt')[0][0]

noise = torch.randn(1,4,1600) * 3 * 0.044414032250642776


with model.trace(string):

    clean_state = model.transformer.h[17].output[0].save()

    clean_logits = model.lm_head.output.softmax(-1)[:,-1,target].save()
    
with model.trace(string):

    model.transformer.wte.output[:,[1,2,3,4]] = noise

    model.transformer.h[17].output[0][:,4,:] = clean_state[:,4,:]

    restored_logits = model.lm_head.output.softmax(-1)[:,-1,target].save()

print(clean_logits, restored_logits)

In [11]:
import torch

string = '<|endoftext|>The Space Needle is in downtown'

target = tok.encode(' Seattle', return_tensors='pt')[0][0]

noise = torch.randn(1,4,1600) * 3 * 0.044414032250642776

clean_states = {}

# Sliding window of 5
r = range(15,20)

with model.trace(string):

    for i in r:
        clean_states[i] = model.transformer.h[i].mlp.act.output.save()

    clean_logits = model.lm_head.output.softmax(-1)[:,-1,target].save()

print("CLEAN: ", clean_logits)


with model.trace(string):

    model.transformer.wte.output[:,[1,2,3,4]] += noise

    corr_logits = model.lm_head.output.softmax(-1)[:,-1,target].save()

print("CORR: ", corr_logits)

for token in range(8):
    with model.trace(string):

        model.transformer.wte.output[:,[1,2,3,4]] += noise

        for i in r:
            model.transformer.h[i].mlp.act.output[:,token,:] = clean_states[i][:,token,:]

        restored_logits = model.lm_head.output.softmax(-1)[:,-1,target].save()

    print(f"RESTORED AT {token}: ", restored_logits)

CLEAN:  tensor([0.9799], device='cuda:0', grad_fn=<SelectBackward0>)
CORR:  tensor([0.0180], device='cuda:0', grad_fn=<SelectBackward0>)
RESTORED AT 0:  tensor([0.0180], device='cuda:0', grad_fn=<SelectBackward0>)
RESTORED AT 1:  tensor([0.0196], device='cuda:0', grad_fn=<SelectBackward0>)
RESTORED AT 2:  tensor([0.0183], device='cuda:0', grad_fn=<SelectBackward0>)
RESTORED AT 3:  tensor([0.0194], device='cuda:0', grad_fn=<SelectBackward0>)
RESTORED AT 4:  tensor([0.1000], device='cuda:0', grad_fn=<SelectBackward0>)
RESTORED AT 5:  tensor([0.0189], device='cuda:0', grad_fn=<SelectBackward0>)
RESTORED AT 6:  tensor([0.0184], device='cuda:0', grad_fn=<SelectBackward0>)
RESTORED AT 7:  tensor([0.0168], device='cuda:0', grad_fn=<SelectBackward0>)


In [None]:
import torch
from tqdm import tqdm

STDEV = 0.044414032250642776

results = torch.zeros((48, 10))

string = '<|endoftext|>The Space Needle is in the city of'
prompt = tok.encode(string)
target = tok.encode(' Seattle', return_tensors='pt')[0][0]
subject_tokens = [1,2,3,4]

print(len(prompt))

for _ in tqdm(range(10)):

    noise = torch.randn(1,4,1600) * 3 * STDEV

    with torch.no_grad():

        clean_states = {}

        with model.trace(string):

            for layer in range(48):
                clean_states[layer] = model.transformer.h[layer].mlp.output.cpu().save()

            clean_logits = model.lm_head.output.softmax(-1)[:,-1,target].save()

        with model.trace(string):

            model.transformer.wte.output[:,subject_tokens] = noise

            corr_logits = model.lm_head.output.softmax(-1)[:,-1,target].save()

        print('clean logits:', clean_logits.value.item())
        print('corrupted logits:', corr_logits.value.item())

        for i in range(48):

            for _tok in range(10):
                
                with model.trace(string):

                    model.transformer.wte.output[:,subject_tokens] = noise

                    model.transformer.h[i].mlp.output[:, _tok, :] = clean_states[i][:,_tok,:]

                    restored_logits = model.lm_head.output.softmax(-1)[:,-1,target].save()

                    diff = restored_logits - corr_logits 

                    diff.save()

                results[i, _tok] += diff.value.item()

In [None]:
import matplotlib.pyplot as plt
from scipy.ndimage import uniform_filter1d


test = uniform_filter1d(results.numpy().T, size=5, axis=1, mode='reflect')

def plot_trace(results, str_tokens):
    fig, ax = plt.subplots(figsize=(10, 8))
    cax = ax.imshow(results, cmap="Purples", aspect="auto")
    fig.colorbar(cax, ax=ax, orientation="vertical")
    # ax.set_yticklabels(str_tokens)
    ax.set_xlabel("single restored layer within GPT-2-XL")

plot_trace(test, None)