# Solving "MNIST" With Gen
We model the data generating process as follows:
1. Choose a number from 0 to 9 with uniform probability.
2. For exam number, set the sampling probabilities for a grid of 9 by 9 pixels.
3. Sample each pixel from a bernoulli distribution with the corresponding probability.

In [1]:
using Gen
using Plots
using Printf

### Define prior distributions

We encode the prior knowledge as a set of matricies.

In [2]:

function get_probs(number)
    # We generate image distributions based on the
    # provide number.  The number must be between 0 and 9.
    if number == 0
        template = [1 1 1 1 1 1 1 1 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1]
    elseif number == 1
        template = [0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;]
    elseif number == 2
        template = [1 1 1 1 1 1 1 1 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;
                    1 0 0 0 0 0 0 0 0;
                    1 0 0 0 0 0 0 0 0;
                    1 0 0 0 0 0 0 0 0;
                    1 1 1 1 1 1 1 1 1;]
    elseif number == 3
        template = [1 1 1 1 1 1 1 1 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;]
    elseif number == 4
        template = [1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;]
    elseif number == 5
        template = [1 1 1 1 1 1 1 1 1;
                    1 0 0 0 0 0 0 0 0;
                    1 0 0 0 0 0 0 0 0;
                    1 0 0 0 0 0 0 0 0;
                    1 1 1 1 1 1 1 1 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;]
    elseif number == 6
        template = [1 1 1 1 1 1 1 1 1;
                    1 0 0 0 0 0 0 0 0;
                    1 0 0 0 0 0 0 0 0;
                    1 0 0 0 0 0 0 0 0;
                    1 1 1 1 1 1 1 1 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;]
    elseif number == 7
        template = [1 1 1 1 1 1 1 1 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;]
    elseif number == 8
        template = [1 1 1 1 1 1 1 1 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;]
    elseif number == 9
        template = [1 1 1 1 1 1 1 1 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    0 0 0 0 0 0 0 0 1;
                    1 1 1 1 1 1 1 1 1;]
    else
        throw(ArgumentError("Number must be between 0 and 9"))
    end

    occupied_prob = 0.97
    unoccupied_prob = 0.03
    probs = template * occupied_prob + (1 .- template) * unoccupied_prob
    return probs
end

get_probs (generic function with 1 method)

### Define the model
Here, we render each pixel as a bernoulli random variable with the corresponding probability.

In [3]:
@gen function number_model()
    # We generate a number by first choosing a number between 0 and 9
    # uniformly at random, and then generating an image of that number
    # using the get_probs function.

    function render_number(probs)
        # We render the number by sampling from a Bernoulli distribution
        # for each pixel in the image.
        rows, cols = size(probs)
        image = zeros(rows, cols)
        for i in 1:rows
            for j in 1:cols
                # We encode the noise probability as prob of 1 or 0
                # Maybe we should add a prob outlier here too?
                image[i, j] = ({(:image, i, j)} ~ bernoulli(probs[i, j]))
            end
        end
        return image
    end

    number = ({:number} ~ uniform_discrete(0, 9))
    probs = get_probs(number)
    image = render_number(probs)

    return number, image
end

function render_number(number_matrix)
    rows, cols = size(number_matrix)
    for i in 1:rows
        for j in 1:cols
            @printf("%c ", (number_matrix[i, j] > 0.5) ? '#' : ' ')
        end
        println()
    end
end;

function render_all_numbers()
    spacing = 11
    width = 5
    height = 2
    final_matrix = zeros(height * spacing, width * spacing)

    for i in 0:9
        start_x = spacing * (i % width) + 2
        start_y = spacing * (i ÷ width) + 2
        end_x = start_x + spacing - 3
        end_y = start_y + spacing - 3

        digit = get_probs(i)

        constraints = Gen.choicemap()
        constraints[:number] = i
        obs_trace, _ = Gen.generate(number_model, (), constraints)
        image = Gen.get_retval(obs_trace)[2]

        final_matrix[start_y : end_y, start_x : end_x] = image

    end
    render_number(final_matrix)
end;

### Visualize Data

In [4]:
# Here, we can simulate a random number
trace  = Gen.simulate(number_model, ())
println("Number: ", trace[:number])
render_number(Gen.get_retval(trace)[2])

Number: 5
# # # # # # # # # 
#                 
#                 
#   #             
# # #   # # # # # 
                # 
                # 
                # 
# # # # # # # # # 


In [5]:
# We can also render all numbers
render_all_numbers()

                                                                                                              
  # # # # # # # # #                     #     # # # # # # # #       # # # # # # # # #     #       #       #   
  #               #                     #                     #                     #     #               #   
  #               #                     #                     #                     #     #               #   
  #     #         #                     #                     #                     #     #               #   
  # #             #                     #     # # # # # # #   #     # # # #   # # # #     # # #   # # # # #   
  #               #                     #     #                                     #                     #   
  #               #                     #     # #                                   #                     #   
  #               #                     #     #           #                         #                     #   
 

### Do inference

We can define an inference step to use importance sampling.

In [6]:
function do_inference(model, obs_number, amount_of_computation)
    rows, cols = size(obs_number)

    observations = Gen.choicemap()
    # Create the choice map
    for i in 1:rows
        for j in 1:cols
            observations[(:data, i, j)] = obs_number[i, j]
        end
    end

    # Call importance_resampling to obtain a likely trace consistent
    # with our observations.
    observations[(:number)] = 2
    (trace, _) = Gen.importance_resampling(model, (), observations, amount_of_computation);
    return trace
end;

Here, we attempt to infer the number given observations.  Note that this
does not work!


In [7]:
constraints = Gen.choicemap()
constraints[:number] = 3
obs_trace, _ = Gen.generate(number_model, (), constraints)
image = Gen.get_retval(obs_trace)[2]
println("Number: ", obs_trace[:number])
render_number(image)

trace = do_inference(number_model, image, 10000)
println("Infered  number")
println("Number: ", trace[:number])
render_number(Gen.get_retval(trace)[2])


Number: 3
# # # # # # # # # 
                # 
                # 
                # 
# # # # # # # # # 
                # 
    #           # 
                  
# # # # # # # # # 
Infered  number


Number: 2
# # # # # #   # # 
                # 
                # 
            #   # 
# # # # # # # # # 
#                 
                  
#                 
# # # # # # # # # 
