# Discrete Policy Gradients

##### Imports

In [1]:
import random

import jax
import jax.numpy as np
import matplotlib.animation as animation
from kilroyplot.plot import plt
from tqdm import tqdm

In [2]:
jax.config.update(
    "jax_platform_name", "cpu"
)  # explicitly use CPU, disable GPU warnings

In [3]:
plt.rcParams["animation.html"] = "jshtml"  # for animated plots

##### Utilities

In [4]:
class BernoulliDistribution:
    @staticmethod
    def pdf(x, p):
        """Returns probability of x (0 or 1) in a Bernoulli distribution with parameter p."""
        return x * p + (1.0 - x) * (1.0 - p)

    @staticmethod
    def logpdf(x, p):
        """Returns natural logarithm of probability of x (0 or 1) in a Bernoulli distribution with parameter p."""
        return np.log(BernoulliDistribution.pdf(x, p))

    @staticmethod
    def sample(p, n=1):
        """Samples n items from Bernoulli distribution with parameter p."""
        return random.choices([1.0, 0.0], weights=[p, 1.0 - p], k=n)

    @staticmethod
    def plot(p, ax=None):
        """Plots Bernoulli distribution with parmeter p."""
        ax = ax or plt.gca()
        x = np.array([0.0, 1.0])
        return ax.bar(x, BernoulliDistribution.pdf(x, p), width=0.25)

In [5]:
class BernoulliDistributionBlackboxOptimizer:
    """ "Utility for optimizing parameter p of Bernoulli distribution using Policy Gradients."""

    def __init__(self, r, lr, p=0.5, n=100):
        self.r = r
        self.lr = lr
        self.p = p
        self.n = n

    def step(self):
        # get some samples
        samples = BernoulliDistribution.sample(self.p, n=self.n)
        # estimate gradient dr/dp by r(x) * grad(log(pdf(x))) using generated samples
        dj_dp = np.array(
            [
                self.r(x)
                * jax.grad(BernoulliDistribution.logpdf, 1)(x, self.p)
                for x in samples
            ]
        ).mean()
        # change parameter in direction of gradient
        self.p = max(0.0, min(1.0, self.p + self.lr * dj_dp))

In [6]:
def make_animation_for_function(r, steps=200, lr=0.1):
    """Produces animated plot with value of parameter p in all steps."""
    optimizer = BernoulliDistributionBlackboxOptimizer(r, lr=lr)
    frames = []
    fig, ax = plt.subplots()
    for _ in tqdm(range(steps)):
        optimizer.step()
        bars = BernoulliDistribution.plot(optimizer.p, ax=ax)
        frames.append((bars, optimizer.p))
    plt.close()

    fig, ax = plt.subplots()
    bars_d = ax.bar([0, 1], [1.0 - optimizer.p, optimizer.p])
    ax.set_ylim([0, 1])

    def animate(frames):
        frame_d, p = frames
        bars_d[0].set_height(frame_d[0].get_height())
        bars_d[1].set_height(frame_d[1].get_height())
        bars_d.set_label(f"p = {p:.3f}")
        ax.legend()
        return bars_d[0], bars_d[1]

    plt.close()
    return animation.FuncAnimation(fig, animate, frames=frames, blit=True)

## Results

Below you can see plots of Bernoulli distribution in each step of optimization using different reward functions.

### Example 1
Reward function: $r(x) = x$

In [7]:
make_animation_for_function(lambda x: x, lr=0.01, steps=100)

100%|██████████| 100/100 [00:28<00:00,  3.45it/s]


### Example 2
Reward function: $r(x) = -x$

In [8]:
make_animation_for_function(lambda x: -x, lr=0.01, steps=100)

100%|██████████| 100/100 [00:25<00:00,  4.00it/s]


### Example 3
Reward function: $r(x) = $ random $0$ or $1$

In [9]:
make_animation_for_function(lambda x: random.randint(0, 1), lr=0.01, steps=100)

100%|██████████| 100/100 [00:23<00:00,  4.21it/s]


### Example 4
Reward function: $r(x) = 0$ for $x=0$, oscillating between $-1$ and $1$ for $x=1$

In [10]:
from itertools import count


class Counter(count):
    def __call__(self):
        return next(self)


counter = Counter(1)

make_animation_for_function(
    lambda x: x * (2 * int(np.sin((counter() + 2500) * np.pi / 5000) > 0) - 1),
    lr=0.01,
    steps=250,
)

100%|██████████| 250/250 [01:00<00:00,  4.13it/s]
