In [216]:
using Gen, PyCall, Plots
import Random, Logging

Logging.disable_logging(Logging.Info);
Random.seed!(42)


Random.TaskLocalRNG()

In [217]:
#dataset
function make_dataset(n)
    # true parameters
    prob_outlier = 0.2
    true_inlier_noise = 0.5
    true_outlier_noise = 5.0
    
    true_slope = randn()*2
    true_intercept = randn()*5
    #print the true parameters
    println("True slope: ", true_slope)
    println("True intercept: ", true_intercept)    

    xs = collect(range(-5, stop=5, length=n))
    ys = Float64[]
    for (i, x) in enumerate(xs)
        if randn() < prob_outlier
            y = randn() * true_outlier_noise
        else
            y = true_slope * x + true_intercept + randn() * true_inlier_noise
        end
        push!(ys, y)
    end
    (xs, ys)
end

make_dataset (generic function with 1 method)

In [218]:
#model, adapted from playground to only use normal distributions
@gen function regression(xs::Vector{<:Real})
    # ...

    slope ~ normal(0, 2)
    intercept ~ normal(2, 2)

    noise ~ normal(0.2, 0.03)
    prob_outlier ~ normal(0, 1)

    # Next, we generate the actual y coordinates.
    n = length(xs)
    ys = Float64[]

    for i = 1:n
        # Decide whether this point is an outlier, and set
        # mean and standard deviation accordingly
        outlier= {:data => (i, :is_outlier)} ~ normal(prob_outlier,0.5)

        if outlier > 0
            (mu, std) = (0., 10.)
        else
            (mu, std) = (xs[i] * slope + intercept, noise)
        end
        # Sample a y value for this point

        y={:data => i => :y} ~ normal(mu, std)

        push!(ys, y)

    end
    ys
end;

In [219]:
@gen function testProposal(trace)
    slope ~ normal(trace[:slope], 0.1)
    intercept ~ normal(trace[:intercept], 0.1)
    noise ~ normal(trace[:noise], 0.01)
    prob_outlier ~ normal(trace[:prob_outlier], 0.01)

    println(Gen.get_choices(trace))

    for i =1:10
        outlier=@trace(normal(prob_outlier,0.5),(:data => i => :is_outlier))
        if outlier > 0
            (mu, std) = (0., 10.)
        else
            (mu, std) = (xs[i] * slope + intercept, noise)
        end
        # Sample a y value for this point
        y=@trace(normal(mu, std), (:data,i,:y))
    end
    println("TEsting")
    return nothing
end;

In [220]:
#nn model
#inputs are:                                values
# - x values                                10    
# - y values                                10
# - address of sample (one hot encoded)     5
# - instance id (one hot encoded)           10

# - if lstm: previous hidden state

#outputs are:
# - mean
# - std

tf = pyimport("tensorflow")
keras = pyimport("keras")

nn_model = keras.models.Sequential()
nn_model.add(keras.layers.Dense(10, input_shape=(35,), activation="relu"))
nn_model.add(keras.layers.Dense(10, activation="relu"))
nn_model.add(keras.layers.Dense(2, activation="linear"))

#compile the model
nn_model.compile(loss="mse", optimizer="adam")

#train the model 

#predict the mean and std
test=rand(1,35)
mean, std = nn_model.predict(test)





1Ã—2 Matrix{Float32}:
 0.231763  0.0127213

In [221]:
# main 

xs, ys = make_dataset(10)

constraints = choicemap()
for (i, y) in enumerate(ys)
    constraints[:data => i => :y] = y
end

(trace, _) = Gen.generate(regression, (xs,), constraints)

# include a proposal
for i=1:10
    (trace, _) = Gen.mh(trace, select(:slope, :intercept, :noise, :prob_outlier))
    #(prop,_,_)=Gen.propose(testProposal,(trace,))
    #trace = Gen.update(trace,prop)    
end

println(trace)

True slope: 0.20954158357784222
True intercept: -4.827452435098614
Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Vector{<:Real}], false, Union{Nothing, Some{Any}}[nothing], var"##regression#324", Bool[0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:intercept => Gen.ChoiceOrCallRecord{Float64}(2.4551847561846762, -1.6379848590474808, NaN, true), :slope => Gen.ChoiceOrCallRecord{Float64}(-0.6027441421662084, -1.6574982763790778, NaN, true), :prob_outlier => Gen.ChoiceOrCallRecord{Float64}(0.5003066411786888, -1.0440919008084233, NaN, true), :noise => Gen.ChoiceOrCallRecord{Float64}(0.2677205327596949, 0.03980238786147261, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}(:data => Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}((10, :is_outlier) => Gen.ChoiceOrCallRecord{Float64}(-1.3791202753354233, -7.290282421680214, NaN, true), (9, :is_outlier) => Gen.Choi