## Install Requirements

In [None]:
%pip install transformers
%pip install torch
%pip install pickle5
%pip install mpld3
%pip install scikit-learn

## Load Pre-trained Bert Model and Tokenizer

In [None]:
import torch
from transformers import BertTokenizer, BertModel

model = "bert"
folder_num = "1."
save_path = f"../data/extracted/{folder_num} {model}/{model}_word_"

# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

## Open, Save, and Extract Embeddings

In [3]:
from tqdm import tqdm
import pickle5 as pickle


def dump_pklfile(file, filepath, size):
	with open(filepath, "wb") as f:
		if size == 0:
			pickle.dump(file, f)
		elif size > 0:
			pickle.dump(file[:size], f)
		else:
			pickle.dump(file[size:], f)


def open_pklfile(filepath, size):
	with open(filepath, "rb") as f:
		if size == 0:
			return pickle.load(f)
		return (pickle.load(f))[0:size]


def extract_bert_embeddings(word_list):
	# tensor for stacking embeddings
	embeddings = torch.empty(0, device=device)

	for word in tqdm(word_list):
		# Map the token strings to their vocabulary indeces.
		marked_text = "[CLS] " + word + " [SEP]"
		tokenized_text = tokenizer.tokenize(marked_text)
		
		# handling such as "wedding_dress"
		tokenized_text = [token for token in tokenized_text if token != '_']

		# Split the sentence into tokens.
		indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
		segments_ids = [1] * len(tokenized_text)

		# Convert inputs to PyTorch tensors
		tokens_tensor = torch.tensor([indexed_tokens], device=device)
		segments_tensors = torch.tensor([segments_ids], device=device)
		
		# Put the model in "evaluation" mode,meaning feed-forward operation.
		model.eval()

		# Run the text through BERT, get the output and collect all of the hidden states produced from all 12 layers.
		with torch.no_grad():
			outputs = model(tokens_tensor, segments_tensors).hidden_states
			last_four_hidden_states = outputs[-4:]
			concated_hidden_states = torch.cat(last_four_hidden_states, dim=2)

			first_last = torch.add(concated_hidden_states[:, 1], concated_hidden_states[:, -2])
			embeddings = torch.cat([embeddings, first_last])

	# print(embeddings.shape)
	return torch.squeeze(embeddings)


## Compute Gender Bias by Projection

In [4]:
import numpy as np
from torch import linalg as LA
import scipy.stats
import json 
import codecs


# normalize word vectors
def normalize(wv):    
    norms = LA.norm(wv, dim=1)
    wv = wv / norms[:, np.newaxis]
    return wv


# compute bias from BERT with he-she
def compute_bias_by_projection_heshe(vocab, lim_wv, gender_word_embedding):
    print(lim_wv.shape)
    print(gender_word_embedding[0].shape)
    males = torch.tensordot(lim_wv, gender_word_embedding[0], dims=1)
    females = torch.tensordot(lim_wv, gender_word_embedding[1], dims=1)
    d = {}
    for w, m, f in zip(vocab, males, females):
        d[w] = m - f
    return d


def extract_professions():
    professions = []
    with codecs.open('../data/lists/professions.json', 'r', 'utf-8') as f:
        professions_data = json.load(f)
    for item in professions_data:
        professions.append(item[0].strip())
    return professions

## Extract Embeddings: All (2016)

In [5]:
vocab = open_pklfile("../data/extracted/0. original/original_word_2016_restricted_vocab.pkl", 0)
lim_wv = extract_bert_embeddings(vocab)
dump_pklfile(lim_wv, f"{save_path}2016_restricted_embeddings.pkl", 0)

lim_wv = normalize(lim_wv)
w2i = {w: i for i, w in enumerate(vocab)}

100%|██████████| 26189/26189 [03:07<00:00, 139.37it/s]


In [6]:
with open("../data/lists/male_word_file.txt", 'r') as f:
  male_words = [word.strip() for word in f.readlines()]

with open("../data/lists/female_word_file.txt", 'r') as f:
  female_words = [word.strip() for word in f.readlines()]

male_embeddings = extract_bert_embeddings(male_words)
female_embeddings = extract_bert_embeddings(female_words)

male_embedding = torch.sum(male_embeddings, dim = 0, keepdim = True)
female_embedding = torch.sum(female_embeddings, dim = 0, keepdim = True)

male_embeddings = normalize(male_embeddings)
female_embeddings = normalize(female_embeddings)

gender_word_embedding = torch.cat((male_embedding, female_embedding), 0)

gender_bias_all = compute_bias_by_projection_heshe(vocab, lim_wv, gender_word_embedding)

100%|██████████| 223/223 [00:01<00:00, 146.21it/s]
100%|██████████| 223/223 [00:01<00:00, 148.05it/s]


torch.Size([26189, 3072])
torch.Size([3072])


In [7]:
import operator

sorted_g = sorted(gender_bias_all.items(), key=operator.itemgetter(1))
females_w = [item[0] for item in sorted_g[:2500]]
females_e = torch.empty(0, device=device)
for w in females_w:
    temp = torch.unsqueeze(lim_wv[w2i[w]], 0)
    females_e = torch.cat([females_e, temp], dim=0)

dump_pklfile(females_w, f"{save_path}2016_female_2500_vocab.pkl", 2500)
dump_pklfile(females_e, f"{save_path}2016_female_2500_embeddings.pkl", 2500)

sorted_g = sorted(gender_bias_all.items(), key=operator.itemgetter(1), reverse=True)
males_w = [item[0] for item in sorted_g[:2500]]
males_e = torch.empty(0, device=device)
for w in males_w:
    temp = torch.unsqueeze(lim_wv[w2i[w]], 0)
    males_e = torch.cat([males_e, temp], dim=0)
    
dump_pklfile(males_w, f"{save_path}2016_male_2500_vocab.pkl", 2500)
dump_pklfile(males_e, f"{save_path}2016_male_2500_embeddings.pkl", 2500)


## Extract Embeddings: All (2018)

In [8]:
vocab = open_pklfile("../data/extracted/0. original/original_word_2018_restricted_vocab.pkl", 0)
lim_wv = extract_bert_embeddings(vocab)
dump_pklfile(lim_wv, f"{save_path}2018_restricted_embeddings.pkl", 0)

lim_wv = normalize(lim_wv)
w2i = {w: i for i, w in enumerate(vocab)}

100%|██████████| 47698/47698 [06:27<00:00, 123.13it/s]


In [9]:
with open("../data/lists/male_word_file.txt", 'r') as f:
  male_words = [word.strip() for word in f.readlines()]

with open("../data/lists/female_word_file.txt", 'r') as f:
  female_words = [word.strip() for word in f.readlines()]

male_embeddings = extract_bert_embeddings(male_words)
female_embeddings = extract_bert_embeddings(female_words)

male_embedding = torch.sum(male_embeddings, dim = 0, keepdim = True)
female_embedding = torch.sum(female_embeddings, dim = 0, keepdim = True)

male_embeddings = normalize(male_embeddings)
female_embeddings = normalize(female_embeddings)

gender_word_embedding = torch.cat((male_embedding, female_embedding), 0)

gender_bias_all = compute_bias_by_projection_heshe(vocab, lim_wv, gender_word_embedding)

100%|██████████| 223/223 [00:01<00:00, 145.25it/s]
100%|██████████| 223/223 [00:01<00:00, 151.54it/s]


torch.Size([47698, 3072])
torch.Size([3072])


In [10]:
import operator

sorted_g = sorted(gender_bias_all.items(), key=operator.itemgetter(1))
females_w = [item[0] for item in sorted_g[:2500]]
females_e = torch.empty(0, device=device)

for w in females_w:
    temp = torch.unsqueeze(lim_wv[w2i[w]], 0)
    females_e = torch.cat([females_e, temp], dim=0)

dump_pklfile(females_w, f"{save_path}2018_female_2500_vocab.pkl", 2500)
dump_pklfile(females_e, f"{save_path}2018_female_2500_embeddings.pkl", 2500)

sorted_g = sorted(gender_bias_all.items(), key=operator.itemgetter(1), reverse=True)
males_w = [item[0] for item in sorted_g[:2500]]
males_e = torch.empty(0, device=device)

for w in males_w:
    temp = torch.unsqueeze(lim_wv[w2i[w]], 0)
    males_e = torch.cat([males_e, temp], dim=0)
    
dump_pklfile(males_w, f"{save_path}2018_male_2500_vocab.pkl", 2500)
dump_pklfile(males_e, f"{save_path}2018_male_2500_embeddings.pkl", 2500)

## Extract Embeddings: Gender Word File

In [11]:
with open("../data/lists/male_word_file.txt", 'r') as f:
  male_words = [word.strip() for word in f.readlines()]
male_word_embs = extract_bert_embeddings(male_words)
dump_pklfile(male_words, f"{save_path}male_word_file_vocab.pkl", 0)
dump_pklfile(male_word_embs, f"{save_path}male_word_file_embeddings.pkl", 0)

with open("../data/lists/female_word_file.txt", 'r') as f:
  female_words = [word.strip() for word in f.readlines()]
female_word_embs = extract_bert_embeddings(female_words)
dump_pklfile(female_words, f"{save_path}female_word_file_vocab.pkl", 0)
dump_pklfile(female_word_embs, f"{save_path}female_word_file_embeddings.pkl", 0)

100%|██████████| 223/223 [00:01<00:00, 146.68it/s]
100%|██████████| 223/223 [00:01<00:00, 148.57it/s]
