### Discriminator Rejection Sampling
https://arxiv.org/pdf/1810.06758.pdf

In [28]:
import numpy as np

# An arbitrarily large number is needed to find a good maximum M. 
# We're just randomly sampling to find a large M, but we have no way of calculating it for real.
arbitrary_large_number = 1000
batch_size = 100
N = 16
gamma = 1e-4
epsilon=1e-8

def generator(z):
    return np.tanh(z)

def discriminator(x):
    return 1 / (1 + np.exp(-x))

def F(Dx, Dm, gamma=1e-4, epsilon=1e-8):
    return Dx - Dm - np.log(1 - np.exp(Dx - Dm - epsilon)) - gamma

In [39]:
# This is the burn in where we find the maximum M
M_max = 0
for _ in range(arbitrary_large_number):
    z = np.random.uniform(-1, 1, 100)
    x = generator(z)
    M_max = max(M_max, np.exp(discriminator(x)).max())

# Now that we have the maximum M, we want to get N samples.
samples = []
n = 0
while n < N:
    z = np.random.uniform(-1, 1, 100)
    x = generator(z)
    Dx = np.exp(discriminator(x))
    M_max = max(M_max, Dx.max())
    Dm = np.log(M_max)
    
    p = F(Dx, Dm, gamma=1e-4, epsilon=1e-8)
    
    select = p < np.random.uniform(0, 1)
    n += select.sum()
    samples.append(x[select])
# Our final samples
np.concatenate(samples)[:N]



KeyboardInterrupt: 