In [1]:
#WEAT Word Embedding Association Tests
import torch as tr
import numpy as np
from bisect import bisect_left
import spacy
from tqdm.notebook import tqdm
import itertools
nlp = spacy.load("en_core_web_sm")
tokenizer = nlp.Defaults.create_tokenizer(nlp)

path_glove = './'

In [2]:
with open("data/vocab.txt") as fin:
    vocab,_ = zip(*map(lambda x: x.split(" "), fin))
    vocab = sorted(list(vocab) + ["<unk>"])  

In [3]:
### create embedding matrix
def index(a, x):
    'Locate the leftmost value exactly equal to x'
    i = bisect_left(a, x)
    if i != len(a) and a[i] == x:
        return i
    raise ValueError
    
def create_embedding_matrix(filepath, vocab, embedding_dim):
    vocab_size = len(vocab)  
    # Adding again 1 because of reserved 0 index
    embedding_matrix = tr.zeros((vocab_size, embedding_dim))

    with open(filepath) as f:
        for line in f:
            word, *vector = line.split()
            if word in vocab:
#                 idx = word_index[word] 
                embedding_matrix[index(vocab, word)] = tr.from_numpy(np.array(
                                        vector, dtype=np.float32))

    return embedding_matrix


def lookup_embeddings(text, vocab, embedding_matrix):
    
    embeddings = tr.zeros(len(text), embedding_matrix.shape[1])
    
    for iword, word in enumerate(text):
        for token in tokenizer(str(word)):
            if(token.lemma_.lower() in vocab):
                embeddings[iword] = embedding_matrix[index(vocab, word)]
                
    return embeddings

In [4]:
embedding_dim = 50
embedding_path = '{}/glove.6B/glove.6B.{}d.txt'.format(path_glove, embedding_dim)

embedding_matrix = create_embedding_matrix(embedding_path, vocab, embedding_dim)

In [5]:
A = ['male', 'man']
B = ['female', 'woman']
X = ['executive', 'management', 'professional', 'corporation', 'salary', 'office', 'business', 'career']
Y = ['home', 'parents', 'children', 'family', 'cousins', 'marriage']
#, 'wedding', 'relatives']



def word_attribute_association(w, A, B, vocab):
    #s(w,A,B) = mean_a cos(w,a) - mean_b cos(w,b)
    A_embed = lookup_embeddings(A, vocab, embedding_matrix)
    B_embed = lookup_embeddings(B, vocab, embedding_matrix)
    w_embed = lookup_embeddings(w,vocab, embedding_matrix)
    
    wA = np.dot(w_embed.numpy()/np.linalg.norm(w_embed, axis=1)[:,np.newaxis],
                (A_embed.numpy()/np.linalg.norm(A_embed, axis=1)[:,np.newaxis]).T).sum()
    wB = np.dot(w_embed.numpy()/np.linalg.norm(w_embed, axis=1)[:,np.newaxis],
                (B_embed.numpy()/np.linalg.norm(B_embed, axis=1)[:,np.newaxis]).T).sum()
    
    return wA/len(A) -  wB/len(B)

def test_statistic(A,B,X,Y, vocab):
    
    wA = 0
    
    for ix in X:
        wA += word_attribute_association([ix], A, B, vocab)
        
    wB = 0
    
    for iy in Y:
        wB -= word_attribute_association([iy], A, B, vocab)
        
    return wA+wB

def calculate_pvalue(A,B,X,Y, vocab):
    
    test_stat_orig = test_statistic(A,B,X,Y,vocab)
    
    union = set(X+Y)
    subset_size = len(union)//2
    
    larger = 0
    total = 0
    
    for subset in tqdm(set(itertools.combinations(union, subset_size))):
        total += 1
        Xi = list(set(subset))
        Yi = list(union - set(subset))
        if test_statistic(A, B, Xi, Yi,vocab) > test_stat_orig:
            larger += 1
    print('num of samples:', total)
    return larger/float(total)

In [None]:
p = calculate_pvalue(A,B,X,Y, vocab)
print(p)

HBox(children=(FloatProgress(value=0.0, max=3432.0), HTML(value='')))