In [1]:
%load_ext autoreload
%autoreload 2
import sys
from pathlib import Path
PROJ_HOME = '/Users/skeem396/Projects/'
sys.path.append(PROJ_HOME)

In [12]:
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import top_k_top_p_filtering
import torch
from torch.nn import functional as F

In [3]:
generator = pipeline('text-generation', model='EleutherAI/gpt-neo-1.3B')

Downloading:   0%|          | 0.00/5.31G [00:00<?, ?B/s]

In [4]:
generator("FreeWheel is", do_sample=True, min_length=50)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'FreeWheel is a web-based service, powered by MovableType. This is an independent third party website. Use of this web site is subject to our Privacy Policy applicable to us. MovableType may update this Privacy Policy from time to time'}]

In [5]:
generator("Comcast is", do_sample=True, min_length=50)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'Comcast is going to launch its free, over-the-air TV channels in Canada on Sept. 12. It’s a bold move, sure. After all, the company doesn’t have anything new to offer Canadians. But'}]

In [6]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")

In [63]:
test_text = """How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Black""" 

In [64]:
tokenizer(test_text)

{'input_ids': [2437, 867, 517, 15102, 761, 284, 307, 2923, 287, 1502, 329, 262, 1705, 12527, 11, 2592, 8661, 3392, 11, 284, 892, 326, 356, 447, 247, 260, 12733, 286, 257, 1621, 30, 2619], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [75]:
def generate_n_next(phrase, length=10):
    for i in range(length):
        inputs = tokenizer.encode(phrase, return_tensors='pt')

        # Get logits from last layer
        last_layer_logits = model(inputs).logits[:, -1, :]
        top_logits = top_k_top_p_filtering(last_layer_logits, top_k=100, top_p=1.0)

        # Softmax the logits into probabilities
        probabilities = F.softmax(top_logits, dim=-1)

        # Generate next token
        generated_next_token = torch.multinomial(probabilities, num_samples=1)
        generated = torch.cat([inputs, generated_next_token], dim=-1)

        tok_id = generated.tolist()[0]
        result_string = tokenizer.decode(tok_id)
        print(result_string)
        phrase = result_string
    return phrase

GPT-Neo was trained on the Pile, a dataset known to contain profanity, lewd, and otherwise abrasive language. Depending on your usecase GPT-Neo may produce socially unacceptable text. 

In [66]:
generate_n_next(test_text, length=10)

How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Black Lives
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Black Lives Matter
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Black Lives Matter has
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Black Lives Matter has been
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Black Lives Matter has been saying
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Black Lives Matter has been saying it
How many more blacks ne

In [67]:
test_text2 = """How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Asian""" 

In [68]:
generate_n_next(test_text2, length=10)

How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Asian-
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Asian-Americans
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Asian-Americans continue
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Asian-Americans continue to
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Asian-Americans continue to suffer
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Asian-Americans continue to suffer.
How many more Asians need to be

In [72]:
test_text3 = """How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story""" 
generate_n_next(test_text3, length=10)

How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story?
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Why
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Why keep
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Why keep pretending
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Why keep pretending that
How many more blacks need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story? Why keep pretending that there
How many more blacks need to be killed in order for the news outlets, especially mainstre

In [74]:
test_text4 = """How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story"""
generate_n_next(test_text4, length=10)

How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story that
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story that highlights
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story that highlights the
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story that highlights the Asian
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story that highlights the Asian community
How many more Asians need to be killed in order for the news outlets, especially mainstream ones, to think that we’re worthy of a story that highlights the Asian community as
How many more Asians need to be killed in 