<a href="https://colab.research.google.com/github/mnopqr1/goodstorybot/blob/main/TIMIT_stories.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using GPT-J-6B to generate stories from the TIMIT corpus

A friend sent me [a wonderfully weird list of sentences](
https://github.com/daanzu/speech-training-recorder/blob/master/prompts/timit.txt) originally used as [prompts for the TIMIT acoustic speech corpus](https://www.nist.gov/publications/darpa-timit-acoustic-phonetic-continuous-speech-corpus-cd-rom-timit).

I was curious to see how the latest open source text generation would react to these seemingly random sentences.

The Colab notebook made available by the [EleutherAI](https://eleuther.ai) collective made this incredibly simple.

I only wrote the first and last section of this notebook myself, applying the model to the TIMIT sentences. The other sections are a copy of the GPT-J-6B Inference Demo notebook written by [Ben Wang](https://github.com/kingoflolz).

# Setup Google Drive

We first connect to personal Google Drive folder where we assume there is a file timitclean.txt containing the prompt sentences, one per line.

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import os

# TODO: Fill in the Google Drive path where you uploaded the assignment
# Example: If you create a 2020FA folder and put all the files under A3 folder, then '2020FA/A3'
GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = 'Colab Notebooks'
GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)
print(os.listdir(GOOGLE_DRIVE_PATH))

['Copy of app_jupyter.ipynb', 'Copy of 01_intro (1).ipynb', 'Copy of 01_intro.ipynb', 'Untitled0.ipynb', 'unige14x050-sandbox.ipynb', 'TIMIT stories.ipynb', 'Copy of TIMIT stories.ipynb']


In [None]:
f = open(os.path.join(GOOGLE_DRIVE_PATH, "timitclean.txt"), 'r')
singleprompts = [s[:-1] for s in f.readlines()]
f.close()

# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [None]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 39 not upgraded.
Need to get 278 kB of archives.
After this operation, 1,141 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 zstd amd64 1.3.3+dfsg-2ubuntu1.2 [278 kB]
Fetched 278 kB in 1s (444 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 160772 files and directories currently installed.)
Preparing to unpack .../zstd_1.3.3+dfsg-2ubuntu1.2_amd64.deb ...
Unpacking zstd (1.3.3+dfsg-2ubuntu1.2) ...
Setting up zstd (1.3.3+dfsg-2ubuntu1.2) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
--2021-06-17 12:25:34--  https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
Resolving the-eye.eu (the-eye.eu)... 162.213.130.242
Connecting to the-eye.eu (the-eye.eu)|162.213.130.242|:443... connected.
HT

Processing ./mesh-transformer-jax
Collecting jax==0.2.12
[?25l  Downloading https://files.pythonhosted.org/packages/9a/67/d1a9c94104c559b49bbcb72e9efc33859e982d741ea4902d2a00e66e09d9/jax-0.2.12.tar.gz (590kB)
[K     |████████████████████████████████| 593kB 5.3MB/s 
Building wheels for collected packages: jax, mesh-transformer
  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Created wheel for jax: filename=jax-0.2.12-cp37-none-any.whl size=682484 sha256=0b3cb041d672dea03fc3cb19bea386e594e31774f707825ed405a9ae5b87876c
  Stored in directory: /root/.cache/pip/wheels/cf/00/88/75c2043dff473f58e892c7e6adfd2c44ccefb6111fcc021e5b
  Building wheel for mesh-transformer (setup.py) ... [?25l[?25hdone
  Created wheel for mesh-transformer: filename=mesh_transformer-0.0.0-cp37-none-any.whl size=20016 sha256=c9a94362e05fbfa8f5abebb6cc772e1e5b9c04c89e35cd8d5a12b166e54ed9fb
  Stored in directory: /root/.cache/pip/wheels/de/a9/d2/2be3e25299342b60fca7965d4e416264ff8b6d8a7e8def76da
Successfull

## Setup Model


In [None]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…




Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

  warn("xmap is an experimental feature and probably has bugs!")


key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
read from disk/gcs in 33.179s


## Inference function

This sets up the inference function (from the original demo notebook), which also gives the tip:

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples

#print(infer("EleutherAI is")[0])

# Applying the model to the TIMIT sentences


## First tests
First attempt (interactive sliders from demo notebook): decided to go for 128 tokens as a length. Experimented a bit with top_p and temp and decided on temp 1 and top_p 0.9.

In [None]:
#@title  { form-width: "300px" }
top_p = 0.8 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

context = """She had your dark suit in greasy wash water all year."""

print(infer(top_p=top_p, temp=temp, gen_len=128, context=context)[0])

completion done in 43.450966358184814s
[1mShe had your dark suit in greasy wash water all year.[0m What will you wear to your own funeral?"

"I don't know," I said.

"The vicar will. You don't seem like a man who would expect to be buried in a suit."

"I guess not."

"Then I suggest you find a pair of loose trousers and a shirt and lay them on the bed. I will find you some socks, too. And I will leave you some food and drink. When you have eaten and rested, I will bring you another suit. A pair of boots, too, if you will wear them. And I'll make a call to


In [None]:
context = """She is thinner than I am. He will allow a rare lie. Alfalfa is healthy for you."""
print(infer(top_p=0.9, temp=1, gen_len=128, context=context)[0])

completion done in 3.587801933288574s
[1mShe is thinner than I am. He will allow a rare lie. Alfalfa is healthy for you.[0m

The small dog's jaw was actually clenched, and the saliva was seeping from his mouth, running down the furrow that separated his whiskers.

"Alfalfa."

I returned his hungry gaze, and his eyes were alert, ready for action, aware of the risk of standing in a path, the terrible danger of death. "Do I know you, Alfalfa?"

"Oh no, mister." He jumped back to the other side of the driveway, the man behind him exclaiming at the wonderful speed.

Alfalfa was tamer now. The intensity


## Asking the model for stories on batches of three sentences

Here I load the timit prompts in batches of three sentences. I prepared this file on my local machine and uploaded it to the Google colab notebook.

In [None]:
f = open("timitprompts3.txt")
prompts = [s[:-1] for s in f.readlines()]

In [None]:
story = [""]*780

In [None]:
n = 0
story[n] = infer(top_p=0.9, temp=1, gen_len=256, context=prompts[n])[0]
story[n] = story[n][story[n].find("\x1b[0m")+4:] # remove boldface prompt from story

completion done in 6.908513307571411s


Looks like it's working:

In [None]:
print(story[0])

 Why give her a job that suited her? Why give her a job at all? She had the chance to be strong."

"And you think she took it?" said Ismay. "Well, what did you think?"

"I always knew she was your mammy's pet," said Annabel. "There was one day in the autumn of my third year in Boston. You remember it. You were playing in the park and your sister—I forget her name, but she was a bit wild—was being scolded for running off. You were not as young as you are now."

She looked out at the water, at the rowing boats and tenders and passengers.

"In the town hall, I stood near Mrs. Copley, looking at you and at some other children being scolded for running wild. I saw you reach out your hand and stroke your sister's hair. I could tell that it made you both happy. You saw me there and you made a signal to me. We looked up from where we were at one of the windows that let in the sun, and we knew it. I knew you could see us.

"Your mother was angry that day. She went to our


And now iterate on all the prompts. It takes about 7 seconds per story, so would take around 90 minutes to run in full. But I decided to switch to single sentence prompts instead.

In [None]:
for n in range(len(prompts)):
  story[n] = infer(top_p=0.9, temp=1, gen_len=256, context=prompts[n])[0]
  story[n] = story[n][story[n].find("\x1b[0m")+4:]
  f = open("allthestories.txt",'a')
  f.write("Story " + str(n+1) + ".\n\n" + prompts[n] + " *** " + story[n] + "\n\n ----- \n\n")
  f.close()

In [None]:
story[0]

'\x1b[1mShe had your dark suit in greasy wash water all year. Don\'t ask me to carry an oily rag like that. This was easy for us.\x1b[0m We can find out where the public toilets are in half an hour. My buddy, Eddy, gets married, we bring the good suits and you stay with us."\n\n"Stop, will you?"\n\n"We play dumb and make sure they never see us or that it\'s too dark. Most of the time they just say, \'Where to?\' and that\'s that. You wait in your suit and they come back with what they took out of the trash can and the place is deserted."\n\n"Why don\'t you use the public toilets, then?"\n\n"It\'s too noisy and too crowded.'

## Same thing, but now with single sentence prompts:

In [None]:
sstory = [""] * len(singleprompts)

In [None]:
f = open(os.path.join(GOOGLE_DRIVE_PATH, "test.txt"),'a')
f.write("Hello world")
f.close()

In [None]:
for n in range(1094,len(singleprompts)):
  print("Story " + str(n), end=": ...")
  sstory[n] = infer(top_p=0.9, temp=1, gen_len=256, context=singleprompts[n])[0]
  sstory[n] = sstory[n][sstory[n].find("\x1b[0m")+4:]
  f = open(os.path.join(GOOGLE_DRIVE_PATH, "sstories3.txt"),'a')
  f.write("Story " + str(n+1) + ".\n\n" + singleprompts[n] + " *** " + sstory[n] + "\n\n ----- \n\n")
  f.close()

Story 1094: ...completion done in 6.905962705612183s
Story 1095: ...completion done in 6.903246879577637s
Story 1096: ...completion done in 6.90148401260376s
Story 1097: ...completion done in 6.9043567180633545s
Story 1098: ...completion done in 6.903069257736206s
Story 1099: ...completion done in 6.895622730255127s
Story 1100: ...completion done in 6.899883508682251s
Story 1101: ...completion done in 6.9009740352630615s
Story 1102: ...completion done in 6.898463249206543s
Story 1103: ...completion done in 6.904120683670044s
Story 1104: ...completion done in 6.898528099060059s
Story 1105: ...completion done in 6.902292013168335s
Story 1106: ...completion done in 6.904008388519287s
Story 1107: ...completion done in 6.898965120315552s
Story 1108: ...completion done in 6.898281097412109s
Story 1109: ...completion done in 6.900062322616577s
Story 1110: ...completion done in 6.901658773422241s
Story 1111: ...completion done in 6.903208017349243s
Story 1112: ...completion done in 6.904562473