In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [2]:
from itertools import chain
from pplm_utils import *
from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer, BertTokenizer, BertForMaskedLM

In [3]:
query = "Should birth control pills be available over the counter?"

# Using PPLM

In [4]:
LM_MODEL_TO_USE = "gpt2"
config = AutoConfig.from_pretrained(LM_MODEL_TO_USE)
config.output_hidden_states = True
tokenizer = AutoTokenizer.from_pretrained(LM_MODEL_TO_USE)
lm = AutoModelWithLMHead.from_pretrained(LM_MODEL_TO_USE, config=config)
lm.eval()    
for param in lm.parameters():
    param.requires_grad = False

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=224.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=548118077.0, style=ProgressStyle(descri…




In [5]:
tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + query)

In [19]:
%%time
unpert_gen_tok_text, pert_gen_tok_texts, grad_norms, loss_per_iter = full_text_generation(model=lm, tokenizer=tokenizer, context=tokenized_cond_text, bag_of_words='arg_bow', length=100, stepsize=0.04, temperature=1.5, 
                                                                                          top_k=10, num_iterations=4, grad_length=10000, horizon_length=1, gm_scale=0.85, kl_scale=0.03, repetition_penalty=1.5, gamma=1.8,
                                                                                          no_cuda=True, device="cpu")

CPU times: user 20min 50s, sys: 15.6 s, total: 21min 6s
Wall time: 1min 49s


In [20]:
print(tokenizer.decode(unpert_gen_tok_text[0][len(tokenized_cond_text):]))


The answer is no. The FDA has not yet determined whether or how many women will need to take them, but it's expected to begin taking more in 2015 and 2016 than before that date (the agency says there are about 1 million people who have taken a pill since 2000). And while some doctors say they're concerned with unintended pregnancies because of their side effects — such as nausea after using an oral contraceptive for years without any problems—there aren't enough studies on this topic at present; most


In [21]:
print(tokenizer.decode(pert_gen_tok_texts[0][0][len(tokenized_cond_text):]))


The American Academy of Pediatrics (AAP) has issued a statement on this issue. The AAP's position is that "the use and availability" or lack thereof of contraceptives should not constitute an emergency contraceptive, but rather as part to prevent pregnancy." This means it would have been better for women who were pregnant if they could get their pill without having unprotected sex with someone else in order make sure there was no risk associated from using them at all:

  http://www-aapnewsletter


# Using Transformer decoders with an Language Model head

In [4]:
LM_MODEL_TO_USE = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(LM_MODEL_TO_USE)
lm = AutoModelWithLMHead.from_pretrained(LM_MODEL_TO_USE)

In [8]:
num_return_sequences = 5
query = "I am trying to think of good arguments. Can you help me? What do you think?"+query 
input_context_pro = [
    "-"+query+"\n- Yes because",
    query+"The answer is yes."
]
input_context_con = [
    "-"+query+"\n- No because",
    query+"The answer is no."
]


In [6]:
%%time

hallucinated_greedy = []

for j, input_context in enumerate(chain(*[input_context_pro, input_context_con])):
    input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0)  
    L = len(input_ids[0])
    outputs = lm.generate(max_length=100, input_ids=input_ids, do_sample=False, num_beams=10, top_k=100 , top_p=0.4, num_return_sequences=1, temperature=1.6, repetition_penalty=20)
    for i in range(1):
        print('')
        print(f'Greedilty hallucinated for query {j}: {tokenizer.decode(outputs[i][L:], skip_special_tokens=True)}')
        print('')
        hallucinated_greedy.append(tokenizer.decode(outputs[i][L:], skip_special_tokens=True))


Hallucinated:  they are safe and effective. However, there is no scientific evidence to support their use as an alternative method of contraception for women who don't want or need them (e:g., those with premenstrual syndrome). There has been some debate about whether such a pill would have any effect on pregnancy outcomes in this population; however research suggests that it may reduce risk factors associated at least partially by inhibiting ovulation during


Hallucinated:  According to a new study published in The American Journal of Obstetrics and Gynecology, more than one-third (35%) women who have had an abortion are still using them at some point during pregnancy because they don't want to miss out on their chance for health insurance coverage if it's not covered by Medicaid or other government programs such as Social Security Disability Insurance."If you're pregnant with your first child," says


Hallucinated:  there is no evidence that they are safe or effective in preventing

In [11]:
num_return_sequences = 5
input_context_pro = [
    "-"+query+"\n- Yes because",
    query+"The answer is yes."
]
input_context_con = [
    "-"+query+"\n- No because",
    query+"The answer is no."
]

In [13]:
%%time

hallucinated_sampling = []

for j, input_context in enumerate(chain(*[input_context_pro, input_context_con])):
    input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0)  # encode input context
    L = len(input_ids[0])
    outputs = lm.generate(max_length=65, input_ids=input_ids, do_sample=True, num_beams=10, top_k=100 , top_p=0.4, num_return_sequences=num_return_sequences, temperature=1.6, repetition_penalty=20)
    for i in range(num_return_sequences): 
        print(" ")
        print(f'Hallucinated {i+1} for query {j+1}: {tokenizer.decode(outputs[i][L:], skip_special_tokens=True)}')
        print(" ")
        hallucinated_sampling.append(tokenizer.decode(outputs[i][L:], skip_special_tokens=True))

 
Hallucinated 1 for query 1:  they're an option, I have been in a relationship for years and still feel that way about them but this doesn't mean anything without some kind (but
 
 
Hallucinated 2 for query 1:  they are very easy, if not impossible... You will have plenty at home and in your office so we don't need any problems with women going out into
 
 
Hallucinated 3 for query 1:  it can cause infertility, and I've had some doctors say no but they'd prefer for a small dose on its own anyway so that we could keep an
 
 
Hallucinated 4 for query 1:  they are so popular and I want my children with no side effects at all! They work great in pregnancy, it helps your baby's brain cells become stronger
 
 
Hallucinated 5 for query 1:  they work so well for my breasts that I just can't go anywhere without one on and off every day - this is why people have tried them (they
 
 
Hallucinated 1 for query 2:  The only way I can get a contraceptive in that amount and not have one thrown at my 

In [132]:
lm.generate(input_ids=input_ids, do_sample=True, top_k=100)

tensor([[   12,    40,   716,  2111,   284,   892,   286,   922,  7159,    13,
          1680,   345,  1037,   502,    30,  1867,   466,   345,   892,    30,
            40,   716,  2111,   284,   892,   286,   922,  7159,    13,  1680,
           345,  1037,   502,    30,  1867,   466,   345,   892,    30,    40,
           716,  2111,   284,   892,   286,   922,  7159,    13,  1680,   345,
          1037,   502,    30,  1680,   345,   892,   286,   922,  7159,    30,
         10358,  4082,  1630, 19521,   307,  1695,   625,   262,  3753,    30,
           198,    12,  3363,   780]])

# Using Transformer encoders with a Masked Language Model head

Retrieve a list of stopwords by computing the union of the respective stop word lists of Scikit-learn, Spacy and NLTK

In [38]:
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS as SKLEARN_STOPWORDS
SKLEARN_STOPWORDS = set(SKLEARN_STOPWORDS)

from spacy.lang.en.stop_words import STOP_WORDS as SPACY_STOPWORDS

import nltk
from ipywidgets import Output
out = Output()
with out:
    nltk.download('stopwords')
from nltk.corpus import stopwords
NLTK_STOPWORDS = set(stopwords.words('english'))

STOP_WORDS = set.union(*[SKLEARN_STOPWORDS, SPACY_STOPWORDS, NLTK_STOPWORDS])
from string import punctuation

In [39]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertForMaskedLM.from_pretrained('bert-large-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=362.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1344997306.0, style=ProgressStyle(descr…




In [75]:
query = "Should churches be taxed?"
query = "I can't think of any arguments, can you help me? "+ query

input_context_pro = [
    '-'+query+'\n-Yes, because of [MASK] and the benefits of [MASK] [MASK].',
    '-'+query+'\n-Absolutely, I think [MASK] is good!.',
    "-"+query+"\n-Yes, [MASK] is associated with [MASK] during [MASK]."
    
]



input_context_con = [
    '-'+query+'\n-No, because of [MASK] and the risk of [MASK] [MASK].',
    '-'+query+'\n-Absolutely not, I think [MASK] is bad!.',
    "-"+query+"\n-No, [MASK] is associated with [MASK] during [MASK]."
]

input_context_neutral = [
    query+' What about [MASK] or [MASK]?',
    query+" Don't forget about [MASK]!"
    
]

In [76]:
%%time
hallucinations = []
for input_context in chain(*[input_context_pro, input_context_con, input_context_neutral]):
    inp_tens = torch.tensor(tokenizer.encode(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_context)))).unsqueeze(0)
    mask_indices = np.nonzero(inp_tens.squeeze()==103).squeeze()
    preds = model(inp_tens)[0].squeeze()
    mask_indices = [mask_indices.tolist()] if type(mask_indices.tolist())!=list else mask_indices.tolist()
    top_words = []
    for i in mask_indices:
        top_words.append(torch.topk(preds[i], k=5))
    words = []
    for mask_topk in top_words:
        for token in mask_topk.indices.tolist():
            words.append(tokenizer.decode(token, clean_up_tokenization_spaces=True))
        #Interestingly, BERT was returning the ##carriage subword, obviously part of "miscarriage" in the "pregnancy" context. Further investigation needed to see how to return the full of word. Filtering out subwords for now.
        #Filter out subwords
        words = [word.replace(" ","") for word in words if not word.startswith("#")]
        words = [word for word in words if not word.endswith("#")]
        #Filter out punctuation
        words = [word for word in words if not word in punctuation]
    words = set(words)
    words = list(words.difference(STOP_WORDS))
    hallucinations.extend(words)
hallucinations = set(hallucinations)
hallucinations = list(hallucinations)

CPU times: user 35.6 s, sys: 890 ms, total: 36.5 s
Wall time: 2.75 s


In [74]:
hallucinations

['religion',
 'poverty',
 'attendance',
 'public',
 'economics',
 'land',
 'money',
 'prayer',
 'death',
 'church',
 'consumption',
 'easter',
 'good',
 'free',
 'tax',
 'lent',
 'churches',
 'property',
 'costs',
 'mass',
 'education',
 'exposure',
 'financial',
 'christmas',
 'schools',
 'festivals',
 'work',
 'women',
 'taxes',
 'religious',
 'unemployment',
 'bankruptcy',
 'legal',
 'services']

In [77]:
hallucinations

['religion',
 'poverty',
 'faith',
 'land',
 'money',
 'prayer',
 'death',
 'cost',
 'church',
 'conversion',
 'worship',
 'free',
 'tax',
 'lent',
 'legal',
 'collapse',
 'churches',
 'property',
 'costs',
 'mass',
 'wartime',
 'housing',
 'living',
 'education',
 'financial',
 'christmas',
 'schools',
 'tourism',
 'festivals',
 'taxes',
 'religious',
 'attendance',
 'corruption',
 'services']

In [96]:
query = "Should the Federal Minimum Wage Be Increased?"

In [99]:
test = ["-"+query+"\n-Yes, it's not too late because of [MASK] [MASK]"]
hallucinations = []
for input_context in chain(*[test]):
    inp_tens = torch.tensor(tokenizer.encode(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_context)))).unsqueeze(0)
    mask_indices = np.nonzero(inp_tens.squeeze()==103).squeeze()
    preds = model(inp_tens)[0].squeeze()
    mask_indices = [mask_indices.tolist()] if type(mask_indices.tolist())!=list else mask_indices.tolist()
    top_words = []
    for i in mask_indices:
        top_words.append(torch.topk(preds[i], k=10))
    words = []
    for mask_topk in top_words:
        for token in mask_topk.indices.tolist():
            words.append(tokenizer.decode(token, clean_up_tokenization_spaces=True))
        #Interestingly, BERT was returning the ##carriage subword, obviously part of "miscarriage" in the "pregnancy" context. Further investigation needed to see how to return the full of word. Filtering out subwords for now.
        #Filter out subwords
        words = [word.replace(" ","") for word in words if not word.startswith("#")]
        words = [word for word in words if not word.endswith("#")]
        #Filter out punctuation
        words = [word for word in words if not word in [*punctuation, "..."]]
    words = set(words)
    words = list(words.difference(STOP_WORDS))
    hallucinations.extend(words)
hallucinations = set(hallucinations)
hallucinations = list(hallucinations)
hallucinations

['inflation', 'labor', 'taxes', 'immigration', 'congress']

In [58]:
hallucinations

['wars',
 'pollution',
 'death',
 'consumption',
 'climate',
 'risk',
 'water',
 'development',
 'warming',
 'growth',
 'summer',
 'energy',
 'emissions',
 'migration',
 'winter',
 'construction',
 'health',
 'deaths',
 'drought',
 'mortality',
 'pregnancy',
 'wartime']

In [189]:
input_context = "-"+query+"\n-But what about [MASK]?"
inp_tens = torch.tensor(tokenizer.encode(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_context)))).unsqueeze(0)
mask_indices = np.nonzero(inp_tens.squeeze()==103).squeeze()
preds = model(inp_tens)[0].squeeze()
mask_indices = [mask_indices.tolist()] if type(mask_indices.tolist())!=list else mask_indices.tolist()
top_words = []
for i in mask_indices:
    top_words.append(torch.topk(preds[i], k=20))
words = []
tokens = []
for mask_topk in top_words:
    for token in mask_topk.indices.tolist():
        tokens.append(token)
        words.append(tokenizer.decode(token, clean_up_tokenization_spaces=True))
    words = [word.replace(" ","") for word in words if not word.startswith("#")]
words = set(words)
words = list(words.difference(STOP_WORDS))
words

['abortion',
 'drugs',
 'condoms',
 'alcohol',
 'pills',
 'babies',
 'children',
 'men',
 'women',
 'money',
 'kids',
 'sex',
 'insurance']

In [188]:
input_context = "-"+query+"\n-Don't forget about [MASK] [MASK]."
inp_tens = torch.tensor(tokenizer.encode(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_context)))).unsqueeze(0)
mask_indices = np.nonzero(inp_tens.squeeze()==103).squeeze()
preds = model(inp_tens)[0].squeeze()
mask_indices = [mask_indices.tolist()] if type(mask_indices.tolist())!=list else mask_indices.tolist()
top_words = []
for i in mask_indices:
    top_words.append(torch.topk(preds[i], k=10))
words = []
tokens = []
for mask_topk in top_words:
    for token in mask_topk.indices.tolist():
        tokens.append(token)
        words.append(tokenizer.decode(token, clean_up_tokenization_spaces=True))
    words = [word.replace(" ","") for word in words if not word.startswith("#")]
words = set(words)
words = list(words.difference(STOP_WORDS))
words =[word for word in words if len(word)>1]
words

['doctor',
 'condoms',
 'pills',
 'prescription',
 'pill',
 'medication',
 'baby',
 'stuff',
 'condom',
 'sex',
 'medicine']

In [187]:
input_context = query+"I don't believe in [MASK]."
inp_tens = torch.tensor(tokenizer.encode(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_context)))).unsqueeze(0)
mask_indices = np.nonzero(inp_tens.squeeze()==103).squeeze()
preds = model(inp_tens)[0].squeeze()
mask_indices = [mask_indices.tolist()] if type(mask_indices.tolist())!=list else mask_indices.tolist()
top_words = []
for i in mask_indices:
    top_words.append(torch.topk(preds[i], k=10))
words = []
tokens = []
for mask_topk in top_words:
    for token in mask_topk.indices.tolist():
        tokens.append(token)
        words.append(tokenizer.decode(token, clean_up_tokenization_spaces=True))
    words = [word.replace(" ","") for word in words if not word.startswith("#")]
words = set(words)
words = list(words.difference(STOP_WORDS))
words =[word for word in words if len(word)>1]
words

['abortion', 'drugs', 'condoms', 'pills', 'babies', 'prostitution', 'sex']

In [None]:
string.pu