<a href="https://colab.research.google.com/github/boknilev/lm-intervention/blob/master/lm_intervention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#!pip install pytorch-pretrained-bert
#!pip install spacy ftfy==4.4.3
#!python -m spacy download en

import torch
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
import random
from functools import partial 
from tqdm import tqdm
from collections import Counter, defaultdict

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="ticks", color_codes=True)

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer

np.random.seed(1)
torch.manual_seed(1)





<torch._C.Generator at 0x7fb3fe427950>

In [0]:
# TODO: plot the log probs nicely 
def plot_log_probs(layer_to_candidate1_log_probs, layer_to_candidate2_log_probs):
    
    raise NotImplementedError
        

In [0]:
def gpt2_intervention():
    
    """ Code draws on https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_gpt2.py """
    
    enc = GPT2Tokenizer.from_pretrained('gpt2')
    model = GPT2LMHeadModel.from_pretrained('gpt2')
    model.eval()
    #print(model)
    
    top_k = 5
    
    # multiplier for intervention; needs to be pretty large (~100) to see a change
    # TODO: plot the intervention results (how many neurons are flipped) for different alphas
    alpha = 500
    
    num_layers = len(model.transformer.h)
    num_neurons = 768    
    
    
    def print_neuron_hook(module, input, output, position, neuron):
        #print(output.shape) 
        print(output[0][position][neuron])
        
    def print_all_hook(module, input, output, position):
        #print(output.shape) 
        print(output[0][position])

    def intervention_hook(module, input, output, position, neuron, intervention):
        #output[0][1][0] = 10
        output[0][position][neuron] += intervention
            
    def extract_representation_hook(module, input, output, position, representations, layer):
        representations[layer] = output[0][position]

        
    with torch.no_grad():
        raw_text = "The teacher said that"
        raw_text_man = "The man said that"
        raw_text_woman = "The woman said that"
        
        # add space character
        candidate1 = 'Ġ' + 'she'
        candidate2 = 'Ġ' + 'he'
        position = 1 # intervene at "teacher"

        candidate1_token, candidate2_token = enc.convert_tokens_to_ids([candidate1, candidate2])
        
        # first collect representations of man and woman  
        # TODO: do only one forward pass and get all layer representations from that, instead of iterating over layers
        man_representations = {}
        woman_representations = {} 
        for layer in tqdm(range(num_layers), desc='man and woman'):
            # man 
            context_tokens = enc.encode(raw_text_man)
            context = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0).repeat(1, 1)
            prev = context
            output = context
            orig_past = None                        
            extract_handle = model.transformer.h[layer].mlp.register_forward_hook(partial(extract_representation_hook, position=position, representations=man_representations, layer=layer))
            logits, past = model(prev, past=orig_past)
            extract_handle.remove()
            
            # woman
            context_tokens = enc.encode(raw_text_woman)
            context = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0).repeat(1, 1)
            prev = context
            output = context
            orig_past = None                        
            extract_handle = model.transformer.h[layer].mlp.register_forward_hook(partial(extract_representation_hook, position=position, representations=woman_representations, layer=layer))
            logits, past = model(prev, past=orig_past)
            extract_handle.remove()
            
        #print('man:', man_representations)
        #print('woman:', woman_representations)

            

        # now intervene in "teacher"
        context_tokens = enc.encode(raw_text)
        context = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0).repeat(1, 1)
        prev = context
        output = context
        orig_past = None


        layer_to_candidate1, layer_to_candidate2 = Counter(), Counter()
        layer_to_candidate1_log_probs, layer_to_candidate2_log_probs = defaultdict(list), defaultdict(list)
        
        for layer in tqdm(range(num_layers), desc='layers'):
        #for layer in range(1): 
            #print('==> intervene at layer', layer)

            #print('==> first pass (no intervention)')
            #mlp_print_all_handle = model.transformer.h[layer].mlp.register_forward_hook(partial(print_all_hook, position=position))
            logits, past = model(prev, past=orig_past)
            #mlp_print_all_handle.remove()

            logits = logits[:, -1, :]
            log_probs = F.softmax(logits, dim=-1)
            candidate1_log_prob, candidate2_log_prob = log_probs[0][[candidate1_token, candidate2_token]]
            #print(candidate1, candidate1_token, candidate1_log_prob)
            #print(candidate2, candidate2_token, candidate2_log_prob)
            
            for neuron in range(num_neurons):
            #for neuron in range(10):
                #print('neuron:', neuron)
                #print(woman_representations[layer].shape)
                #print(woman_representations[layer][neuron].shape)


                #intervention = woman_representations[layer][neuron] - man_representations[layer][neuron]
                intervention = alpha * (man_representations[layer][neuron] - woman_representations[layer][neuron])
                #print('intervention:', intervention)
               
                mlp_intervention_handle = model.transformer.h[layer].mlp.register_forward_hook(partial(intervention_hook, position=position, neuron=neuron, intervention=intervention))
                #mlp_print_neuron_handle = model.transformer.h[layer].mlp.register_forward_hook(partial(print_neuron_hook, position=position, neuron=neuron))
                logits, past = model(prev, past=orig_past)
                mlp_intervention_handle.remove()
                #mlp_print_neuron_handle.remove()

                logits = logits[:, -1, :]
                #logits = top_k_logits(logits, k=top_k)
                #print(logits.shape)
                #print(logits[0][673])
                #print(logits[0][339])
                #logits[0][339] = -99
                log_probs = F.softmax(logits, dim=-1)
                #log_probs_top_k, prev = torch.topk(log_probs, k=top_k, dim=-1)
                #output = torch.cat((output, prev), dim=1)   
                #out = output[:, len(context_tokens):].tolist()
                #text = enc.decode(out[0])
                #print(text)
                #print(log_probs_top_k)

                #print(log_probs.shape)
                candidate1_log_prob, candidate2_log_prob = log_probs[0][[candidate1_token, candidate2_token]]
                layer_to_candidate1_log_probs[layer].append(candidate1_log_prob)
                layer_to_candidate2_log_probs[layer].append(candidate2_log_prob)
                if candidate1_log_prob > candidate2_log_prob:
                    layer_to_candidate1[layer] += 1
                else:
                    layer_to_candidate2[layer] += 1
                    
                #print(candidate1, candidate1_token, candidate1_log_prob)
                #print(candidate2, candidate2_token, candidate2_log_prob)  
        
        print('more probable candidate per layer, across all neurons in the layer')
        print('candidate1:', candidate1, layer_to_candidate1, layer_to_candidate1_log_probs)
        print('candidate2:', candidate2, layer_to_candidate2, layer_to_candidate2_log_probs)

        # TODO: we need to look at the log prob distribution over the candidates and how that changes by intervention



        
    

In [0]:
def simple_intervention():

    def intervention_hook(module, input, output):
        output[1] = 1

    dim = 5

    lin1 = nn.Linear(dim, dim)
    relu = nn.ReLU()
    lin2 = nn.Linear(dim, dim)
    mlp = nn.Sequential(lin1, relu, lin2)
    print('mlp:', mlp)
    #lin1_clone = lin1.clone()
    #mlp_intervene = nn.Sequential(lin1_clone, relu, lin2)
    #lin1_clone.register_forward_hook(intervention_hook)
    with torch.no_grad():

        print('==> before intervention')

        input = torch.rand(5)
        print('input:', input)
        out1 = lin1(input)
        print('out1:', out1)
        outrelu = relu(out1)
        print('outrelu:', outrelu)
        out2 = lin2(outrelu)
        print('out2:', out2)
        output = mlp(input)
        print('output:', output)

        print('==> intervene')

        out1_intervened = out1.clone()
        out1_intervened[1] = 1
        print('out1:', out1)
        print('out1_intervened:', out1_intervened)
        outrelu_intervened = relu(out1_intervened)    
        print('outrelu_intervened:', outrelu_intervened)
        out2_intervened = lin2(outrelu_intervened)
        print('out2_intervened:', out2_intervened)    

        lin1_handle = lin1.register_forward_hook(intervention_hook)
        out1_intervened = lin1(input)
        print('out1_intervened:', out1_intervened)
        output_intervened = mlp(input)
        print('output_intervened:', output_intervened)
        
        print('==> remove intervention')
        
        lin1_handle.remove()
        out1 = lin1(input)
        print('out1:', out1)    
        output = mlp(input)
        print('output:', output)


In [0]:
#simple_intervention()

gpt2_intervention()





man and woman:   0%|          | 0/12 [00:00<?, ?it/s][A[A[A[A



man and woman:   8%|▊         | 1/12 [00:00<00:02,  3.78it/s][A[A[A[A



man and woman:  17%|█▋        | 2/12 [00:00<00:02,  3.73it/s][A[A[A[A



man and woman:  25%|██▌       | 3/12 [00:00<00:02,  3.72it/s][A[A[A[A



man and woman:  33%|███▎      | 4/12 [00:01<00:02,  3.68it/s][A[A[A[A



man and woman:  42%|████▏     | 5/12 [00:01<00:01,  3.65it/s][A[A[A[A



man and woman:  50%|█████     | 6/12 [00:01<00:01,  3.65it/s][A[A[A[A



man and woman:  58%|█████▊    | 7/12 [00:01<00:01,  3.65it/s][A[A[A[A



man and woman:  67%|██████▋   | 8/12 [00:02<00:01,  3.66it/s][A[A[A[A



man and woman:  75%|███████▌  | 9/12 [00:02<00:00,  3.67it/s][A[A[A[A



man and woman:  83%|████████▎ | 10/12 [00:02<00:00,  3.63it/s][A[A[A[A



man and woman:  92%|█████████▏| 11/12 [00:03<00:00,  3.63it/s][A[A[A[A



man and woman: 100%|██████████| 12/12 [00:03<00:00,  3.64it/s][A[A[A[A





more probable candidate per layer, across all neurons in the layer
candidate1: Ġshe Counter({11: 768, 10: 733, 9: 650, 8: 537, 6: 434, 7: 402, 4: 359, 5: 268, 3: 238, 1: 226, 2: 216, 0: 96}) defaultdict(<class 'list'>, {0: [tensor(0.0149), tensor(0.0019), tensor(0.0213), tensor(0.0208), tensor(0.0721), tensor(0.1064), tensor(0.0747), tensor(0.1319), tensor(0.0118), tensor(0.0100), tensor(0.0002), tensor(0.0341), tensor(0.0303), tensor(0.0124), tensor(0.0401), tensor(0.0368), tensor(0.1188), tensor(0.1248), tensor(0.0123), tensor(0.0267), tensor(0.0120), tensor(0.0191), tensor(0.0047), tensor(0.0336), tensor(0.0655), tensor(0.0241), tensor(0.0254), tensor(0.0075), tensor(0.0069), tensor(0.0059), tensor(0.0189), tensor(0.0102), tensor(0.0072), tensor(0.0609), tensor(0.0074), tensor(0.1128), tensor(0.0669), tensor(0.0140), tensor(0.0442), tensor(0.0332), tensor(0.0423), tensor(0.0057), tensor(0.1228), tensor(0.0239), tensor(0.0816), tensor(0.1166), tensor(0.0388), tensor(0.0543), tensor(0