## Install Requirements

In [None]:
%pip install transformers
%pip install torch
%pip install pickle5
%pip install mpld3
%pip install scikit-learn

## Load Pre-trained Bert Model and Tokenizer

In [None]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM

# debiased_model = "sent_debiased"
# folder_num = "2."
# model_path = "../debiased_models/sent_debias/debias-BERT/experiments/acl2020-results/QNLI/debiased_final_final"
# model = BertModel.from_pretrained(model_path, output_hidden_states = True)

# debiased_model = "contextualised"
# folder_num = "3."
# model_path = "../debiased_models/contextualised-embeddings-bert"
# model = BertModel.from_pretrained(model_path, output_hidden_states = True)

debiased_model = "cds"
folder_num = "4."
model_path = "../debiased_models/cds.pt"
model = BertForMaskedLM.from_pretrained('bert-base-uncased',
                                        output_attentions = False,
                                        output_hidden_states = True)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

save_path = f"../data/extracted/{folder_num} {debiased_model}/{debiased_model}_word_"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


## Open, Save, and Extract Embeddings

In [25]:
from tqdm import tqdm
import pickle5 as pickle


def dump_pklfile(file, filepath, size):
	with open(filepath, "wb") as f:
		if size == 0:
			pickle.dump(file, f)
		elif size > 0:
			pickle.dump(file[:size], f)
		else:
			pickle.dump(file[size:], f)


def open_pklfile(filepath, size):
	with open(filepath, "rb") as f:
		if size == 0:
			return pickle.load(f)
		return (pickle.load(f))[0:size]


def extract_bert_embeddings(word_list):
	# tensor for stacking embeddings
	embeddings = torch.empty(0, device=device)

	for word in tqdm(word_list):
		# Map the token strings to their vocabulary indeces.
		marked_text = "[CLS] " + word + " [SEP]"
		tokenized_text = tokenizer.tokenize(marked_text)
		
		# handling such as "wedding_dress"
		tokenized_text = [token for token in tokenized_text if token != '_']

		# Split the sentence into tokens.
		indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
		segments_ids = [1] * len(tokenized_text)

		# Convert inputs to PyTorch tensors
		tokens_tensor = torch.tensor([indexed_tokens], device=device)
		segments_tensors = torch.tensor([segments_ids], device=device)
		
		# Put the model in "evaluation" mode,meaning feed-forward operation.
		model.eval()

		# Run the text through BERT, get the output and collect all of the hidden states produced from all 12 layers.
		with torch.no_grad():
			outputs = model(tokens_tensor, segments_tensors).hidden_states
			last_four_hidden_states = outputs[-4:]
			concated_hidden_states = torch.cat(last_four_hidden_states, dim=2)

			first_last = torch.add(concated_hidden_states[:, 1], concated_hidden_states[:, -2])
			embeddings = torch.cat([embeddings, first_last])

	# print(embeddings.shape)
	return torch.squeeze(embeddings)

## Compute Gender Bias by Projection

In [26]:
import numpy as np
from torch import linalg as LA
import scipy.stats
import json 
import codecs


# normalize vectors
def normalize(wv):    
    norms = LA.norm(wv, dim=1)
    wv = wv / norms[:, np.newaxis]
    return wv


# compute bias from bert with he-she
def compute_bias_by_projection_heshe(vocab, lim_wv, gender_word_embedding):
    print(lim_wv.shape)
    print(gender_word_embedding[0].shape)
    males = torch.tensordot(lim_wv, gender_word_embedding[0], dims=1)
    females = torch.tensordot(lim_wv, gender_word_embedding[1], dims=1)
    d = {}
    for w, m, f in zip(vocab, males, females):
        d[w] = m - f
    return d


def extract_professions():
    professions = []
    with codecs.open('../data/lists/professions.json', 'r', 'utf-8') as f:
        professions_data = json.load(f)
    for item in professions_data:
        professions.append(item[0].strip())
    return professions

## Extract Embeddings: Restricted Embeddings (2016)

In [27]:
vocab = open_pklfile("../data/extracted/0. original/original_word_2016_restricted_vocab.pkl", 0)
lim_wv = extract_bert_embeddings(vocab)
dump_pklfile(lim_wv, f"{save_path}2016_restricted_embeddings.pkl", 0)
print(save_path)

100%|██████████| 26189/26189 [03:17<00:00, 132.82it/s]


../data/extracted/4. cds/cds_word_


## Extracting Embeddings: Restricted Embeddings (2018)

In [28]:
vocab = open_pklfile("../data/extracted/0. original/original_word_2018_restricted_vocab.pkl", 0)
lim_wv = extract_bert_embeddings(vocab)
dump_pklfile(lim_wv, f"{save_path}2018_restricted_embeddings.pkl", 0)
print(save_path)

100%|██████████| 47698/47698 [06:23<00:00, 124.33it/s]


../data/extracted/4. cds/cds_word_


## Extracting Embeddings: Top 2500 - using words from BERT (2016)

In [29]:
vocab_male_2016 = open_pklfile("../data/extracted/1. bert/bert_word_2016_male_2500_vocab.pkl", 0)
lim_wv_male_2016 = extract_bert_embeddings(vocab_male_2016)
dump_pklfile(lim_wv_male_2016, f"{save_path}2016_male_2500_embeddings.pkl", 0)

vocab_female_2016 = open_pklfile("../data/extracted/1. bert/bert_word_2016_female_2500_vocab.pkl", 0)
lim_wv_female_2016 = extract_bert_embeddings(vocab_female_2016)
dump_pklfile(lim_wv_female_2016, f"{save_path}2016_female_2500_embeddings.pkl", 0)

100%|██████████| 2500/2500 [00:16<00:00, 150.40it/s]
100%|██████████| 2500/2500 [00:16<00:00, 150.87it/s]


## Extracting Embeddings: Top 2500 - using words from BERT (2018)

In [30]:
vocab_male_2018 = open_pklfile("../data/extracted/1. bert/bert_word_2018_male_2500_vocab.pkl", 0)
lim_wv_male_2018 = extract_bert_embeddings(vocab_male_2018)
dump_pklfile(lim_wv_male_2018, f"{save_path}2018_male_2500_embeddings.pkl", 0)

vocab_female_2018 = open_pklfile("../data/extracted/1. bert/bert_word_2018_female_2500_vocab.pkl", 0)
lim_wv_female_2018 = extract_bert_embeddings(vocab_female_2018)
dump_pklfile(lim_wv_female_2018, f"{save_path}2018_female_2500_embeddings.pkl", 0)

100%|██████████| 2500/2500 [00:16<00:00, 150.23it/s]
100%|██████████| 2500/2500 [00:16<00:00, 150.63it/s]


## Extract Embeddings: Gender Word File

In [31]:
with open("../data/lists/male_word_file.txt", 'r') as f:
  male_words = [word.strip() for word in f.readlines()]
male_word_embs = extract_bert_embeddings(male_words)
dump_pklfile(male_word_embs, f"{save_path}male_word_file_embeddings.pkl", 0)

with open("../data/lists/female_word_file.txt", 'r') as f:
  female_words = [word.strip() for word in f.readlines()]
female_word_embs = extract_bert_embeddings(female_words)
dump_pklfile(female_word_embs, f"{save_path}female_word_file_embeddings.pkl", 0)

100%|██████████| 223/223 [00:01<00:00, 149.22it/s]
100%|██████████| 223/223 [00:01<00:00, 150.84it/s]
