In [2]:
!pip install -q pandas transformers tensorflow-hub faiss-gpu annoy torch torchvision

You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:85% !important; }</style>"))

In [3]:
import pandas as pd
import numpy as np
from urllib.request import urlopen
import multiprocessing
import json
#import tensorflow as tf
import torch
import matplotlib.pyplot as plt
from transformers import AutoModel, AutoTokenizer, ALL_PRETRAINED_CONFIG_ARCHIVE_MAP 
# from elasticsearch import Elasticsearch, helpers
# from elasticsearch_dsl import Search
import faiss
import annoy
from tqdm.notebook import tqdm_notebook as tqdmnb

In [4]:
def encode_row(row, tokenizer):
    argument = row['text']
    encoded_dicts = []
    while argument:
        encoded = tokenizer.encode_plus(argument, return_overflowing_tokens=True, max_length=tokenizer.max_len, pad_to_max_length=True, return_token_type_ids=False)
        argument = encoded.pop('overflowing_tokens', False)
        encoded.pop('num_truncated_tokens', 0)
        encoded_dicts.append(encoded)
    return encoded_dicts

def encode_dataframe(df):
    df['encoded'] = df.apply(lambda x: encode_row(x, tokenizer=tokenizer), axis=1)
    return df

def parallel_encode(df, encode_df, n_cores=multiprocessing.cpu_count()):
    df_chunk = np.array_split(df, n_cores)
    pool = multiprocessing.Pool(n_cores)
    df = pd.concat(tqdmnb(pool.imap(encode_df, df_chunk), total=n_cores))
    pool.close()
    pool.join()
    return df

In [5]:
%%time
with open('./Data/args-me.json') as f:
    d = json.load(f)
d = d['arguments']
context_subfields = [['context', k] for k in d[0]['context'].keys()]
dataset = pd.json_normalize(d, record_path='premises', meta=['id', 'conclusion', *context_subfields])
dataset.to_pickle('./dataset.pkl')
args = dataset[['id', 'text']].copy()

CPU times: user 20 s, sys: 1.52 s, total: 21.5 s
Wall time: 21.6 s


In [6]:
%%time
model_configs = []
for model_name, config_url in tqdmnb(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.items(), total=len(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP)):
    with urlopen(config_url) as url:
        json_config = url.read().decode()
        config = {**{'model_name':model_name}, **json.loads(json_config)}
        model_configs.append(config)
all_models = pd.DataFrame(model_configs)
all_models.set_index('model_name', inplace=True, drop=True)
all_models.to_pickle('./all_models.pkl')

HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


CPU times: user 857 ms, sys: 89.8 ms, total: 946 ms
Wall time: 38.8 s


In [16]:
tqdmnb.pandas()

### Choose a model

In [7]:
MODEL_TO_USE = 'bert-large-uncased'

In [8]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_USE)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=362.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [9]:
%%time
args = parallel_encode(args, encode_dataframe)

HBox(children=(FloatProgress(value=0.0, max=28.0), HTML(value='')))


CPU times: user 24.6 s, sys: 6.57 s, total: 31.2 s
Wall time: 8min 56s


In [13]:
args = args.explode('encoded')
args.reset_index(drop=True, inplace=True)

In [21]:
args[['input_ids','attention_mask']] = args['encoded'].progress_apply(pd.Series)[['input_ids','attention_mask']]

HBox(children=(FloatProgress(value=0.0, max=555583.0), HTML(value='')))




In [27]:
args.drop(columns=['encoded'], inplace=True)

In [33]:
args.to_pickle(f'args_encoded_{MODEL_TO_USE}.pkl')

In [35]:
with pd.option_context('display.max_colwidth', -1):
    display(args[:5])

Unnamed: 0,id,text,input_ids,attention_mask
0,c67482ba-2019-04-18T13:32:05Z-00000-000,"My opponent forfeited every round. None of my arguments were answered. I don’t like the idea of winning by default, but here we are.Tule: it’s good for students to get involved and address big issues like teen pregnancy. You need to be able to answer arguments like mine and not simply prepare for an abstinence-only type of response. You should also be aware that, in the U.S., condoms may be sold to minors in ANY state. A retailer who says it is illegal to sell you them is, frankly, wrong.","[101, 2026, 7116, 2005, 21156, 2098, 2296, 2461, 1012, 3904, 1997, 2026, 9918, 2020, 4660, 1012, 1045, 2123, 1521, 1056, 2066, 1996, 2801, 1997, 3045, 2011, 12398, 1010, 2021, 2182, 2057, 2024, 1012, 10722, 2571, 1024, 2009, 1521, 1055, 2204, 2005, 2493, 2000, 2131, 2920, 1998, 4769, 2502, 3314, 2066, 9458, 10032, 1012, 2017, 2342, 2000, 2022, 2583, 2000, 3437, 9918, 2066, 3067, 1998, 2025, 3432, 7374, 2005, 2019, 14689, 10196, 5897, 1011, 2069, 2828, 1997, 3433, 1012, 2017, 2323, 2036, 2022, 5204, 2008, 1010, 1999, 1996, 1057, 1012, 1055, 1012, 1010, 29094, 2089, 2022, 2853, 2000, 18464, 1999, 2151, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"
1,c67482ba-2019-04-18T13:32:05Z-00001-000,"How do you propose the school will fund your program? Condoms cost money and checking an ""opt out"" list before handing them out takes time away from staff members whenever they could be doing their actual jobs. Your ""opt out"" option is only be a token to parental authority and would be easily subverted. If everyone in school except a handful of students had access to free condoms, do you not think those students would simply ask their friends to provide them with condoms?","[101, 2129, 2079, 2017, 16599, 1996, 2082, 2097, 4636, 2115, 2565, 1029, 29094, 3465, 2769, 1998, 9361, 2019, 1000, 23569, 2041, 1000, 2862, 2077, 13041, 2068, 2041, 3138, 2051, 2185, 2013, 3095, 2372, 7188, 2027, 2071, 2022, 2725, 2037, 5025, 5841, 1012, 2115, 1000, 23569, 2041, 1000, 5724, 2003, 2069, 2022, 1037, 19204, 2000, 18643, 3691, 1998, 2052, 2022, 4089, 4942, 26686, 1012, 2065, 3071, 1999, 2082, 3272, 1037, 9210, 1997, 2493, 2018, 3229, 2000, 2489, 29094, 1010, 2079, 2017, 2025, 2228, 2216, 2493, 2052, 3432, 3198, 2037, 2814, 2000, 3073, 2068, 2007, 29094, 1029, 102, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, ...]"
2,c67482ba-2019-04-18T13:32:05Z-00002-000,"Schools have no compelling interest in providing contraceptives to students. The purpose of schools is not to provide healthcare nor to provide any other service except insofar as it relates to the furtherance of education [1,2,3], though I do not contest that individual districts ought to have this option if they so choose. As an educator, I do feel that adequate sexual education is a necessary. [1] tinyurl.com/z786mww[2] tinyurl.com/jafrt2n[3] tinyurl.com/zbkwkz6","[101, 2816, 2031, 2053, 17075, 3037, 1999, 4346, 24528, 28687, 2015, 2000, 2493, 1012, 1996, 3800, 1997, 2816, 2003, 2025, 2000, 3073, 9871, 4496, 2000, 3073, 2151, 2060, 2326, 3272, 16021, 11253, 2906, 2004, 2009, 14623, 2000, 1996, 2582, 6651, 1997, 2495, 1031, 1015, 1010, 1016, 1010, 1017, 1033, 1010, 2295, 1045, 2079, 2025, 5049, 2008, 3265, 4733, 11276, 2000, 2031, 2023, 5724, 2065, 2027, 2061, 5454, 1012, 2004, 2019, 11490, 1010, 1045, 2079, 2514, 2008, 11706, 4424, 2495, 2003, 1037, 4072, 1012, 1031, 1015, 1033, 4714, 3126, 2140, 1012, 4012, 1013, 1062, 2581, 20842, 2213, 2860, 2860, 1031, 1016, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"
3,c67482ba-2019-04-18T13:32:05Z-00003-000,"As a senior at my school. My group and I are focusing on teenage pregnancy; we are determined to have high school districts provide contraceptive forms to students to be safe about sex. This focus isn't for us to encourage to have sex, but if teenagers decide to have sex to please be safe about it. In addition. if parents do not agree to this we want to have an opt out form where their child/children will not be receiving this form.","[101, 2004, 1037, 3026, 2012, 2026, 2082, 1012, 2026, 2177, 1998, 1045, 2024, 7995, 2006, 9454, 10032, 1025, 2057, 2024, 4340, 2000, 2031, 2152, 2082, 4733, 3073, 24528, 28687, 3596, 2000, 2493, 2000, 2022, 3647, 2055, 3348, 1012, 2023, 3579, 3475, 1005, 1056, 2005, 2149, 2000, 8627, 2000, 2031, 3348, 1010, 2021, 2065, 12908, 5630, 2000, 2031, 3348, 2000, 3531, 2022, 3647, 2055, 2009, 1012, 1999, 2804, 1012, 2065, 3008, 2079, 2025, 5993, 2000, 2023, 2057, 2215, 2000, 2031, 2019, 23569, 2041, 2433, 2073, 2037, 2775, 1013, 2336, 2097, 2025, 2022, 4909, 2023, 2433, 1012, 102, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, ...]"
4,4d3d4471-2019-04-18T11:45:01Z-00000-000,"The resolution used by Pro *assumes* that Australia isn't already a 'significant' country - however, in actual reality, it is. Firstly we should clarify what significance means: 1.a the state or quality of being significant1.b of consequence or importance==================================To respond directly to Pros argument first:he/she asserts that Australia invented 'amazing things' like 'WiFI, Google Maps, Polymer bank notes, Ultrasound scanners, stainless steel braces and many more things'. Now, if these inventions did come from Australia, then, it can be considered a 'significant country' - as a country which is the home of some of the most universally-used inventions in the 21st century. It would seem that Pro was/is trying to argue that Australia simply deserves more recognition, in which case he/she should proposed that instead of stating that it should be more 'significant' - because the examples that Pro themselves has listed, fully go against that. Instead of affirming the resolution, Pro has negated it. After all, insignificant countries do not invent things such as WiFi or Google Maps. One invention listed by Pro that I will take issue with though is the ultrasound, as this was not invented in Australia. Its first use is thought to have been in Austria, which is a country in Europe. Its technology developed from there. 'The use of ultrasound in medicine began during and shortly after the 2nd World War in various centres around the world. The work of Dr. Karl Theodore Dussik in Austria in 1942 on transmission ultrasound investigation of the brain provides the first published work on medical ultrasonics.' 'Although other workers in the USA, Japan and Europe have also been cited as pioneers, the work of Professor Ian Donald and his colleagues in Glasgow, in the mid 1950s, did much to facilitate the development of practical technology and applications.'[1.] https://www.bmus.org...","[101, 1996, 5813, 2109, 2011, 4013, 1008, 15980, 1008, 2008, 2660, 3475, 1005, 1056, 2525, 1037, 1005, 3278, 1005, 2406, 1011, 2174, 1010, 1999, 5025, 4507, 1010, 2009, 2003, 1012, 15847, 2057, 2323, 25037, 2054, 7784, 2965, 1024, 1015, 1012, 1037, 1996, 2110, 2030, 3737, 1997, 2108, 3278, 2487, 1012, 1038, 1997, 9509, 2030, 5197, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 1027, 2000, 6869, 3495, 2000, 4013, 2015, 6685, 2034, 1024, 2002, 1013, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"


In [403]:
testitest = parallel_encode(args[:2000], encode_dataframe, 4)

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




In [30]:
tokenizer.tokenize('reiterate')

['rei', '##tera', '##te']

In [18]:
%%time
args.loc[:, 'chunks'] = args.apply(lambda x: [x.text[i:i+510] for i in range(0, len(x.text), 510)] , axis=1)
args.loc[:,'chunks'].apply(pd.Series) \
    .merge(args, right_index=True, left_index=True) \
    .drop(['chunks'], axis=1) \
    .melt(id_vars=['text', 'id'], value_name="chunks") \
    .drop('variable', axis=1) \
    .dropna()

CPU times: user 3min 1s, sys: 4.07 s, total: 3min 5s
Wall time: 3min 5s


Unnamed: 0,text,id,chunks
0,My opponent forfeited every round. None of my ...,c67482ba-2019-04-18T13:32:05Z-00000-000,My opponent forfeited every round. None of my ...
1,How do you propose the school will fund your p...,c67482ba-2019-04-18T13:32:05Z-00001-000,How do you propose the school will fund your p...
2,Schools have no compelling interest in providi...,c67482ba-2019-04-18T13:32:05Z-00002-000,Schools have no compelling interest in providi...
3,As a senior at my school. My group and I are f...,c67482ba-2019-04-18T13:32:05Z-00003-000,As a senior at my school. My group and I are f...
4,The resolution used by Pro *assumes* that Aust...,4d3d4471-2019-04-18T11:45:01Z-00000-000,The resolution used by Pro *assumes* that Aust...
...,...,...,...
78972948,Border Terror DA Links Border Border surveilla...,e98fe508-2019-04-18T14:13:32Z-00005-000,"hat""s the No. 1 reason people come to this cou..."
79360640,Border Terror DA Links Border Border surveilla...,e98fe508-2019-04-18T14:13:32Z-00005-000,"rs help spur economic growth. It""s a position ..."
79748332,Border Terror DA Links Border Border surveilla...,e98fe508-2019-04-18T14:13:32Z-00005-000,o connect with the middle class and show that ...
80136024,Border Terror DA Links Border Border surveilla...,e98fe508-2019-04-18T14:13:32Z-00005-000,"the needs of the agricultural industry, would ..."


In [8]:
c = dataset['context.sourceTitle'].str.split(":").apply(lambda x: x[0])

In [9]:
d = dataset['context.sourceTitle'].str.split("-").apply(lambda x: x[-1])

In [16]:
d.value_counts()[:50]

 Debatepedia                                                                                      18760
 DebateWise                                                                                       14353
Online Debate: Abortion | Debate.org                                                                520
Debate Topic: Abortion | Debate.org                                                                 495
Debate: Abortion | Debate.org                                                                       489
Debate Issue: Abortion | Debate.org                                                                 469
Debate Argument: Abortion | Debate.org                                                              433
International Debate Education Association (IDEA)                                                   274
 Debatepedia, Debate on Universal Health Care                                                       269
 Debatepedia, Debate on Capital Punishment                      

In [18]:
c.value_counts()[:5].sum()

359817

In [6]:
def document_generator_from_dataframe(df, index, fields_to_index):
    for _, row in df.iterrows():
        row_as_dict = row.replace('', 'empty').to_dict()
        yield {
            "_index": index,
            "_id": row['id'],
            "_source": {k: row_as_dict[k] for k in fields_to_index}
        }

In [29]:
es = Elasticsearch(http_compress=True, maxsize=500) 
gen = document_generator_from_dataframe(dataset, "arg_index", ['text', 'stance', 'context.sourceId', 'conclusion'])
body = {
    'settings' : {
        'similarity' : {
            'my_dirichlet' : {'type': 'LMDirichlet' }
        }
    },
    
    'mappings': {
        
        'properties' : {
            
            'text':             {'type': 'text', 'similarity': 'my_dirichlet'},
            'stance':           {'type': 'keyword'},
            'context.sourceId': {'type': 'keyword'},
            'conclusion':       {'type': 'text'}
        }
    }
}

es.indices.create(index='arg_index', body=body)

{'acknowledged': True, 'shards_acknowledged': True, 'index': 'arg_index'}

In [30]:
%%time
helpers.bulk(es, gen)

CPU times: user 2min 48s, sys: 210 ms, total: 2min 48s
Wall time: 4min 9s


(387692, [])

In [33]:
es.indices.get_mapping(index="arg_index")

{'arg_index': {'mappings': {'properties': {'conclusion': {'type': 'text'},
    'context': {'properties': {'sourceId': {'type': 'keyword'}}},
    'stance': {'type': 'keyword'},
    'text': {'type': 'text', 'similarity': 'my_dirichlet'}}}}}

In [53]:
s = Search(using=es, index="arg_index").query("match", text="donald trump good president")
response = s.execute()

In [63]:
for hit in response:
    print(hit.conclusion)

Donald Trump Will Be a Good President
Donald Trump More Like Donald Dump !
Hilary Clinton vs Donald Trump
Trump is a better candidate than Clinton for President
Trump is a better candidate than Clinton for President
Donald Trump will most likely run for president in 2020
President Trump is a good President.
Donald Trump should resign from the presidential race
The United States Federal Government ought to begin the process to impeach Donald Trump
Donald Trump shouldn't be president.


## Turn args dataframe into input tokens and masks

In [36]:
teststring = 'Do you think gaming is harmful to our youth'

In [37]:
tokenizer.encode_plus(teststring)

{'input_ids': [101,
  2079,
  2017,
  2228,
  10355,
  2003,
  17631,
  2000,
  2256,
  3360,
  102],
 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [38]:
dataset[dataset['id']=='fee6cd21-2019-04-18T16:05:41Z-00001-000']

Unnamed: 0,text,stance,id,conclusion,context.sourceId,context.previousArgumentInSourceId,context.acquisitionTime,context.discussionTitle,context.sourceTitle,context.sourceUrl,context.nextArgumentInSourceId
334820,,PRO,fee6cd21-2019-04-18T16:05:41Z-00001-000,Things Are Only Offensive If You MAKE Them Off...,fee6cd21-2019-04-18T16:05:41Z,fee6cd21-2019-04-18T16:05:41Z-00000-000,2019-04-18T16:05:41Z,Things Are Only Offensive If You MAKE Them Off...,Debate Issue: Things Are Only Offensive If You...,https://www.debate.org/debates/Things-Are-Only...,fee6cd21-2019-04-18T16:05:41Z-00002-000


In [67]:
args.to_pickle('./albert.pkl')

In [19]:
%%time
args = pd.read_pickle('./albert.pkl')
i = np.stack(args['input_ids'])
m = np.stack(args['attention_mask'])

CPU times: user 9.13 s, sys: 2.35 s, total: 11.5 s
Wall time: 11.5 s


In [68]:
i = np.stack(args['input_ids'])
m = np.stack(args['attention_mask'])

In [12]:
test = args[10000:20000].copy()

In [43]:
%%time
dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(i), tf.data.Dataset.from_tensor_slices(m)))
dataset = dataset.batch(10)

CPU times: user 40.1 ms, sys: 28 ms, total: 68.1 ms
Wall time: 66.3 ms


In [44]:
outputs = []
for inputs, masks in tqdm.notebook.tqdm_notebook(dataset, total=1000):
    outputs.append(model(inputs, attention_mask=masks))

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [143]:
arg_representations = np.vstack([thing[0][:,0,:].cpu().numpy() for thing in outputs])
arg_representations.shape

(6601, 768)

In [44]:
gpu = torch.device('cuda')

In [161]:
from time import sleep
import gc

In [None]:
outputs = []
for chunk_i, chunk_m in tqdmnb(zip(np.array_split(i, 156), np.array_split(m, 156))):
    for batch_i, batch_m in zip(np.array_split(chunk_i, 2400), np.array_split(chunk_m, 2400)):
        tensor_i = torch.tensor(batch_i).to(gpu)
        tensor_m = torch.tensor(batch_m).to(gpu)
        print(batch

HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1240.0), HTML(value='')))





In [184]:
arg_representations = np.vstack([thing for thing in outputs])
arg_representations.shape

(387692, 768)

In [187]:
%%time
np.save('arg_representations', arg_representations)

CPU times: user 6.76 ms, sys: 739 ms, total: 746 ms
Wall time: 861 ms


In [39]:
%%time
arg_representations = np.load('arg_representations_bert-large-uncased.npy')

CPU times: user 0 ns, sys: 439 ms, total: 439 ms
Wall time: 437 ms


In [40]:
arg_representations

array([[ 0.22579737, -0.07349378, -1.0474814 , ..., -0.9103607 ,
        -0.20819543, -0.17666322],
       [-0.1498387 , -0.4468132 , -0.99795085, ..., -1.2222996 ,
        -0.06732306,  0.26513594],
       [ 0.47659963, -1.0098528 , -0.6905463 , ..., -0.86388797,
        -0.432725  ,  0.15581201],
       ...,
       [-0.08755945, -0.49159864, -0.62298226, ..., -0.6357609 ,
        -0.41940513,  0.25646588],
       [ 0.0292035 , -0.8047588 , -0.3333864 , ..., -0.4762738 ,
        -0.7674418 ,  0.17478399],
       [-0.35893157, -0.23544711, -1.0336562 , ..., -0.09143204,
        -0.5442716 ,  0.21634527]], dtype=float32)

In [9]:
arg_representations.shape

(387692, 768)

In [22]:
d = arg_representations.shape[1]
xb = arg_representations
m=16
n_bits=8
ids = args.id.values

In [23]:
pq = faiss.IndexPQ(d, m, n_bits)

In [29]:
pq.train(xb)

In [31]:
pq.add(xb)

In [34]:
pq.is_trained

True

In [27]:
type(ids)

numpy.ndarray

In [9]:
index_flat = faiss.IndexFlatL2(d)

In [10]:
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat)

In [11]:
gpu_index_flat.add(xb)         # add vectors to the index
print(gpu_index_flat.ntotal)

100000


In [101]:
%%time
query = "donald trump is an idiot"

tokenized_query = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(query))
encoded_query = tokenizer.prepare_for_model(tokenized_query, max_length=512, add_special_tokens=True, pad_to_max_length=True)

inp, mask = encoded_query['input_ids'], encoded_query['attention_mask']
inp, mask = torch.tensor(inp).unsqueeze(0), torch.tensor(mask).unsqueeze(0)

with torch.no_grad():
    query_albert = model(inp, attention_mask=mask)[0][:,0,:].numpy()

D, I = pq.search(query_albert, 100)

indices = list(I.squeeze())
args[args.index.isin(indices)]

CPU times: user 3.77 s, sys: 91.1 ms, total: 3.86 s
Wall time: 292 ms


Unnamed: 0,text,id
14151,Happiness is a state of mind,471dd6da-2019-04-18T16:38:43Z-00000-000
21828,testing testing,4e982cac-2019-04-18T19:07:20Z-00003-000
27100,agnostics dont believe in judaism,e71fe4c6-2019-04-18T12:56:35Z-00003-000
31896,Rascist jokes are not the best type of jokes,e6c3351d-2019-04-18T16:47:28Z-00003-000
41309,Why television isn't a bad influence on students,dbfbac49-2019-04-18T17:38:23Z-00003-000
...,...,...
380685,Bible supports incest and embraces its practice,e4848164-2019-04-17T11:47:33Z-00008-000
381674,General statements in favor of saying merry Ch...,7e44569e-2019-04-17T11:47:28Z-00010-000
383605,"Oil sands can't compete w/ cleaner, abundant, ...",4d2e82ff-2019-04-17T11:47:25Z-00027-000
386198,Alternatives to body scanners don't sacrifice ...,91a1b22c-2019-04-17T11:47:28Z-00005-000


In [83]:
np.ceil(0.01)

1.0

In [59]:
%%time


CPU times: user 3.01 s, sys: 254 ms, total: 3.27 s
Wall time: 238 ms


Unnamed: 0,text,id
1783,I acceptUsing Superman Prime,5010a680-2019-04-18T16:07:20Z-00005-000
2053,Meow,fdb8a06-2019-04-18T13:23:28Z-00004-000
2817,my grandson's name is jay,6c49a463-2019-04-18T11:06:28Z-00002-000
4049,i hate baseball,80eddc97-2019-04-18T17:28:50Z-00005-000
8550,suck on that,3f45dc38-2019-04-18T14:13:23Z-00003-000
...,...,...
347572,Focused leadership,a329a7d5-2019-04-15T20:24:40Z-00002-000
354778,whats a vegan,76de8010-2019-04-19T12:45:43Z-00000-000
357438,A more interesting story,6d9e359f-2019-04-19T12:46:09Z-00006-000
362782,Don't be lazy,15aaf549-2019-04-19T12:45:44Z-00005-000


In [94]:
I

array([[106301, 248985, 270505,  90806, 291383, 357438, 151683, 185774,
        260158,  38933, 275982, 159128, 312865, 308493, 269036,  73621,
          1783, 185812,  48774, 333649,   2053, 304266, 151684,  59045,
        177953, 347572,   8668,   4049, 125624,  12039, 287536, 364951,
        225822,  52207,  11852, 220305, 150934,  44187,  55367, 187220,
        362782, 218918, 110544, 215830, 148715, 174082, 235235, 268444,
        155960, 296578, 220692, 204920,  74964,  12805,  30757, 333116,
         90038,  43297, 206372, 116109,  86306, 319908, 170605,  95085,
        310911,  50713, 263836,  99215, 195672, 287491, 304133, 155247,
        245281, 141766,  69033, 169931, 100872, 258272,  41220, 211380,
          8550,  70952,  29060, 248430, 314853,  49053, 134757, 237550,
        354778, 200986,  38982,  95528, 180625, 273265,   2817, 153844,
        133807, 321155, 291385,  44286]])

In [12]:
k = 4                          # we want to see 4 nearest neighbors
D, I = gpu_index_flat.search(xq, k)  # actual search

In [65]:
%%timeit
model_tf(test_i, attention_mask=test_m)

179 ms ± 7.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [61]:
tensor_test_i = tf.convert_to_tensor(test_i)
tensor_test_m = tf.convert_to_tensor(test_m)

In [62]:
%%timeit
model_tf(tensor_test_i, attention_mask=tensor_test_m)

100 ms ± 2.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
