# Prepare requirements

In [None]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12

In [None]:
!git clone https://github.com/finetuneanon/transformers
!git -C ./transformers checkout gpt-j
!pip install transformers/

In [None]:
!python3 ./conv.py  # if you get OOM error try convlowmem.py

# Generetion

In [None]:
import torch
import gptj_wrapper as gptj
import datetime

In [None]:
model = gptj.GPTJ(stage=2)

In [None]:
eos_newline = model.tokenizer("<|endoftext|>")['input_ids'][0]

In [None]:
with torch.no_grad():
    text = """I am a highly intelligent question answering bot. If you provide me a context and ask me a question that is rooted in truth, I will give you the answer. If you ask me a question that is nonsense, trickery, or has no clear answer based on the context, I will respond with "Unknown".

Context: In 2017, U.S. life expectancy was 78.6 years.
Question: What is human life expectancy in the United States?
Answer: 78 years.

Context: puppy A is happy. puppy B is sad.
Question: which puppy is happy?
Answer: puppy A.

Context: You poured a glass of cranberry, but then absentmindedly, you poured about a teaspoon of grape juice into it. It looks OK. You try sniffing it, but you have a bad cold, so you can't smell anything. You are very thirsty. So you drink it.
Question: What happens next?
Answer:
"""
    start = datetime.datetime.now()
    out = model.generate(
        text=text,
        max_length=512,
#         num_beams=5,
        do_sample=True,
        temperature=0.1,
        top_k=5,
        top_p=0.9,
        no_repeat_ngram_size=2, 
        early_stopping=True,
#         num_return_sequences=1,
        use_cache=False,
        eos_token_id=eos_newline
    )
    duration = datetime.datetime.now() - start
    for o in out:
        print("\n\n\n")
        print(o[len(text):])
    print(f"\n\nDuration = {duration.total_seconds()}")