In [16]:
from tqdm import tqdm

In [9]:
import gensim.downloader as api

model = api.load("word2vec-google-news-300")

In [17]:
from datasets import load_dataset

ds = load_dataset("microsoft/ms_marco", "v1.1")

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
ds.keys()

dict_keys(['validation', 'train', 'test'])

In [19]:
ds['train'].column_names

['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers']

In [20]:
ds['train'][0]

{'answers': ['Results-Based Accountability is a disciplined way of thinking and taking action that communities can use to improve the lives of children, youth, families, adults and the community as a whole.'],
 'passages': {'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
  'passage_text': ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.",
   "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the C

In [21]:
ds['train'][0]['passages']['passage_text']

["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.",
 "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.",
 'RBA R

In [22]:
len(ds['train'])

82326

In [23]:
train = ds['train']
valid = ds['validation']
test = ds['test']
print(len(train), len(valid), len(test))


82326 10047 9650


In [24]:
import random
#random seed
random.seed(42)

In [25]:
ds['train'][1]['passages']['url']

['http://www.history.com/topics/us-presidents/ronald-reagan',
 'https://en.wikipedia.org/wiki/Reagan_Democrat',
 'http://www.answers.com/Q/Was_Ronald_Reagan_a_republican_or_a_democrat',
 'https://en.wikipedia.org/wiki/Ronald_Reagan',
 'http://www.msnbc.com/the-last-word/watch/when-reagan-was-a-liberal-democrat-219696195576',
 'http://www.history.com/topics/us-presidents/ronald-reagan',
 'http://www.biography.com/people/ronald-reagan-9453198']

In [26]:
train_triples = []
for i in tqdm(range(0,len(train))):
    query = ds['train'][i]['query']
    for k, passage in enumerate(ds['train'][i]['passages']['passage_text']):
        # print(passage)
        sample = {}
        sample['query'] = query
        sample['positive'] = passage
        sample['positive_url'] = ds['train'][i]['passages']['url'][k]
        while True:
            random_ind = random.randint(0, len(ds['train'])-1)
            if random_ind != i:
                break

        negatives = ds['train'][random_ind]['passages']['passage_text']
        #make random selection of these passages
        sample['negative'] = random.choice(negatives)
        sample['negative_url'] = ds['train'][random_ind]['passages']['url'][negatives.index(sample['negative'])]
        train_triples.append(sample)

100%|██████████| 82326/82326 [01:31<00:00, 903.78it/s]


In [27]:
#save the train triples
import json
json.dump(train_triples, open('train_triples_v1.1.json', 'w'))

In [47]:
json.dump(train_triples[:5], open('train_triples_sample.json', 'w'))

In [28]:
#validation triples
valid_triples = []
for i in tqdm(range(len(valid))):
    query = ds['validation'][i]['query']
    for k, passage in enumerate(ds['validation'][i]['passages']['passage_text']):
        # print(passage)
        sample = {}
        sample['query'] = query
        sample['positive'] = passage
        sample['positive_url'] = ds['validation'][i]['passages']['url'][k]
        while True:
            random_ind = random.randint(0, len(ds['validation'])-1)
            if random_ind != i:
                break

        negatives = ds['validation'][random_ind]['passages']['passage_text']
        #make random selection of these passages
        sample['negative'] = random.choice(negatives)
        sample['negative_url'] = ds['validation'][random_ind]['passages']['url'][negatives.index(sample['negative'])]
        valid_triples.append(sample)

100%|██████████| 10047/10047 [00:11<00:00, 899.52it/s]


In [None]:
import json
json.dump(valid_triples, open('valid_triples_v1.1.json', 'w'))

## Make the vocab

In [19]:
from utils import tokenize

In [35]:
from nltk.stem import PorterStemmer
import constants

punctuation_map = constants.punctuation_map
tokenize("I went to the market today and I haven't left yet or found a dog's toy", 
         punctuation_map=punctuation_map, stemmer=PorterStemmer(), junk_punctuations=True)

['went', 'market', 'today', 'left', 'yet', 'found', 'dog', 'toy']

In [36]:
queries_t = []
for i in tqdm(range(0,len(train))):
    queries_t.append(train[i]['query'])

for i in tqdm(range(0,len(valid))):
    queries_t.append(valid[i]['query'])

for i in tqdm(range(0,len(test))):
    queries_t.append(test[i]['query'])

queries_t = list(set(queries_t))
queries_t = ' '.join(queries_t)
queries_words = tokenize(queries_t, punctuation_map=punctuation_map, stemmer=PorterStemmer(), junk_punctuations=True)

100%|██████████| 808731/808731 [00:38<00:00, 20968.18it/s]
100%|██████████| 101093/101093 [00:04<00:00, 21294.22it/s]
100%|██████████| 101092/101092 [00:04<00:00, 21156.48it/s]


In [37]:
#dump all words to a file
import json
json.dump(list(set(queries_words)), open('queries_words.json', 'w'))

In [39]:
len(set(queries_words))

118621

In [40]:
passages_words = []
for i in tqdm(range(0,len(train))):
    passages = ' '.join((train[i]['passages']['passage_text']))
    w = tokenize(passages, {})
    passages_words.extend(w)
#dump all words to a file
json.dump(list(set(passages_words)), open('passages_words.json', 'w'))
for i in tqdm(range(0,len(valid))):
    passages = ' '.join((train[i]['passages']['passage_text']))
    w = tokenize(passages, punctuation_map={}, stemmer=PorterStemmer())
    passages_words.extend(w)
#dump all words to a file
json.dump(list(set(passages_words)), open('passages_words.json', 'w'))
for i in tqdm(range(0,len(test))):
    passages = ' '.join((train[i]['passages']['passage_text']))
    w = tokenize(passages, {})
    passages_words.extend(w)
#dump all words to a file
json.dump(list(set(passages_words)), open('passages_words.json', 'w'))

100%|██████████| 808731/808731 [07:13<00:00, 1864.18it/s]
100%|██████████| 101093/101093 [03:56<00:00, 426.86it/s]
100%|██████████| 101092/101092 [00:52<00:00, 1921.06it/s]


In [42]:
len(passages_words)

322286825

In [1]:
import json

qwords = json.load(open('queries_words.json'))

In [2]:
pwords = json.load(open('passages_words.json'))

In [3]:
len(pwords)

1225032

In [4]:
all_words = list(set(qwords + pwords))

In [5]:
print(len(all_words))

1244024


In [6]:
word_to_ids = {v: i+1 for i,v in enumerate(all_words)}
word_to_ids['<unk>'] = 0

In [7]:
idx_to_word = {v: k for k, v in word_to_ids.items()}

In [10]:
import numpy as np
embeds = {}
oov = 0
for v, ind in word_to_ids.items():
    if v in model:
        embeds[ind] = model[v]
    else:
        oov += 1
        embeds[ind] = np.zeros(300)

In [11]:
oov/len(word_to_ids)

0.8901565483008782

In [32]:
list(word_to_ids.keys())[:5]

['wce', 'nterjections', 'psychotoxic', 'polyphobia', 'gourka']

In [14]:
#save all
import joblib

joblib.dump(embeds, 'embeds.pkl')
joblib.dump(word_to_ids, 'word_to_ids.pkl')
joblib.dump(idx_to_word, 'idx_to_word.pkl')

['idx_to_word.pkl']

In [15]:
import joblib
word_to_ids = joblib.load('word_to_ids.pkl')