In [None]:
!pip install --upgrade pip
!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git

Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
Collecting jaxlib==0.1.70+cuda111
  Downloading https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.70%2Bcuda111-cp37-none-manylinux2010_x86_64.whl (197.0 MB)
[K     |████████████████████████████████| 197.0 MB 19 kB/s 
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.1.66+cuda111
    Uninstalling jaxlib-0.1.66+cuda111:
      Successfully uninstalled jaxlib-0.1.66+cuda111
Successfully installed jaxlib-0.1.70+cuda111
Collecting git+https://github.com/matthias-wright/flaxmodels.git
  Cloning https://github.com/matthias-wright/flaxmodels.git to /tmp/pip-req-build-cg84k2dn
  Running command git clone -q https://github.com/matthias-wright/flaxmodels.git /tmp/pip-req-build-cg84k2dn
  Resolved https://github.com/matthias-wright/flaxmodels.git to commit 242ced2a4a12ace8adc32a705b08064ffeeb31ac


# Generate text

This is very simple greedy text generation. There are more sophisticated [methods](https://huggingface.co/blog/how-to-generate) out there.

In [None]:
import jax
import jax.numpy as jnp
import flaxmodels as fm

key = jax.random.PRNGKey(0)

# Initialize tokenizer
tokenizer = fm.gpt2.get_tokenizer()

# Encode start sequence
generated = tokenizer.encode('The Manhattan bridge')

context = jnp.array([generated])
past = None

# Initialize model
# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
params = model.init(key, input_ids=context, past_key_values=past)

for i in range(20):
    # Predict next token in sequence
    output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)
    token = jnp.argmax(output['logits'][..., -1, :])
    #context = jnp.expand_dims(token, axis=(0, 1))
    context = jnp.expand_dims(token, axis=0)
    # Add token to sequence
    generated += [token]
    # Update past keys and values
    past = output['past_key_values']

# Decode sequence of tokens
sequence = tokenizer.decode(generated)

print()
print(sequence)

Downloading: "https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt" to /tmp/flaxmodels/merges.txt


100%|██████████| 456k/456k [00:00<00:00, 12.1MiB/s]


Downloading: "https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json" to /tmp/flaxmodels/vocab.json


100%|██████████| 1.04M/1.04M [00:00<00:00, 23.1MiB/s]


Downloading: "https://www.dropbox.com/s/0wdgj0gazwt9nm7/gpt2.h5" to /tmp/flaxmodels/gpt2.h5


100%|██████████| 703M/703M [00:14<00:00, 48.1MiB/s]


Downloading: "https://www.dropbox.com/s/s5xl32dgwc8322p/gpt2.json" to /tmp/flaxmodels/gpt2.json


100%|██████████| 715/715 [00:00<00:00, 159kiB/s]



The Manhattan bridge is a major artery for the city's subway system, and the bridge is one of the busiest in


# Get language model head output from text input

In [None]:
import jax
import jax.numpy as jnp
import flaxmodels as fm

key = jax.random.PRNGKey(0)

# Initialize tokenizer
tokenizer = fm.gpt2.get_tokenizer()

# Encode start sequence
input_ids = tokenizer.encode('The Manhattan bridge')
input_ids = jnp.array([input_ids])

# Initialize model
model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
params = model.init(key, input_ids=input_ids)

# Compute output
output = model.apply(params, input_ids=input_ids, use_cache=True)
# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}

# Get language model head output from embeddings


In [None]:
import jax
import jax.numpy as jnp
import flaxmodels as fm
                                                                    
key = jax.random.PRNGKey(0)

# Dummy input                                        
input_embds = jax.random.normal(key, shape=(2, 10, 768))

# Initialize model
model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
params = model.init(key, input_embds=input_embds)
# Compute output
output = model.apply(params, input_embds=input_embds, use_cache=True)
# output: {'last_hidden_state': ..., 'past_key_values': ..., 'loss': ..., 'logits': ...}

# Get model output from text input

In [None]:
import jax
import jax.numpy as jnp
import flaxmodels as fm

key = jax.random.PRNGKey(0)

# Initialize tokenizer
tokenizer = fm.gpt2.get_tokenizer()

# Encode start sequence
input_ids = tokenizer.encode('The Manhattan bridge')
input_ids = jnp.array([input_ids])

# Initialize model
model = fm.gpt2.GPT2Model(pretrained='gpt2')
params = model.init(key, input_ids=input_ids)

# Compute output
output = model.apply(params, input_ids=input_ids, use_cache=True)
# output: {'last_hidden_state': ..., 'past_key_values': ...}

# Get model output from embeddings

In [None]:
import jax
import jax.numpy as jnp
import flaxmodels as fm
                                                                    
key = jax.random.PRNGKey(0)

# Dummy input
input_embds = jax.random.normal(key, shape=(2, 10, 768))
                                                                                                      
# Initialize model
model = fm.gpt2.GPT2Model(pretrained='gpt2')
params = model.init(key, input_embds=input_embds)

# Compute output
output = model.apply(params, input_embds=input_embds, use_cache=True)
# output: {'last_hidden_state': ..., 'past_key_values': ...}