In [None]:
import Random
import Pkg; Pkg.add("MLDatasets")
import MLDatasets
train_x, train_y = MLDatasets.MNIST.traindata()

mutable struct MNISTTrainDataLoader
    cur_id::Int
    order::Vector{Int}
end

MNISTTrainDataLoader() = MNISTTrainDataLoader(1, Random.shuffle(1:60000))

function next_batch(loader::MNISTTrainDataLoader, batch_size)
    x = zeros(Float64, batch_size, 784)
    y = Vector{Int}(undef, batch_size)
    for i=1:batch_size
        x[i, :] = reshape(train_x[:,:,loader.cur_id], (28*28))
        y[i] = train_y[loader.cur_id] + 1
        loader.cur_id += 1
        if loader.cur_id > 60000
            loader.cur_id = 1
        end
    end
    x, y
end

function load_mnist_test_set()
    test_x, test_y = MLDatasets.MNIST.testdata()
    N = length(test_y)
    x = zeros(Float64, N, 784)
    y = Vector{Int}(undef, N)
    for i=1:N
        x[i, :] = reshape(test_x[:,:,i], (28*28))
        y[i] = test_y[i]+1
    end
    x, y
end

In [None]:
using Gen
using PyPlot

# First, we load the GenTF package and the PyCall package. The PyCall package is used because TensorFlow computation graphs are constructed using the TensorFlow Python API, and the PyCall package allows Python code to be run from Julia.
using PyCall
using GenTF
# We text load the TensorFlow and TensorFlow.nn Python modules into our scope. The `@pyimport` macro is defined by PyCall.
tf = pyimport("tensorflow")
tf.compat.v1.disable_eager_execution()
nn = tf.nn

# Next, we define a TensorFlow computation graph. The graph will have placeholders for an N x 784 matrix of pixel values, where N is the number of images that will be processed in batch, and 784 is the number of pixels in an MNIST image (28x28). There are 10 possible digit classes. The `probs` Tensor is an N x 10 matrix, where each row of the matrix is the vector of normalized probabilities of each digit class for a single input image. Note that this code is largely identical to the corresponding Python code. We provide initial values for the weight and bias parameters that are computed in Julia (it is also possible to use TensorFlow initializers for this purpose).
# input images, shape (N, 784)
xs = tf.compat.v1.placeholder(tf.float64, shape=(nothing, 784))

# weight matrix parameter for soft-max regression, shape (784, 10)
# initialize to a zeros matrix generated by Julia.
init_W = zeros(Float64, 784, 10)
W = tf.compat.v1.Variable(init_W)

# bias vector parameter for soft-max regression, shape (10,)
# initialize to a zeros vector generated by Julia.
init_b = zeros(Float64, 10)
b = tf.compat.v1.Variable(init_b)

# probabilities for each class, shape (N, 10)
probs = nn.softmax(tf.add(tf.matmul(xs, W), b), axis=1);

# Next, we construct the generative function from this graph. The GenTF package provides a `TFFunction` type that implements the generative function interface. The `TFFunction` constructor takes:
# (i) A vector of Tensor objects that will be the trainable parameters of the generative function (`[W, b]`). These should be TensorFlow variables.
# (ii) A vector of Tensor object that are the inputs to the generative function (`[xs]`). These should be TensorFlow placeholders.
# (iii) The Tensor object that is the return value of the generative function (`probs`).
tf_softmax_model = TFFunction([W, b], [xs], probs);

# The `TFFunction` constructor creates a new TensorFlow session that will be used to execute all TensorFlow code for this generative function. It is also TensorFlow possible to supply a session explicitly to the constructor. See the [GenTF documentation](https://probcomp.github.io/GenTF/dev/) for more details.
# We can run the resulting generative function on some fake input data. This causes the TensorFlow to execute code in the TensorFlow session associated with `tf_softmax_model`:
fake_xs = rand(5, 784)
probs = tf_softmax_model(fake_xs)
println("Size probs: ", size(probs))

# We can also use `Gen.initialize` to obtain a trace of this generative function.
(trace, _) = Gen.generate(tf_softmax_model, (fake_xs,));

#  Note that generative functions constructed using GenTF do not make random choices:
println("Get choices: ", Gen.get_choices(trace))

# The return value is the Julia value corresponding to the Tensor `y`:
println("Get size retval: ", size(Gen.get_retval(trace)))

# Finally, we write a generative function using the built-in modeling DSL that invokes the TFFunction generative function we just defined. Note that we wrap the call to `tf_softmax_model` in an `@addr` statement.
@gen function digit_model(xs::Matrix{Float64})
    
    # there are N input images, each with D pixels
    (N, D) = size(xs)
    
    # invoke the `net` generative function to compute the digit label probabilities for all input images
    probs = @trace(tf_softmax_model(xs), :softmax)
    @assert size(probs) == (N, 10)
    
    # sample a digit label for each of the N input images
    for i=1:N
        @trace(categorical(probs[i,:]), (:y, i)) 
    end
end;

# Let's obtain a trace of `digit_model` on the fake tiny input:
(trace, _) = Gen.generate(digit_model, (fake_xs,));
# We see that the `net` generative function does not make any random choices. The only random choices are the digit labels for each input input:
println("Digit model choices: ", Gen.get_choices(trace))


In [None]:
training_data_loader = MNISTTrainDataLoader();

# Now, we train the trainable parameters of the `tf_softmax_model` generative function  (`W` and `b`) on the MNIST traing data. Note that these parameters are stored as the state of the TensorFlow variables. We will use the [`Gen.train!`](https://probcomp.github.io/Gen/dev/ref/inference/#Gen.train!) method, which supports supervised training of generative functions using stochastic gradient opimization methods. In particular, this method takes the generative function to be trained (`digit_model`), a Julia function of no arguments that generates a batch of training data, and the update to apply to the trainable parameters.
# The `ParamUpdate` constructor takes the type of update to perform (in this case a gradient descent update with step size 0.00001), and a specification of which trainable parameters should be updated). Here, we request that the `W` and `b` trainable parameters of the `tf_softmax_model` generative function should be trained.
update = Gen.ParamUpdate(Gen.FixedStepGradientDescent(0.00001), tf_softmax_model => [W, b]);

# For the data generator, we obtain a batch of 100 MNIST training images. The data generator must return a tuple, where the first element is a set of arguments to the generative function being trained (`(xs,)`) and the second element contains the values of random choices. `train!` attempts to maximize the expected log probability of these random choices given their corresponding input values.
function data_generator()
    (xs, ys) = next_batch(training_data_loader, 100)

    @assert size(xs) == (100, 784)
    @assert size(ys) == (100,)
    constraints = Gen.choicemap()
    for (i, y) in enumerate(ys)
        constraints[(:y, i)] = y
    end
    ((xs,), constraints)
end;


In [None]:
# We run 10000 iterations of stochastic gradient descent, where each iteration uses a batch of 100 images to get a noisy gradient estimate. This might take one or two minutes.

@time scores = Gen.train!(digit_model, data_generator, update;
    num_epoch=1000, epoch_size=1, num_minibatch=1, minibatch_size=1, verbose=false);

# We plot an estimate of the objective function function over time:

plot(scores)
xlabel("iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log likelihood");