<a href="https://colab.research.google.com/github/matthias-wright/flaxmodels/blob/main/flaxmodels/gpt2/gpt2_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade pip
!pip install --upgrade jax jaxlib==0.1.66+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git

Collecting pip
[?25l  Downloading https://files.pythonhosted.org/packages/cd/6f/43037c7bcc8bd8ba7c9074256b1a11596daa15555808ec748048c1507f08/pip-21.1.1-py3-none-any.whl (1.5MB)
[K     |▏                               | 10kB 24.7MB/s eta 0:00:01[K     |▍                               | 20kB 32.3MB/s eta 0:00:01[K     |▋                               | 30kB 27.1MB/s eta 0:00:01[K     |▉                               | 40kB 19.4MB/s eta 0:00:01[K     |█                               | 51kB 10.1MB/s eta 0:00:01[K     |█▎                              | 61kB 9.2MB/s eta 0:00:01[K     |█▌                              | 71kB 10.3MB/s eta 0:00:01[K     |█▊                              | 81kB 11.3MB/s eta 0:00:01[K     |██                              | 92kB 12.1MB/s eta 0:00:01[K     |██▏                             | 102kB 8.4MB/s eta 0:00:01[K     |██▎                             | 112kB 8.4MB/s eta 0:00:01[K     |██▌                             | 122kB 8.4MB/s eta 0:

# Generate text

This is very simple greedy text generation. There are more sophisticated [methods](https://huggingface.co/blog/how-to-generate) out there.

In [2]:
import jax
import jax.numpy as jnp
import flaxmodels as fm

key = jax.random.PRNGKey(0)

# Initialize tokenizer
tokenizer = fm.gpt2.get_tokenizer()

# Encode start sequence
generated = tokenizer.encode('The Manhattan bridge')

context = jnp.array([generated])
past = None

# Initialize model
# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
params = model.init(key, input_ids=context, past_key_values=past)

for i in range(20):
    # Predict next token in sequence
    output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)
    token = jnp.argmax(output['logits'][..., -1, :])
    #context = jnp.expand_dims(token, axis=(0, 1))
    context = jnp.expand_dims(token, axis=0)
    # Add token to sequence
    generated += [token]
    # Update past keys and values
    past = output['past_key_values']

# Decode sequence of tokens
sequence = tokenizer.decode(generated)

print()
print(sequence)

Downloading: "https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt" to /tmp/flaxmodels/merges.txt


100%|██████████| 456k/456k [00:00<00:00, 14.9MiB/s]


Downloading: "https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json" to /tmp/flaxmodels/vocab.json


100%|██████████| 1.04M/1.04M [00:00<00:00, 1.42MiB/s]


Downloading: "https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5" to /tmp/flaxmodels/gpt2.h5


100%|██████████| 703M/703M [00:30<00:00, 22.9MiB/s]


Downloading: "https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json" to /tmp/flaxmodels/gpt2.json


100%|██████████| 715/715 [00:00<00:00, 143kiB/s]



The Manhattan bridge is a major artery for the city's subway system, and the bridge is one of the busiest in


# Get language model head output from text input

In [3]:
import jax
import jax.numpy as jnp
import flaxmodels as fm

key = jax.random.PRNGKey(0)

# Initialize tokenizer
tokenizer = fm.gpt2.get_tokenizer()

# Encode start sequence
input_ids = tokenizer.encode('The Manhattan bridge')
input_ids = jnp.array([input_ids])

# Initialize model
model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
params = model.init(key, input_ids=input_ids)

# Compute output
output = model.apply(params, input_ids=input_ids, use_cache=True)
# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}

# Get language model head output from embeddings


In [4]:
import jax
import jax.numpy as jnp
import flaxmodels as fm
                                                                    
key = jax.random.PRNGKey(0)

# Dummy input                                        
input_embds = jax.random.normal(key, shape=(2, 10, 768))

# Initialize model
model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
params = model.init(key, input_embds=input_embds)
# Compute output
output = model.apply(params, input_embds=input_embds, use_cache=True)
# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}

# Get model output from text input

In [5]:
import jax
import jax.numpy as jnp
import flaxmodels as fm

key = jax.random.PRNGKey(0)

# Initialize tokenizer
tokenizer = fm.gpt2.get_tokenizer()

# Encode start sequence
input_ids = tokenizer.encode('The Manhattan bridge')
input_ids = jnp.array([input_ids])

# Initialize model
model = fm.gpt2.GPT2Model(pretrained='gpt2')
params = model.init(key, input_ids=input_ids)

# Compute output
output = model.apply(params, input_ids=input_ids, use_cache=True)
# output: {'last_hidden_state': ..., 'past_key_values': ...}

# Get model output from embeddings

In [6]:
import jax
import jax.numpy as jnp
import flaxmodels as fm
                                                                    
key = jax.random.PRNGKey(0)

# Dummy input
input_embds = jax.random.normal(key, shape=(2, 10, 768))
                                                                                                      
# Initialize model
model = fm.gpt2.GPT2Model(pretrained='gpt2')
params = model.init(key, input_embds=input_embds)

# Compute output
output = model.apply(params, input_embds=input_embds, use_cache=True)
# output: {'last_hidden_state': ..., 'past_key_values': ...}