<a href="https://colab.research.google.com/github/boknilev/lm-intervention/blob/master/lm_intervention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install pytorch-pretrained-bert
#!pip install spacy ftfy==4.4.3
#!python -m spacy download en

import torch
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
import random
from functools import partial 
from tqdm import tqdm
from tqdm import tqdm_notebook

from collections import Counter, defaultdict

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="ticks", color_codes=True)

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer

np.random.seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x112024ed0>

In [12]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
class Intervention():
    '''
    Wrapper for all the possible interventions
    '''
    def __init__(self, 
                 base_string:str, 
                 substitutes:list, 
                 candidates:list):
        super()
        self.enc = tokenizer
        # All the initial strings
        # First item should be neutral, others tainted
        self.base_strings = [base_string.format(s) 
                             for s in substitutes]
        # Tokenized bases
        self.base_strings_tok = [self._to_batch(s)
                                 for s in self.base_strings]
        # Where to intervene
        self.position = base_string.split().index('{}')
        
        # How to extend the string
        self.candidates = ['Ġ' + c for c in candidates]
        # tokenized candidates
        self.candidates_tok = [self.enc.convert_tokens_to_ids(c) 
                               for c in self.candidates]
        
    def _to_batch(self, txt):
        encoded = self.enc.encode(txt) 
        return torch.tensor(encoded, dtype=torch.long)\
                    .unsqueeze(0)\
                    .repeat(1, 1)
        

In [3]:
class Model():
    '''
    Wrapper for all model logic
    '''
    def __init__(self):
        super()
        self.model = GPT2LMHeadModel.from_pretrained('gpt2')
        self.model.eval()
        
        # Options
        self.top_k = 5
        # 12 for GPT-2
        self.num_layers = len(self.model.transformer.h)
        # 768 for GPT-2 
        self.num_neurons = self.model.transformer.wte.weight.shape[1] 
        
        # multiplier for intervention; needs to be pretty large (~100) to see a change
        # TODO: plot the intervention results (how many neurons are flipped) for different alphas
        self.alpha = 500
        
    def get_representations(self, context, position):
        # Hook for saving the representation
        def extract_representation_hook(module, input, output, position, representations, layer):
            representations[layer] = output[0][position]
        handles = []
        representation = {}
        with torch.no_grad():
            # construct all the hooks
            for layer in range(self.num_layers):
                handles.append(self.model.transformer.h[0]\
                                   .mlp.register_forward_hook(
                    partial(extract_representation_hook, 
                            position=position, 
                            representations=representation, 
                            layer=layer)))
            logits, past = model.model(context)
            for h in handles:
                h.remove()
        print(representation[0][:5])
        return representation
    
    def get_probabilities_for_examples(self, context, outputs):
        logits, past = self.model(context)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
#         print(probs.shape)
#         print(outputs)
        return probs[0][outputs]

    def intervene_for_examples(self, 
                               context, 
                               outputs, 
                               repr_difference, 
                               layer, 
                               neuron, 
                               position):
        # Hook for changing representation during forward pass
        def intervention_hook(module, input, output, position, neuron, intervention):
            output[0][position][neuron] += intervention
        
        intervention_rep = self.alpha * repr_difference[layer][neuron]
        mlp_intervention_handle = self.model.transformer.h[layer]\
                                       .mlp.register_forward_hook(
            partial(intervention_hook, 
                    position=position, 
                    neuron=neuron, 
                    intervention=intervention_rep))
        new_probabilities = self.get_probabilities_for_examples(
            context, 
            outputs)
        mlp_intervention_handle.remove()
        return new_probabilities
model = Model()

In [4]:
# TODO: plot the log probs nicely 
def plot_log_probs(layer_to_candidate1_log_probs, layer_to_candidate2_log_probs):
    
    raise NotImplementedError
        
def print_neuron_hook(module, input, output, position, neuron):
        #print(output.shape) 
        print(output[0][position][neuron])
        
def print_all_hook(module, input, output, position):
    #print(output.shape) 
    print(output[0][position])

In [13]:
'''
Bookkeeping for experiments
'''

intervention = Intervention(
        "The {} said that",
        ["teacher", "man", "woman"],
        ["he", "she"])

In [16]:
'''
To Do: actually run all of them. note: does not include teacher
'''

profession_interventions = []
with open('professions.json', 'r') as f:
    for l in f: 
        # there is only one line that eval's to an array
        for j in eval(l):
            profession = j[0]
            profession_interventions.append(
                Intervention(
                    "The {} said that",
                    [profession, "man", "woman"],
                    ["he", "she"]))

In [5]:
layer_to_candidate1, layer_to_candidate2 = Counter(), Counter()
layer_to_candidate1_probs, layer_to_candidate2_probs = defaultdict(list), defaultdict(list)


""" Code draws on https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_gpt2.py """        
with torch.no_grad():
    '''
    Compute representations for gendered terms
    ''' 
    man_representations = model_repr = model.get_representations(
        intervention.base_strings_tok[1], 
        intervention.position)
    woman_representations = model.get_representations(
        intervention.base_strings_tok[2], 
        intervention.position)
    representation_difference = {k: v - woman_representations[k] 
                                 for k,v in man_representations.items()}
    '''
    Now intervening on potentially biased example
    '''
    context = intervention.base_strings_tok[0]
    '''
    Probabilities without intervention (Base case)
    '''
    base_probs = model.get_probabilities_for_examples(
        context, 
        intervention.candidates_tok)
    print("Base case: {} ____".format(intervention.base_strings[0]))
    for token, prob in zip(intervention.candidates, base_probs):
        print("{}: {:.2f}%".format(token, prob*100))

    '''
    Intervene at every possible neuron
    '''
    for layer in tqdm_notebook(range(model.num_layers), desc='layers'):
        for neuron in tqdm_notebook(range(model.num_neurons), desc='neurons'):
            candidate1_prob, candidate2_prob = model.intervene_for_examples(
                context=context, 
                outputs=intervention.candidates_tok, 
                repr_difference=representation_difference, 
                layer=layer, 
                neuron=neuron, 
                position=intervention.position)

            layer_to_candidate1_probs[layer].append(candidate1_prob)
            layer_to_candidate2_probs[layer].append(candidate2_prob)
            if candidate1_prob > candidate2_prob:
                layer_to_candidate1[layer] += 1
            else:
                layer_to_candidate2[layer] += 1
    
    # TODO: we need to look at the log prob distribution over the candidates and how that changes by intervention  
    

tensor([-0.9688, -0.0767, -1.3549, -0.5847, -1.2433])
tensor([-0.4244,  0.8223, -1.8646, -0.8140, -1.0686])
Base case: The teacher said that ____
Ġhe: 9.81%
Ġshe: 12.17%


HBox(children=(IntProgress(value=0, description='layers', max=12, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='neurons', max=768, style=ProgressStyle(description_width='ini…




In [6]:
'''
Report aggregate
'''
print('more probable candidate per layer, across all neurons in the layer')
print('candidate1:', intervention.candidates[0], layer_to_candidate1)
print('candidate2:', intervention.candidates[1], layer_to_candidate2)


more probable candidate per layer, across all neurons in the layer
candidate1: Ġhe Counter({0: 672, 2: 668, 1: 664, 3: 633, 4: 569, 5: 561, 6: 428, 7: 375, 8: 359, 9: 158, 10: 37})
candidate2: Ġshe Counter({11: 768, 10: 731, 9: 610, 8: 409, 7: 393, 6: 340, 5: 207, 4: 199, 3: 135, 1: 104, 2: 100, 0: 96})
