Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change PRNGKey to key in jax-101 PRNG guide #19645

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/jax-101/05-random-numbers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@
"source": [
"from jax import random\n",
"\n",
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"\n",
"print(key)"
]
Expand Down Expand Up @@ -381,7 +381,7 @@
"source": [
"`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n",
"\n",
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
"\n",
"It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n",
"\n",
Expand Down Expand Up @@ -460,12 +460,12 @@
}
],
"source": [
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"subkeys = random.split(key, 3)\n",
"sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n",
"print(\"individually:\", sequence)\n",
"\n",
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"print(\"all at once: \", random.normal(key, shape=(3,)))"
]
},
Expand Down
8 changes: 4 additions & 4 deletions docs/jax-101/05-random-numbers.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ To avoid this issue, JAX does not use a global state. Instead, random functions

from jax import random

key = random.PRNGKey(42)
key = random.key(42)

print(key)
```
Expand Down Expand Up @@ -201,7 +201,7 @@ key = new_key # If we wanted to do this again, we would use new_key as the key.

`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.

If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.

It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.

Expand Down Expand Up @@ -240,12 +240,12 @@ In the example below, sampling 3 values out of a normal distribution individuall
:id: 4nB_TA54D-HT
:outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56

key = random.PRNGKey(42)
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.PRNGKey(42)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
```

Expand Down
Loading