## Experiments with GPT-Neo

In [1]:
#
# NOTE
# Comment out the following code if you are not running this notebook on NLULab (Azure).
#

#import os

#cache_dir = '/home/jovyan/nlu2/.cache/transformers'
#os.environ['TRANSFORMERS_CACHE'] = cache_dir

In [1]:
import torch

if torch.cuda.is_available():
    device = torch.cuda.current_device()
    gpu_stat = torch.cuda.device_count(), device
else:
    device = -1
    gpu_stat = 'no-gpu'

gpu_stat

'no-gpu'

### Model Size Chart

|Size|Parameters|
|--- | ---|
|S|125M params|
|M|1.3B params|
|L|2.7B params|

In [14]:
model_size = "M"  # Can be S, M, L

In [15]:
from transformers import pipeline

In [16]:
%%time


def load_model(msize):
    model_name = 'EleutherAI/gpt-neo-125M'
    if msize == "M":
        model_name = 'EleutherAI/gpt-neo-1.3B'
    elif msize == "L":
        model_name = 'EleutherAI/gpt-neo-2.7B'
    return pipeline('text-generation', model=model_name, device=device)

print("Loading model (this might take a while)...")
gpt_neo = load_model(model_size)

Loading model (this might take a while)...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1347.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5312753599.0, style=ProgressStyle(descr…

("Connection broken: ConnectionResetError(104, 'Connection reset by peer')", ConnectionResetError(104, 'Connection reset by peer'))


OSError: Can't load weights for 'EleutherAI/gpt-neo-1.3B'. Make sure that:

- 'EleutherAI/gpt-neo-1.3B' is a correct model identifier listed on 'https://huggingface.co/models'

- or 'EleutherAI/gpt-neo-1.3B' is the correct path to a directory containing a file named one of pytorch_model.bin, tf_model.h5, model.ckpt.



In [5]:
def complete(prompt, temperature=0.9, do_sample=True, min_length=1, max_length=60, stop=["\n"], top_k=1):
    prompt = prompt.rstrip()
    _max_length = int(len(prompt.split())*1.5 + max_length)
    choices = gpt_neo(
        prompt, 
        temperature=temperature,
        do_sample=do_sample,
        min_length=min_length,
        max_length=_max_length,
        top_k=top_k
    )
    result = list()
    for i in range(min(top_k, len(choices))):
        output = choices[i].get('generated_text', '')
        if output.startswith(prompt):
            output = output[len(prompt):]
        for s in stop:
            _idx = output.find(s)
            if _idx > -1:
                output = output[:_idx]
        result.append(output.strip())
    if len(result) == 1:
        result = result[0]
    return result

#### Text Completion

In [10]:
%%time

prompt = "As a company Ericsson is the"
complete(prompt, max_length=60, top_k=5)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


CPU times: user 5.21 s, sys: 58 ms, total: 5.27 s
Wall time: 2.64 s


'most popular player for the game, it was announced that the game would be released on June 20, 2015, in North America. The game is based on the classic Android platform, and the developer has announced that the game will have the following features:'

#### Question Answering (no context)

In [7]:
%%time

prompt = """
Q: What is upload data-rate in LTE?
A: 75 Mbps

Q: What is EPS?
A: Evolved Packet System

Q: What is peak download rate for 4G?
A:
"""

complete(prompt, stop=["\n", "Q:"])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


CPU times: user 4.97 s, sys: 23.5 ms, total: 5 s
Wall time: 2.5 s


'1 Mbps'

#### SQL Generation

In [8]:
%%time

prompt = """
Q: List the names of employees at Ericsson?
SQL: SELECT name FROM employees WHERE company = 'Ericsson';

Q: Show the highest paid player at the NBA?
SQL: SELECT * FROM players WHERE league = 'NBA' ORDER BY salary DESC LIMIT 1;

Q: How many employees are there at GAIA?
SQL: 
"""

complete(prompt, stop=["\n", "SQL:"], temperature=0.6)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


CPU times: user 6.24 s, sys: 62.3 ms, total: 6.31 s
Wall time: 3.16 s


"SELECT * FROM employees WHERE company = 'GAIA';"