From 1e7ae3fb48201c44b116f2cba2480975737b2bad Mon Sep 17 00:00:00 2001 From: ivyzheng Date: Fri, 2 Feb 2024 19:59:29 -0800 Subject: [PATCH] Change PRNGKey to key in jax-101 PRNG guide --- docs/jax-101/05-random-numbers.ipynb | 8 ++++---- docs/jax-101/05-random-numbers.md | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/jax-101/05-random-numbers.ipynb b/docs/jax-101/05-random-numbers.ipynb index 65977674a062..904eab51cbb5 100644 --- a/docs/jax-101/05-random-numbers.ipynb +++ b/docs/jax-101/05-random-numbers.ipynb @@ -282,7 +282,7 @@ "source": [ "from jax import random\n", "\n", - "key = random.PRNGKey(42)\n", + "key = random.key(42)\n", "\n", "print(key)" ] @@ -381,7 +381,7 @@ "source": [ "`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n", "\n", - "If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n", + "If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n", "\n", "It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n", "\n", @@ -460,12 +460,12 @@ } ], "source": [ - "key = random.PRNGKey(42)\n", + "key = random.key(42)\n", "subkeys = random.split(key, 3)\n", "sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n", "print(\"individually:\", sequence)\n", "\n", - "key = random.PRNGKey(42)\n", + "key = random.key(42)\n", "print(\"all at once: \", random.normal(key, shape=(3,)))" ] }, diff --git a/docs/jax-101/05-random-numbers.md b/docs/jax-101/05-random-numbers.md index f9f3ae178efe..aacd4f967422 100644 --- a/docs/jax-101/05-random-numbers.md +++ b/docs/jax-101/05-random-numbers.md @@ -150,7 +150,7 @@ To avoid this issue, JAX does not use a global state. Instead, random functions from jax import random -key = random.PRNGKey(42) +key = random.key(42) print(key) ``` @@ -201,7 +201,7 @@ key = new_key # If we wanted to do this again, we would use new_key as the key. `split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever. -If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it. +If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it. It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later. @@ -240,12 +240,12 @@ In the example below, sampling 3 values out of a normal distribution individuall :id: 4nB_TA54D-HT :outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56 -key = random.PRNGKey(42) +key = random.key(42) subkeys = random.split(key, 3) sequence = np.stack([random.normal(subkey) for subkey in subkeys]) print("individually:", sequence) -key = random.PRNGKey(42) +key = random.key(42) print("all at once: ", random.normal(key, shape=(3,))) ```