In [1]:
from transformers import AutoTokenizer
import jax.random as jrand
import jax
import jax.numpy as jnp


tokenizer_bert1 = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer_bert2 = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer_roberta = AutoTokenizer.from_pretrained("roberta-base")

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

In [2]:
print(f'BERT uncased: {tokenizer_bert1.vocab_size}')
print(f'BERT cased: {tokenizer_bert2.vocab_size}')
print(f'RoBERTa : {tokenizer_roberta.vocab_size}')

BERT uncased: 30522
BERT cased: 28996
RoBERTa : 50265


In [12]:
tokenizer('My name is Ozymandias, king of kings', return_tensors='np')

{'input_ids': array([[  101,  2026,  2171,  2003, 11472, 17906,  9032,  2015,  1010,
         2332,  1997,  5465,   102]]), 'token_type_ids': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [13]:
key = jrand.PRNGKey(2)
jrand.dirichlet(key, alpha=jnp.ones((10,)))

INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Host Interpreter
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


DeviceArray([0.22638896, 0.08777282, 0.04717882, 0.03416082, 0.03275571,
             0.10174081, 0.10347045, 0.26464796, 0.09861212, 0.00327145],            dtype=float32)

In [47]:
x = jnp.array([0., 1., 2., 3., 4., 5., 6.])
y = jnp.zeros((5,))
new = jax.lax.dynamic_slice(x, (0,), (3,))
jax.lax.dynamic_update_slice(y, new, (0,))

DeviceArray([0., 1., 2., 0., 0.], dtype=float32)

In [74]:




def tok2simp(key, token_ids, concentration=0.9, vocab_size=30_522):
    """ For time t=0, defines a point on the simplex for each token of the sequence.
    
    The true token is set to `concentration` and the remaining tokens are randomly assigned the remain probability sampled from the `vocab_size - 1` simplex 
    """
    seq_len = token_ids.shape[-1]
    key, subkey = jrand.split(key)
    concentrations = jnp.full((seq_len, 1), concentration)
    other_probs = jrand.dirichlet(key, alpha=jnp.ones(vocab_size - 1,), shape=(seq_len,)) * (1 - concentration)
    probs = jnp.concatenate([concentrations, other_probs], axis=1)
    return jax.vmap(jnp.roll)(probs, token_ids)
    
                                   
key, subkey = jrand.split(key)

tokens = tokenizer('My name is Ozymandias, king of kings', return_tensors='np')['input_ids'][0]

tok2simp(jrand.split(key, 2), tokens, concentration=0.75)

DeviceArray([[1.0687480e-06, 8.4908961e-06, 9.8315859e-06, ...,
              4.2215838e-06, 4.7946633e-06, 9.2560765e-07],
             [3.5339278e-07, 1.1123673e-06, 5.8005348e-06, ...,
              1.6144631e-05, 9.3367481e-07, 5.6265985e-06],
             [3.7886305e-06, 5.6568247e-06, 1.0279502e-05, ...,
              1.4628033e-06, 1.7675428e-07, 3.0388173e-06],
             ...,
             [7.8678022e-06, 5.9860863e-06, 1.5539873e-06, ...,
              2.2945899e-06, 1.5643706e-06, 3.9684437e-06],
             [1.7250152e-06, 1.6345626e-05, 1.7823446e-05, ...,
              1.0356162e-08, 3.9409697e-06, 1.3914089e-05],
             [1.5761785e-05, 1.1107145e-05, 1.7765348e-06, ...,
              1.0689306e-05, 1.5254598e-05, 1.3654142e-05]],            dtype=float32)

In [78]:
from functools import partial
tokens2 = tokenizer(
    ['My name is Ozymandias, king of kings',
     'Look on my works, ye mighty, and despair'
    ], return_tensors='np', padding=True)['input_ids']
tok_fn = partial(tok2simp, concentration=0.9)
jax.vmap(tok_fn)(jrand.split(key, 2), tokens2)

DeviceArray([[[4.9601061e-07, 7.0883284e-06, 5.0080885e-06, ...,
               4.8108497e-08, 2.4784588e-06, 3.2537268e-06],
              [1.9356710e-06, 8.4702734e-08, 2.5933771e-06, ...,
               9.3110002e-06, 1.5264789e-06, 1.6829944e-06],
              [2.8916584e-06, 2.0917669e-06, 7.3057572e-06, ...,
               1.2578334e-06, 2.4394202e-08, 1.3356006e-06],
              ...,
              [1.0777827e-07, 1.1772730e-07, 6.5805011e-08, ...,
               3.9340111e-06, 5.2198848e-06, 8.4783878e-06],
              [3.7900188e-06, 4.7797096e-07, 2.9140389e-07, ...,
               1.3466755e-06, 8.6511272e-07, 6.1717919e-06],
              [3.9615497e-07, 2.8244549e-06, 1.4466606e-05, ...,
               1.1659594e-05, 2.5676010e-07, 1.0340481e-05]],

             [[7.9156571e-08, 6.4688656e-06, 2.4391076e-05, ...,
               6.2400659e-06, 4.0468990e-08, 1.8201349e-07],
              [4.8308953e-06, 2.4992153e-06, 1.9477952e-06, ...,
               3.8142978e-06, 5.

In [65]:
r = jrand.dirichlet(subkey, jnp.ones((9,)), (3,)) * .1
x = jnp.concatenate([jnp.full((3, 1), .9), r], axis=1)
x

DeviceArray([[8.9999998e-01, 4.1741308e-02, 3.5531104e-03, 1.9258678e-02,
              1.0263417e-02, 2.0375280e-03, 2.2614503e-03, 2.9235105e-03,
              2.0785143e-03, 1.5882496e-02],
             [8.9999998e-01, 2.3812694e-03, 8.2627041e-03, 2.9006688e-02,
              3.9507765e-03, 8.1967050e-04, 3.7214272e-03, 1.4443065e-02,
              3.0938467e-02, 6.4759413e-03],
             [8.9999998e-01, 1.3912891e-02, 5.7322169e-03, 1.2617664e-02,
              1.3692746e-02, 1.7116426e-03, 2.5537631e-02, 1.1670019e-02,
              8.1467014e-03, 6.9784946e-03]], dtype=float32)

In [69]:
z = jnp.array([1, 2, 3])
jax.vmap(jnp.roll)(x, z)

DeviceArray([[1.5882496e-02, 8.9999998e-01, 4.1741308e-02, 3.5531104e-03,
              1.9258678e-02, 1.0263417e-02, 2.0375280e-03, 2.2614503e-03,
              2.9235105e-03, 2.0785143e-03],
             [3.0938467e-02, 6.4759413e-03, 8.9999998e-01, 2.3812694e-03,
              8.2627041e-03, 2.9006688e-02, 3.9507765e-03, 8.1967050e-04,
              3.7214272e-03, 1.4443065e-02],
             [1.1670019e-02, 8.1467014e-03, 6.9784946e-03, 8.9999998e-01,
              1.3912891e-02, 5.7322169e-03, 1.2617664e-02, 1.3692746e-02,
              1.7116426e-03, 2.5537631e-02]], dtype=float32)

In [79]:
def t_to_alpha_sigma(t):
    """Returns the scaling factors for the clean image and for the noise, given
    a timestep."""
    return jnp.cos(t * jnp.pi / 2), jnp.sin(t * jnp.pi / 2)

t_to_alpha_sigma(jnp.linspace(0, 1))

(DeviceArray([ 1.0000000e+00,  9.9948621e-01,  9.9794537e-01,
               9.9537909e-01,  9.9179000e-01,  9.8718178e-01,
               9.8155916e-01,  9.7492790e-01,  9.6729487e-01,
               9.5866787e-01,  9.4905573e-01,  9.3846840e-01,
               9.2691678e-01,  9.1441262e-01,  9.0096885e-01,
               8.8659930e-01,  8.7131870e-01,  8.5514277e-01,
               8.3808810e-01,  8.2017225e-01,  8.0141360e-01,
               7.8183144e-01,  7.6144594e-01,  7.4027801e-01,
               7.1834934e-01,  6.9568253e-01,  6.7230088e-01,
               6.4822841e-01,  6.2348986e-01,  5.9811050e-01,
               5.7211661e-01,  5.4553491e-01,  5.1839256e-01,
               4.9071753e-01,  4.6253830e-01,  4.3388376e-01,
               4.0478328e-01,  3.7526694e-01,  3.4536502e-01,
               3.1510821e-01,  2.8452760e-01,  2.5365460e-01,
               2.2252086e-01,  1.9115858e-01,  1.5959987e-01,
               1.2787715e-01,  9.6023038e-02,  6.4070255e-02,
        

In [2]:
import haiku as hk

def _forward_fn_linear1(x):
    m = hk.Linear(10)
    return m(x)

linear = hk.transform(_forward_fn_linear1)


In [11]:
import jax.random
key = jax.random.PRNGKey(32)
key, subkey = jax.random.split(key)
dummy_x = jax.random.normal(key, (128, 512, 100))
key, subkey = jax.random.split(key)
params = linear.init(subkey, dummy_x)

In [12]:
import jax.numpy as jnp

key, subkey = jax.random.split(key)
x_test = jnp.ones((1, 256, 100))
linear.apply(params, subkey, x_test).shape

(1, 256, 10)