# Install dependencies

In [None]:
# !pip install einops transformers
!pip install transformers

# Import dependencies

In [5]:
import os
import logging
import hashlib
import requests
from tqdm import tqdm

from transformers import GPT2Tokenizer
from transformers import GPTNeoForCausalLM

# Init/bootstrap

In [6]:
# logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if (logger.hasHandlers()):
    logger.handlers.clear()

console = logging.StreamHandler()
logger.addHandler(console)

# urls
urls = {
    'config.json': {
        'url': 'https://zhisu-nlp.s3.us-west-2.amazonaws.com/gpt-j-hf/config.json',
        'sha1sum': 'a0af27bcff3c0fa17ec9718ffb6060b8db5e54e4'
    },
    'pytorch_model.bin': {
        'url': 'https://zhisu-nlp.s3.us-west-2.amazonaws.com/gpt-j-hf/pytorch_model.bin',
        'sha1sum': 'bab870fc9b82f0bfb3f6cbf4bd6bec3f3add05a6'
    }
}

# Utility functions

In [5]:
# download
def download(url, path=None, overwrite=False, sha1_hash=None):
    """Download files from a given URL.
    """
    if path is None:
        fname = os.path.join(url.split('/')[-2],url.split('/')[-1])
    else:
        path = os.path.expanduser(path)
        if os.path.isdir(path):
            fname = os.path.join(path, url.split('/')[-2], url.split('/')[-1])
        else:
            fname = path

    if os.path.exists(fname) and sha1_hash:
        logger.info('File {} exist, checking content hash...'.format(fname))
        file_check = check_sha1(fname, sha1_hash)
        if file_check:
            logger.info('File {} checking pass'.format(fname))
        else:
            raise KeyError('File {} is downloaded but the content hash does not match. ' \
                                'Please retry.'.format(fname))

    elif overwrite or not os.path.exists(fname) :
        if overwrite:
            logger.info('File {} exist, overwriting...'.format(fname))
        download_ops(url,fname)
        if sha1_hash:
            logger.info('File {} downloaded, checking content hash...'.format(fname))
            file_check = check_sha1(fname, sha1_hash)
            if file_check:
                logger.info('File {} checking pass'.format(fname))
            else:
                raise KeyError('File {} is downloaded but the content hash does not match. ' \
                                    'Please retry.'.format(fname))
    return fname

# check_sha1
def check_sha1(filename, sha1_hash):
    """Check whether the sha1 hash of the file content matches the expected hash.
    """
    sha1 = hashlib.sha1()
    with open(filename, 'rb') as f:
        while True:
            data = f.read(1048576)
            if not data:
                break
            sha1.update(data)

    return sha1.hexdigest() == sha1_hash

# download_ops
def download_ops(url, fname):
    dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    logger.info('Downloading %s from %s...'%(fname, url))
    r = requests.get(url, stream=True)
    if r.status_code != 200:
        raise RuntimeError("Failed downloading url %s"%url)
    total_length = r.headers.get('content-length')
    with open(fname, 'wb') as f:
        if total_length is None: # no content length header
            for chunk in r.iter_content(chunk_size=1024):
                if chunk: # filter out keep-alive new chunks
                    f.write(chunk)
        else:
            total_length = int(total_length)
            for chunk in tqdm(r.iter_content(chunk_size=1024),
                                total=int(total_length / 1024. + 0.5),
                                unit='KB', unit_scale=False, dynamic_ncols=True):
                f.write(chunk)

# Download gptj models

In [None]:
for file_name, info in urls.items():
  download(info['url'], './', sha1_hash=info['sha1sum'])

logger.info("***download finished***")

In [6]:
!ls -ahl gpt-j-hf

total 12G
drwxr-xr-x 2 root root 4.0K Jul 24 16:32 .
drwxr-xr-x 1 root root 4.0K Jul 24 16:32 ..
-rw-r--r-- 1 root root 1.4K Jul 24 16:32 config.json
-rw-r--r-- 1 root root  12G Jul 24 16:43 pytorch_model.bin


# Loading model

In [None]:
model = GPTNeoForCausalLM.from_pretrained("./gpt-j-hf")
model.eval()

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
logger.info("***loading finished***")

# This should take about 12GB of Graphics RAM, if you have a larger than 16GB gpu you don't need the half()
# model.half().cuda() 

# Predictions

In [2]:
input_text = 'Why AutoGluon is great?'

logger.info("***encoding***")
input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()

logger.info("***generating***")
output = model.generate(
    input_ids,
    do_sample=True,
    max_length=args.max_length,
    top_p=args.top_p,
    top_k=0,
    temperature=1.0,
)

output_context = tokenizer.decode(output[0], skip_special_tokens=True)
logger.info('***output_context: {}'.format(output_context))

NameError: ignored