In [1]:
import torch
import torch.nn as nn
from torch import Tensor
from transformers import GPTNeoXForCausalLM, AutoTokenizer, AutoModelForCausalLM
from jaxtyping import Float, Int
from typing import List, Optional, Tuple, Dict
import sys
from functools import partial
from tqdm import tqdm
import itertools
import json
import seaborn as sns
import pandas as pd
import multiprocessing
import pickle

from request_patching import request_patch_one_pair, create_patch_request_dict, baseline_completion, baseline_completion_plus
from models import get_model_from_name

import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device =', device)
torch.set_grad_enabled(False)

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")

else:
    print("CUDA is not available. Listing CPUs instead.")
    print(multiprocessing.cpu_count())

  from .autonotebook import tqdm as notebook_tqdm


device = cpu
CUDA is not available. Listing CPUs instead.
48


In [2]:
model, tokenizer = get_model_from_name('pythia-6.9b')

prompt_dict = {
    'prompt_1': {'context': 'Emily is a dynamic Democrat who tirelessly works towards advancing civil rights, healthcare access, and public education reform. She combines her expertise in law and social work to advocate for policies that dismantle systemic inequalities and support mental health initiatives.',
                 'sex': 'f',
                 'party': 'd'
    },
    'prompt_2': {'context': "Carlos is an energetic Democrat with a focus on immigration reform and workers' rights. As the son of immigrants, he brings a personal passion to his advocacy, aiming to create pathways to citizenship and fair labor practices",
                 'sex': 'm',
                 'party': 'd'
    },
    'prompt_3': {'context': "Ethan is a steadfast Republican with a deep commitment to economic development and job creation. As an entrepreneur, he understands the challenges faced by small businesses and advocates for lower taxes and reduced government regulation to encourage investment and growth.",
                 'sex': 'm',
                 'party': 'r'
    },
    'prompt_4': {'context': "Rachel is a principled Republican who values national security and law enforcement. With a background in criminal justice, she advocates for policies that support the military and police forces, believing in the importance of a safe and secure society.",
                 'sex': 'f',
                 'party': 'r'
    },
    'prompt_5': {'context': "Anna is a passionate and dedicated Democrat, known for her strong commitment to social justice, healthcare reform, and environmental protection. With a background in public policy, she actively participates in community organizing and voter registration drives, aiming to empower underrepresented groups and promote policies that ensure equality and sustainability.",
                 'sex': 'f',
                 'party': 'd'
    },
    'prompt_6': {'context': "Bob is a devoted Republican, characterized by his firm belief in limited government, fiscal responsibility, and strong national defense. With a background in business, he champions policies that promote economic growth, entrepreneurship, and tax reform, advocating for a regulatory environment that fosters innovation and competition.",
                 'sex': 'f',
                 'party': 'r'}
}

few_shot_ex = 'Sophia is a dedicated Democrat with a strong focus on social justice and environmental sustainability. As a community organizer, she recognizes the struggles of underserved communities and champions policies for equitable access to healthcare and education. In 2020 she voted for Mr. Biden.'

Loading checkpoint shards: 100%|██████████| 2/2 [00:46<00:00, 23.45s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Baseline

In [3]:
years = ['2000', '2004', '2008', '2012']
layers = [14, 15, 16, 19, 20]
all_prompts = [p for p in prompt_dict]


# BASELINE COMPLETION
for p in tqdm(prompt_dict):
    pronoun = 'he' if prompt_dict[p]['sex']=='m' else 'she'
    prompt_dict[p]['baseline_completion'] = {}
    for year in years:
        context = prompt_dict[p]['context'] + f' In {year} {pronoun} voted for Mr.'
        prompt_dict[p]['baseline_completion'][f'{year}'] = baseline_completion(context, model, tokenizer)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


KeyboardInterrupt: 

### Constant Year

In [None]:
patch_dict = {}
pair_id = 0


for year in years:
    for pair in list(itertools.combinations(all_prompts, 2)):

        if prompt_dict[pair[0]]['party'] == prompt_dict[pair[1]]['party']: 
            # at constant year if the two parties are the same there is no patching visible
            continue

        pronouns = ['he' if prompt_dict[pair[x]]['sex']=='m' else 'she' for x in range(0,2)]
        context_1 = prompt_dict[pair[0]]['context'] + f' In {year} {pronouns[0]} voted for Mr.'
        context_2 = prompt_dict[pair[1]]['context'] + f' In {year} {pronouns[1]} voted for Mr.'
        
        answer_list = request_patch_one_pair(context_1, context_2, model, tokenizer, layers=layers)
        patch_dict[f'pair_{pair_id}'] = {'context_1': context_1,
                                      'context_2': context_2,
                                      'year_1': year,
                                      'year_2': year,
                                      'R1_C1': prompt_dict[pair[0]]['baseline_completion'][f'{year}'],
                                      'R2_C2': prompt_dict[pair[1]]['baseline_completion'][f'{year}'],
                                      'R1_C2': prompt_dict[pair[0]]['baseline_completion'][f'{year}'],
                                      'R2_C1': prompt_dict[pair[1]]['baseline_completion'][f'{year}']
        }
        for j, layer_id in enumerate(layers):
            patch_dict[f'pair_{pair_id}'][f'patch_l{layer_id}'] = answer_list[j]
        pair_id += 1

        # Patch in the other direction: from context_2 to context_1
        answer_list = request_patch_one_pair(context_2, context_1, model, tokenizer, layers=layers)
        patch_dict[f'pair_{pair_id}'] = {'context_1': context_2,
                                      'context_2': context_1,
                                      'year_1': year,
                                      'year_2': year,
                                      'R1_C1': prompt_dict[pair[1]]['baseline_completion'][f'{year}'],
                                      'R2_C2': prompt_dict[pair[0]]['baseline_completion'][f'{year}'],
                                      'R1_C2': prompt_dict[pair[1]]['baseline_completion'][f'{year}'], 
                                      'R2_C1': prompt_dict[pair[0]]['baseline_completion'][f'{year}'] 
        }
        for j, layer_id in enumerate(layers):
            patch_dict[f'pair_{pair_id}'][f'patch_l{layer_id}'] = answer_list[j]
        pair_id += 1
        
        with open('outputs/no_dialog_patching_dict.pkl', 'wb') as f:
            pickle.dump(patch_dict, f)

### Varying year and character

In [5]:
def patch_no_dialog_varying_year_and_character(prompt_dict: dict,
                                               years: list[str] = ['2000', '2004', '2008', '2012'],
                                               layers: list[int] = [14, 15, 16, 19, 20],
                                               output_file_name: str = 'no_dialog_patching_dict.pkl'
):

    for p in tqdm(prompt_dict, desc='Baseline completion'):
        pronoun = 'he' if prompt_dict[p]['sex']=='m' else 'she'
        prompt_dict[p]['baseline_completion'] = {}
        for year in years:
            context = prompt_dict[p]['context'] + f' In {year} {pronoun} voted for Mr.'
            prompt_dict[p]['baseline_completion'][f'{year}'] = baseline_completion(context, model, tokenizer)


    patch_dict = {}
    pair_id = 0

    for char_1 in tqdm(prompt_dict, desc='char_1'):
        for year_1 in years:
            for char_2 in prompt_dict:
                for year_2 in years:

                    if (char_1 == char_2 and 
                        year_1 == year_2):
                        # in this case context_1 = context_2
                        continue

                    party_1 = prompt_dict[char_1]['party']
                    party_2 = prompt_dict[char_2]['party']
                    if (party_1 == party_2):
                        continue

                    pronouns = ['he' if prompt_dict[x]['sex']=='m' else 'she' for x in [char_1, char_2]]
                    context_1 = prompt_dict[char_1]['context'] + f' In {year_1} {pronouns[0]} voted for Mr.'
                    context_2 = prompt_dict[char_2]['context'] + f' In {year_2} {pronouns[1]} voted for Mr.'
                    
                    token_per_layer, logits_per_layer = request_patch_one_pair(context_1, context_2, model, tokenizer, layers=layers)
                    patch_dict[f'pair_{pair_id}'] = {'context_1': context_1,
                                                    'context_2': context_2,
                                                    'year_1': year_1,
                                                    'year_2': year_2,
                                                    'R1_C1': prompt_dict[char_1]['baseline_completion'][f'{year_1}'],
                                                    'R2_C2': prompt_dict[char_2]['baseline_completion'][f'{year_2}'],
                                                    'R1_C2': prompt_dict[char_1]['baseline_completion'][f'{year_2}'], # à vérif
                                                    'R2_C1': prompt_dict[char_2]['baseline_completion'][f'{year_1}'] # à vérif
                    }
                    for j, layer_id in enumerate(layers):
                        patch_dict[f'pair_{pair_id}'][f'token_l{layer_id}'] = token_per_layer[j]

                        probabilities = torch.nn.functional.softmax(logits_per_layer[j], dim=-1)[0, -1].cpu().numpy() 
                        patch_dict[f'pair_{pair_id}'][f'logit_l{layer_id}'] = logits_per_layer[j]
                    pair_id += 1

                    with open(f'outputs/{output_file_name}', 'wb') as f:
                        pickle.dump(patch_dict, f)
    
    print(f'Number of patchings: {pair_id}')
    return 

In [7]:
# Remove Carlos (prompt_2), Ethan (prompt_3) and years 1996 so the baseline is correctly completed by pythia-6.9b
#for e in ['prompt_2', 'prompt_3']:
    #prompt_dict.pop(e)

patch_no_dialog_varying_year_and_character(prompt_dict=prompt_dict,
                                           years=['2000', '2004', '2008', '2012'],
                                           output_file_name='no_dialog_patching_dict_filtered_ethancarlos.pkl'
)

Baseline completion: 100%|██████████| 4/4 [13:43<00:00, 205.82s/it]
char_1: 100%|██████████| 4/4 [6:12:03<00:00, 5580.99s/it]  


Number of patchings: 128


## Few-shot no dialog

In [None]:
def patch_no_dialog_few_shot(prompt_dict: dict,
                             few_shot_paragraph: str,
                             years: list[str] = ['1992', '1996', '2000', '2004', '2008', '2012'],
                             layers: list[int] = [14, 15, 16, 17, 18, 19, 20],
                             output_file_name: str = 'no_dialog_patching_dict.pkl',
                             model_name: str = 'pythia-6.9b'
):

    model, tokenizer = get_model_from_name(model_name)

    # BASELINE COMPLETION
    for p in tqdm(prompt_dict, desc='Baseline completion'):
        pronoun = 'he' if prompt_dict[p]['sex']=='m' else 'she'
        prompt_dict[p]['baseline_completion'] = {}
        for year in years:
            context = f"{few_shot_paragraph}\n {prompt_dict[p]['context']} In {year} {pronoun} voted for Mr."
            prompt_dict[p]['baseline_completion'][f'{year}'] = baseline_completion(context, model, tokenizer)


    patch_dict = {}
    pair_id = 0

    for char_1 in tqdm(prompt_dict, desc='char_1'):
        for year_1 in years:
            for char_2 in prompt_dict:
                for year_2 in years:

                    if (char_1 == char_2 and 
                        year_1 == year_2):
                        # in this case context_1 = context_2
                        continue

                    party_1 = prompt_dict[char_1]['party']
                    party_2 = prompt_dict[char_2]['party']
                    if (party_1 == party_2):
                        continue

                    pronouns = ['he' if prompt_dict[x]['sex']=='m' else 'she' for x in [char_1, char_2]]
                    context_1 = f"{few_shot_paragraph}\n {prompt_dict[char_1]['context']} In {year_1} {pronouns[0]} voted for Mr."
                    context_2 = f"{few_shot_paragraph}\n {prompt_dict[char_2]['context']} In {year_2} {pronouns[1]} voted for Mr."
                    
                    token_per_layer, logits_per_layer = request_patch_one_pair(context_1, context_2, model, tokenizer, layers=layers)
                    patch_dict[f'pair_{pair_id}'] = {'context_1': context_1,
                                                    'context_2': context_2,
                                                    'year_1': year_1,
                                                    'year_2': year_2,
                                                    'R1_C1': prompt_dict[char_1]['baseline_completion'][f'{year_1}'],
                                                    'R2_C2': prompt_dict[char_2]['baseline_completion'][f'{year_2}'],
                                                    'R1_C2': prompt_dict[char_1]['baseline_completion'][f'{year_2}'],
                                                    'R2_C1': prompt_dict[char_2]['baseline_completion'][f'{year_1}']
                    }
                    for j, layer_id in enumerate(layers):
                        patch_dict[f'pair_{pair_id}'][f'token_l{layer_id}'] = token_per_layer[j]

                        probabilities = torch.nn.functional.softmax(logits_per_layer[j], dim=-1)[0, -1].cpu().numpy() 
                        patch_dict[f'pair_{pair_id}'][f'logit_l{layer_id}'] = logits_per_layer[j]
                    pair_id += 1

                    with open(f'outputs/{output_file_name}', 'wb') as f:
                        pickle.dump(patch_dict, f)
    
    print(f'Number of patchings: {pair_id}')
    return 