In [2]:
using Gen;
using Random;
using StatsBase;
using CairoMakie;

##### Types

In [3]:
@enum Button a=1 b=2 c=3 d=4
@enum Direction up=1 down=2 left=3 right=4
@enum Field empty=0 obstacle=1 goal=2

struct Pos
    x::Int
    y::Int
end

const Policy = Dict{Pos, Button}

const Maze = Dict{Pos, Field}

const Controller = Dict{Direction, Button}

struct MovementProbabilities
    forward::Float64
    back::Float64
    left::Float64
    right::Float64
end;

struct Environment
    controller::Controller
    movementProbabilities::MovementProbabilities
    maze::Maze
end

struct Episode
    rewards::Vector{Int}
    visited::Vector{Pos}
end;

##### Generators

In [4]:
function generate_maze(n::Int)::Maze
    maze = Maze()
    for x in 1:n
        for y in 1:n
            maze[Pos(x, y)] = empty
        end
    end
    for i in 1:rand(1:n)
        x = rand(1:n)
        y = rand(1:n)
        maze[Pos(x, y)] = obstacle
    end
    x = rand(1:n)
    y = rand(1:n)
    maze[Pos(x, y)] = goal
    return maze
end;

In [5]:
function generate_environment(n::Int)::Tuple{Environment, Pos}
    directions = shuffle([up, down, left, right])
    controller = Dict(directions .=> [a, b, c, d])
    movementProbs = MovementProbabilities(1, 0, 0, 0)
    maze = generate_maze(n) 
    start = collect(maze)[findfirst(x -> x[2] == empty, collect(maze))][1]
    return Environment(controller, movementProbs, maze), start
end;

##### Movement in Maze

In [6]:
FORWARD_MAP = [up, down, left, right]
BACK_MAP = [down, up, right, left]
LEFT_MAP = [left, right, down, up]
RIGHT_MAP = [right, left, up, down]
DIRECTION_MAP::Vector{Vector{Direction}} = [FORWARD_MAP, BACK_MAP, LEFT_MAP, RIGHT_MAP]


function get_new_pos(pos::Pos, maze::Maze, direction::Direction)
    if direction == up
        new_pos = Pos(pos.x, pos.y + 1)
    elseif direction == down
        new_pos = Pos(pos.x, pos.y - 1)
    elseif direction == left
        new_pos = Pos(pos.x - 1, pos.y)
    elseif direction == right
        new_pos = Pos(pos.x + 1, pos.y)
    end
    if haskey(maze, new_pos) && maze[new_pos] != obstacle
        return new_pos
    else
        return pos
    end
end;


In [7]:

@dist function stoch_direction(labels, probs)::Direction
    index = categorical(probs)
    labels[index]
end;


##### Action

In [8]:

@gen function simulate_action(pos::Pos, environment::Environment, button)::Pos
    direction = collect(environment.controller)[findfirst(x -> x[2] == button, collect(environment.controller))][1]
    
    prob = environment.movementProbabilities
    stoch_dir = {:direction} ~ stoch_direction(DIRECTION_MAP[Int(direction)], [prob.forward, prob.back, prob.left, prob.right])
    return get_new_pos(pos, environment.maze, stoch_dir)
end;

##### Rewards

In [16]:
function get_reward(pos::Pos, maze::Maze)::Int
    if maze[pos] == goal
        return 100
    else
        return -1
    end
end;

### Model

In [17]:
@gen function maze_model(environment::Environment, start::Pos, episode_length::Int, controller_estimate::Controller, policy::Policy )
    maze = environment.maze
    pos = start
    playing = true
    visited = [pos]
    rewards = []
    for t in 1:episode_length
        button = policy[pos]
        new_pos = {pos => t} ~ simulate_action(pos, environment, button)
        if playing 
            pos = new_pos
            push!(visited, pos)
            if maze[pos] == goal
                playing = false
                push!(rewards, get_reward(pos, maze))
                break
            else
                push!(rewards, get_reward(pos, maze))
            end
        end
    end
    return Episode(rewards, visited)
end;

##### Policy Sampling

In [18]:
@gen function select_policy()::Policy
    policy = Policy()
    for x in 1:6
        for y in 1:6
            pos = Pos(x, y)
            button = {pos} ~ uniform_discrete(1, 4)
            policy[pos] = Button(button)
        end
    end
    return policy
end;

##### Controller Sampling

In [19]:
# @gen function select_controller()::Controller
#     buttons = [a, b, c, d]
#     controller = Controller()
#     for direction in [up, down, left, right]
#         button = {:button => direction} ~ uniform_discrete(1, length(buttons))
#         controller[direction] = buttons[button]
#         deleteat!(buttons, button)
#     end
#     return controller
# end;

In [20]:
Random.seed!(1)
environment, start = generate_environment(6);

##### Visualization

In [21]:
function print_maze(environment::Environment, start::Pos, policy=Policy(), visit_count=Dict{Pos, Int}())
    direction_symbols = Dict(
        up => "↑", 
        down => "↓", 
        left => "←", 
        right => "→"
    )
    n = sqrt(length(environment.maze))
    for y in reverse(1:n)
        row = ""
        for x in 1:n
            pos = Pos(x, y)
            if pos == start
                row *= " S"  # Start
            elseif haskey(environment.maze, pos) && environment.maze[pos] == obstacle
                row *= " #"  # Obstacle
            elseif haskey(environment.maze, pos) && environment.maze[pos] == goal
                row *= " G"  # Goal
            elseif haskey(environment.maze, pos) && environment.maze[pos] == empty
                row *= " ."  
            end
            if haskey(policy, pos)
                direction = collect(environment.controller)[findfirst(x -> x[2] == policy[pos], collect(environment.controller))][1]
                row *= direction_symbols[direction] 
            end
            if haskey(visit_count, pos)
                row *= string(visit_count[pos]) * " " # Print the number of times this position was visited
            else
                row *= "  "  # If no visit count is defined for this position, print an empty space
            end
        end
        println(row)
        println()
    end
end;


In [22]:
print_maze(environment, start, select_policy())

 .←   .→   .↓   .↓   .↓   .↑  

 .↓   .↓   .↓   .←   #→   .↓  

 .←   .→   .↑   .→   S↓   .→  

 .←   .↓   .→   #↓   .↑   .←  

 .↑   .↑   .→   .↑   #→   #←  

 .↑   G↑   .↓   .→   .←   .↓  



## Evaluation

In [43]:
function policy_evaluation(policy::Policy, environment::Environment, eps::Float64)
    state_values = Dict{Pos, Float64}()
    discount = 0.9
    for pos in keys(policy)
        state_values[pos] = 0
    end
    changes = [eps + 1]
    # while(maximum(changes) > eps)
    for i in 1:1e4
        changes = []
        for pos in keys(policy)
            values = []
            directions = [up, down, left, right]
            new_positions = [get_new_pos(pos, environment.maze, dir) for dir in directions]
            for (i, new_pos) in enumerate(new_positions)
                new_dir = directions[i]
                (weight, actual_pos) = assess(simulate_action, (pos, environment, policy[pos]), choicemap((:direction, new_dir)))
                @assert actual_pos == new_pos "actual position is not as predicted"
                if weight < Inf && weight > -Inf
                    retval = get_reward(actual_pos, environment.maze)
                    new_value = weight * (discount * state_values[actual_pos] + retval)
                    push!(values, new_value)
                end
            end
            new_state_val = sum(values)
            push!(changes, abs(new_state_val - state_values[pos]))
            state_values[pos] = new_state_val
        end
    end
    
    return state_values
end;

policy_evaluation(select_policy(), environment, 0.01)

Dict{Pos, Float64} with 36 entries:
  Pos(5, 4) => -0.0
  Pos(6, 6) => -0.0
  Pos(1, 6) => -0.0
  Pos(6, 4) => -0.0
  Pos(2, 4) => -0.0
  Pos(1, 4) => -0.0
  Pos(4, 1) => -0.0
  Pos(4, 2) => -0.0
  Pos(4, 5) => -0.0
  Pos(2, 3) => -0.0
  Pos(5, 2) => -0.0
  Pos(3, 6) => -0.0
  Pos(3, 2) => -0.0
  Pos(5, 3) => -0.0
  Pos(6, 3) => -0.0
  Pos(3, 1) => 0.0
  Pos(2, 2) => -0.0
  Pos(1, 2) => -0.0
  Pos(5, 6) => -0.0
  Pos(5, 5) => -0.0
  Pos(2, 1) => -0.0
  Pos(3, 4) => -0.0
  Pos(3, 5) => -0.0
  Pos(2, 6) => -0.0
  Pos(4, 6) => -0.0
  ⋮         => ⋮

#### Direct utility estimation

Problem: Very slow

In [172]:
# perform trial from start until terminal state is reached
policy = select_policy()
print_maze(environment, start, policy, Dict())
γ = 0.5
state_values = Dict{Pos, Float64}()
for pos in keys(policy)
    state_values[pos] = 0
end
visit_count = Dict{Pos, Int}()
for i in 1:100000
    trace = simulate(maze_model, (environment, start, 100, environment.controller, policy))
    episode = get_retval(trace)

    for (j, pos) in enumerate(episode.visited)
        # println("Position ", pos)
        if haskey(visit_count, pos)
            visit_count[pos] += 1
        else
            visit_count[pos] = 1
        end
        state_values[pos] = state_values[pos] + (1 / visit_count[pos]) * (sum(Float64[episode.rewards[k] * (γ^(k-j)) for k in j:length(episode.rewards)]) - state_values[pos]) 
    end

    for pos in keys(policy)
        # find the best new position around the currrent position
        directions = [up, down, left, right]
        new_positions = [get_new_pos(pos, environment.maze, dir) for dir in directions]
        new_dir = directions[argmax([state_values[np] for np in new_positions])]
        probabilities = []
        for button in [a, b, c, d]
            (weight, retval) = assess(simulate_action, (pos, environment, button), choicemap((:direction, new_dir)))
            push!(probabilities, weight)
        end
        new_button = Button(argmax(probabilities))
        policy[pos] = new_button
    end
end

 .↓   .↓   .↑   .←   .↑   .↓  

 .↓   .↓   .↓   .←   #←   .↓  

 .↓   .↓   .↑   .←   S←   .↓  

 .←   .↓   .↓   #↑   .↓   .↑  

 .→   .↑   .→   .↓   #↓   #←  

 .←   G↑   .↑   .←   .↓   .↓  



LoadError: InterruptException:

In [171]:

print_maze(environment, start, policy, state_values)



 .↑0.0  .↑0.0  .←-1.958762886597938  .→-1.9591836734693877  .↑0.0  .←-1.9591836734693877 

 .↑0.0  .↑0.0  .←-2.0  .↑-2.0  #↑0.0  .↑-2.0 

 .↑0.0  .↑0.0  .↓-2.0  .↓-1.960000805623239  S←-2.0  .↓-2.0 

 .↑0.0  .↑0.0  .↓0.0  #↓0.0  .→-1.9611650485436904  .↓-1.9595959595959602 

 .↑0.0  .↑0.0  .↑0.0  .↑0.0  #↓0.0  #↓0.0 

 .↑0.0  G↑0.0  .↑0.0  .↑0.0  .↑0.0  .↑0.0 



### Q-learning

In [63]:

##### Q-learning Algorithm
const QTable = Dict{Pos, Vector{Float64}}

function initialize_q_table(maze::Maze)::QTable
    q_table = QTable()
    for pos in keys(maze)
        q_table[pos] = [0.0, 0.0, 0.0, 0.0]  # Four possible directions (up, down, left, right)
    end
    return q_table
end

function epsilon_greedy_action(q_values::Vector{Float64}, epsilon::Float64)::Int
    if rand() < epsilon
        return rand(1:4)  # Explore: Random action
    else
        return argmax(q_values)  # Exploit: Best action
    end
end

function q_learning(environment::Environment, start::Pos, episodes::Int, alpha::Float64, gamma::Float64, epsilon::Float64)
    q_table = initialize_q_table(environment.maze)
    
    for episode in 1:episodes
        pos = start
        total_reward = 0
        while true
            if environment.maze[pos] == goal
                break  # Exit when reaching goal
            end
            
            # Choose action using epsilon-greedy policy
            action = epsilon_greedy_action(q_table[pos], epsilon)
            direction = Direction(action)
            
            # Get new position based on action
            new_pos = get_new_pos(pos, environment.maze, direction)
            
            # Get reward
            reward = get_reward(new_pos, environment.maze)
            
            # Update Q-value
            best_next_q = maximum(q_table[new_pos])
            q_table[pos][action] += alpha * (reward + gamma * best_next_q - q_table[pos][action])
            
            pos = new_pos
            total_reward += reward
            if goal == environment.maze[new_pos]
                break  # Stop if goal reached
            end
        end
    end
    return q_table
end


q_learning (generic function with 1 method)

In [70]:

function visualize_q_table(environment::Environment, start::Pos, q_table::QTable)
    direction_symbols = Dict(
        1 => "↑", 
        2 => "↓", 
        3 => "←", 
        4 => "→"
    )
    n = sqrt(length(environment.maze)) |> Int
    for y in reverse(1:n)
        row = ""
        for x in 1:n
            pos = Pos(x, y)
            if pos == start
                row *= " S"
            elseif environment.maze[pos] == obstacle
                row *= " #"
            elseif environment.maze[pos] == goal
                row *= " G"
            else
                best_action = argmax(q_table[pos])
                row *= " " * direction_symbols[best_action]
            end
        end
        println(row)
    end
end

visualize_q_table (generic function with 1 method)

In [71]:

q_table = q_learning(environment, start, 1000, 0.1, 0.9, 0.1)

# Visualize the learned policy (Q-table)
visualize_q_table(environment, start, q_table)

 ← ↓ ↑ ↓ → ↑
 ↓ ↓ ↓ ↓ # ←
 → ↓ ← ← S ←
 → ↓ ← # ↑ ←
 → ↓ ← → # #
 → G ↓ ↓ ↑ ↑


### Value iteration

In [72]:
function value_iteration(environment::Environment, discount::Float64, eps::Float64)
    # Initialize state values for all positions in the maze
    state_values = Dict{Pos, Float64}()
    for pos in keys(environment.maze)
        state_values[pos] = 0.0
    end
    
    policy = Dict{Pos, Button}()
    
    # Define possible actions
    actions = [a, b, c, d]
    
    # Value iteration loop
    delta = eps + 1  # Arbitrary large number to start
    while delta > eps
        delta = 0
        
        # Loop through each position in the maze
        for pos in keys(environment.maze)
            if environment.maze[pos] == goal || environment.maze[pos] == obstacle
                continue  # Skip terminal or invalid states
            end
            
            old_value = state_values[pos]
            action_values = []
            
            # Compute expected values for each action
            for action in actions
                direction = collect(environment.controller)[findfirst(x -> x[2] == action, collect(environment.controller))][1]
                new_pos = get_new_pos(pos, environment.maze, direction)
                reward = environment.maze[new_pos] == goal ? 100 : -1
                action_value = reward + discount * state_values[new_pos]
                push!(action_values, action_value)
            end
            
            # Update the state value with the best action's value
            state_values[pos] = maximum(action_values)
            delta = max(delta, abs(old_value - state_values[pos]))
        end
    end
    
    # Extract policy
    for pos in keys(environment.maze)
        if environment.maze[pos] == goal || environment.maze[pos] == obstacle
            continue
        end
        
        action_values = []
        for action in actions
            direction = collect(environment.controller)[findfirst(x -> x[2] == action, collect(environment.controller))][1]
            new_pos = get_new_pos(pos, environment.maze, direction)
            reward = environment.maze[new_pos] == goal ? 100 : -1
            action_value = reward + discount * state_values[new_pos]
            push!(action_values, action_value)
        end
        
        # Select the best action
        best_action_idx = argmax(action_values)
        policy[pos] = actions[best_action_idx]
    end
    
    return policy, state_values
end


value_iteration (generic function with 1 method)

In [73]:
optimal_policy, state_values = value_iteration(environment, 0.9, 0.01)
print_maze(environment, start, optimal_policy)

 .↓   .↓   .←   .←   .←   .←  

 .↓   .↓   .←   .←   #   .↓  

 .↓   .↓   .←   .←   S←   .←  

 .↓   .↓   .←   #   .↑   .←  

 .↓   .↓   .←   .←   #   #  

 .→   G   .←   .←   .←   .←  



### Policy iteration

In [74]:
function policy_iteration(environment::Environment, discount::Float64, eps::Float64)
    # Initialize the policy and state values
    policy = Dict{Pos, Button}()
    state_values = Dict{Pos, Float64}()
    for pos in keys(environment.maze)
        policy[pos] = a  # Start with some arbitrary action
        state_values[pos] = 0.0
    end
    
    stable = false
    actions = [a, b, c, d]
    
    # Policy iteration loop
    while !stable
        # Policy Evaluation
        while true
            delta = 0
            for pos in keys(environment.maze)
                if environment.maze[pos] == goal || environment.maze[pos] == obstacle
                    continue  # Skip terminal or invalid states
                end
                
                old_value = state_values[pos]
                action = policy[pos]
                direction = collect(environment.controller)[findfirst(x -> x[2] == action, collect(environment.controller))][1]
                new_pos = get_new_pos(pos, environment.maze, direction)
                reward = environment.maze[new_pos] == goal ? 100 : -1
                state_values[pos] = reward + discount * state_values[new_pos]
                delta = max(delta, abs(old_value - state_values[pos]))
            end
            
            if delta < eps
                break
            end
        end
        
        # Policy Improvement
        stable = true
        for pos in keys(environment.maze)
            if environment.maze[pos] == goal || environment.maze[pos] == obstacle
                continue
            end
            
            old_action = policy[pos]
            action_values = []
            
            for action in actions
                direction = collect(environment.controller)[findfirst(x -> x[2] == action, collect(environment.controller))][1]
                new_pos = get_new_pos(pos, environment.maze, direction)
                reward = environment.maze[new_pos] == goal ? 100 : -1
                action_value = reward + discount * state_values[new_pos]
                push!(action_values, action_value)
            end
            
            # Choose the best action
            best_action_idx = argmax(action_values)
            policy[pos] = actions[best_action_idx]
            
            if policy[pos] != old_action
                stable = false  # The policy has changed, so we continue iterating
            end
        end
    end
    
    return policy, state_values
end


policy_iteration (generic function with 1 method)

In [75]:

optimal_policy, state_values = policy_iteration(environment, 0.9, 0.01)
print_maze(environment, start, optimal_policy)

 .↓   .↓   .←   .←   .←   .←  

 .↓   .↓   .←   .←   #←   .↓  

 .↓   .↓   .←   .←   S←   .←  

 .↓   .↓   .←   #←   .↑   .←  

 .↓   .↓   .←   .←   #←   #←  

 .→   G←   .←   .←   .←   .←  



## Conditioning Attempt

In [None]:
@gen function condition_on_outcome(trace, pos::Pos, button::Button, new_pos::Pos, reward::Int)::Pos

    ## sample maze based on outcome
    belief_maze, belief_controller = get_retval(trace)
    n = sqrt(length(belief_maze))
    direction = belief_controller[button]
    labels = [goal, obstacle, empty]
    has_goal = false
    maze = Maze()
    if reward == 100
        maze[new_pos] = {:maze => new_pos} ~ labeled_categorical(labels, [0.8, 0.1, 0.1])
    elseif new_pos == pos
        maze[new_pos] = {:maze => new_pos} ~ labeled_categorical(labels, [0.1, 0.8, 0.1])
    else
        maze[new_pos] = {:maze => new_pos} ~ labeled_categorical(labels, [0.1, 0.1, 0.8])
    end
    if maze[new_pos] == goal
        has_goal = true
    end
    
    for x in 1:n
        for y in 1:n
            if x == new_pos.x && y == new_pos.y
                continue
            end
            if has_goal
                maze[Pos(x,y)] = {:maze => Pos(x,y)} ~ labeled_categorical(labels, [0.0, 0.15, 0.85])
            else
                maze[Pos(x,y)] = {:maze => Pos(x,y)} ~ labeled_categorical(labels, [0.1, 0.1, 0.8])
            end
            if maze[Pos(x, y)] == goal
                has_goal = true
            end
        end
    end

    ## sample controller based on outcome
    directions = [up, down, left, right]

    direction = new_pos.x - pos.x > 0 ? right : new_pos.x - pos.x < 0 ? left : new_pos.y - pos.y > 0 ? up : down
    
    controller = Controller()
    controller[button] = {:ctrl => button} ~ labeled_categorical([direction], [1.0])
    filter!(x -> x != direction, directions)
    for b in filter(x -> x != button, [a, b, c, d])
        direction = {:ctrl => b} ~ labeled_categorical(directions, [1/length(directions) for _ in directions])
        controller[b] = direction
        filter!(x -> x != direction, directions)
    end

    ## sample transitions based on maze, controller and outcome
    for k in keys(maze)
        for b in keys(controller)
            if k == pos  && b == button
                target = {k => b} ~ labeled_categorical([new_pos], [1.0])
            else
                possible_targets = possible_target_pos(k, maze)
                target = {k => b} ~ labeled_categorical(possible_targets, fill(1/length(possible_targets), length(possible_targets)))
            end
        end
    end
end;



Random.seed!(1)
(trace,_) = generate(belief_distribution, (6,))
pos = start
belief_maze, belief_controller = get_retval(trace)


for i in 1:1000
    policy, _ = value_iteration(belief_maze, belief_controller, 0.9, 0.01)
    # print_maze(environment, pos, policy)
    new_pos = simulate_action(pos, environment, policy[pos])
    reward = get_reward(new_pos, environment.maze)
    println("Pos: ", pos, " Button: ", policy[pos], " New Pos: ", new_pos, " Reward: ", reward)

    (trace,_) = mh(trace, condition_on_outcome, (pos, policy[pos], new_pos, reward))
    belief_maze, belief_controller = get_retval(trace)

    println(belief_controller)
    print_maze(belief_maze, start)
    print_maze(environment, pos, policy)
    println(get_choices(trace))
    pos = new_pos
end

In [1]:
@gen function simulate_episode(environment::Environment, start::Pos, episode_length::Int, controller_estimate::Controller, policy::Policy )
    maze = environment.maze
    pos = start
    playing = true
    visited = [pos]
    rewards = []
    for t in 1:episode_length
        button = policy[pos]
        new_pos = {pos => t} ~ simulate_action(pos, environment, button)
        if playing 
            pos = new_pos
            push!(visited, pos)
            if maze[pos] == goal
                playing = false
                push!(rewards, get_reward(pos, maze))
                break
            else
                push!(rewards, get_reward(pos, maze))
            end
        end
    end
    return Episode(rewards, visited)
end;

LoadError: LoadError: UndefVarError: `@gen` not defined
in expression starting at In[1]:1

In [None]:
@gen function maze_model(as)
    T = length(as)

    # prior for reward sampling variance
    σ ~ exponential(1)

    # prior for reward means
    q = Vector{Real}(undef, k)
    for i in 1:k
        q[i] = {(:q, i)} ~ normal(0, 1) # For the moment we assume that the prior for generating random bandits is known.
    end

    rs = []
    for t in 1:T
        r = {(:r,t)} ~ normal(q[as[t]], σ) # We also do not assume that the reward variance is known. We infer it from the data.
        push!(rs, r)
    end
    
    return rs 
end

In [None]:
## version where a new policy is learned in each step

Random.seed!(2)
policies = [random_policy(6)]
episodes = choicemap()

steps = 10000
traces = []
for i in 1:10
    policy = policies[end]
    episode = simulate_episode(environment.maze, environment.controller, start, 100, policy)
    for (j, pos) in enumerate(episode.visited)
        if j < length(episode.visited)
            episodes[i => :episode => pos => j => :new_pos] = episode.visited[j+1]
        end
    end
    trace, _ = generate(maze_model, (policies, start), episodes)
    traces = [trace]
    for i in 1:steps
        trace = block_update(trace)
        push!(traces, trace)
    end
    belief_maze, belief_controller = get_retval(traces[end])
    new_policy, _ = value_iteration(belief_maze, belief_controller, 0.9, 0.01)
    # print_maze(environment, start, new_policy)
    # println()
    push!(policies, new_policy)
end


In [None]:
# function get_moves(trace)

#     choices = get_choices(trace)
#     episodes = nested_view(choices)[:episode]
#     moves = Dict()

#     for ep_key in collect(keys(episodes))
#         episode = episodes[ep_key]
#         for step_key in collect(keys(episode))
#             step = episode[step_key]
#             for pos_key in collect(keys(step))
#                 pos = step[pos_key]
#                 if !haskey(moves, pos_key)
#                     moves[pos_key] = Dict()
#                 end
#                 for button_key in collect(keys(pos))
#                     if haskey(moves[pos_key], button_key)
#                         push!(moves[pos_key][button_key], pos[button_key][:new_pos])
#                     else 
#                         moves[pos_key][button_key] = [pos[button_key][:new_pos]]
#                     end
#                 end
#             end
#         end
#     end

#     return moves
# end;


In [None]:
# @gen function propose_field_based_on_moves(trace, pos::Pos, moves::Dict)
#     goal_x = get_choices(trace)[:maze => :goal_x]
#     goal_y = get_choices(trace)[:maze => :goal_y]
#     frac_obstacles = trace[:maze => :frac_obstacles]
#     gp = 0.99

#     labels = [obstacle, empty, goal]
    
#     if length(keys(moves)) > 0
#         # set new sample for current position
#         {:maze => pos} ~ labeled_categorical(labels, [frac_obstacles*gp, (1-frac_obstacles)*gp, 1-gp])
#         # and set probable fields for target positions of moves
#         for button in keys(moves)
#             probs = get_choices(trace)[:ctrl => button]
#             dir = DIRECTIONS[argmax(probs)]
#             if dir == up
#                 target =  Pos(pos.x, pos.y + 1)
#             elseif dir == down
#                 target =  Pos(pos.x, pos.y - 1)
#             elseif dir == left
#                 target = Pos(pos.x - 1, pos.y)
#             elseif dir == right
#                 target = Pos(pos.x + 1, pos.y)
#             end
#             # if any(new_pos -> new_pos == pos, moves[button]) && has_value(get_choices(trace), :maze => target)
#             #     {:maze => target} ~ labeled_categorical(labels, [0.8*gp, 0.2*gp, 1-gp])
#             # else
#             #     {:maze => target} ~ labeled_categorical(labels, [0.2*gp, 0.8*gp, 1-gp])
#             # end
#         end
#     else
#         ## no moves means this is either the goal, or an obstacle
#         # if pos.x == goal_x && pos.y == goal_y
#         #     new_goal_x, new_goal_y = {*} ~ propose_goal(trace)
#         #     println("proposed goal: ", new_goal_x, new_goal_y)
#         #     if new_goal_x !== pos.x || new_goal_y !== pos.y
#         #         new_goal = {:maze => Pos(new_goal_x, new_goal_y)} ~ labeled_categorical([goal], [1.0])
#         #         # if there is a new goal, then this should be an obstacle
#         #         {:maze => pos} ~ labeled_categorical(labels, [0.8*gp, 0.2*gp, 1-gp])
#         #     else
#         #         # this still needs to be resampled for the assess in metropolis_hastings
#         #         {:maze => pos} ~ labeled_categorical(labels, [0.8*gp, 0.2*gp, 1-gp])
#         #     end
#         # else
#             # not the goal, so it is probably an obstacle
#             {:maze => pos} ~ labeled_categorical(labels, [0.8*gp, 0.2*gp, 1-gp])
#         # end
#     end
    

# end;

In [1]:
# @gen function propose_field_based_on_moves(trace, pos::Pos, moves::Dict, new_goal_x, new_goal_y)
#     println("pos is ", pos)
#     goal_x = get_choices(trace)[:maze => :goal_x]
#     goal_y = get_choices(trace)[:maze => :goal_y]
#     frac_obstacles = trace[:maze => :frac_obstacles]

#     labels = [obstacle, empty]
#     field_above = Pos(pos.x, pos.y + 1)
#     field_below = Pos(pos.x, pos.y - 1)
#     field_left  = Pos(pos.x - 1, pos.y)
#     field_right = Pos(pos.x + 1, pos.y)
#     adj_fields = [field_above, field_below, field_left, field_right]

#     if pos.x == new_goal_x && pos.y == new_goal_y
#         {:maze => :goal_x} ~ labeled_categorical([new_goal_x], [1.0])
#         {:maze => :goal_y} ~ labeled_categorical([new_goal_y], [1.0])
#         {:maze => pos} ~ labeled_categorical([goal], [1.0])
#         return
#     end


#     # sample adjacent fields ------------
#     for button in keys(moves)
#         probs = get_choices(trace)[:ctrl => button]
#         dir = argmax(probs)
#         target = adj_fields[dir]
#         if length(keys(moves)) > 0
#             if any(new_pos -> new_pos == pos, moves[button]) && has_value(get_choices(trace), :maze => target)
#                 {:maze => target} ~ labeled_categorical(labels, [0.8, 0.2])
#             else
#                 {:maze => target} ~ labeled_categorical(labels, [0.2, 0.8])
#             end
#         elseif target.x != new_goal_x || target.y != new_goal_y
#             {:maze => target} ~ labeled_categorical(labels, [frac_obstacles, (1-frac_obstacles)])
#         end
#         println("target ", target)
#     end

#     # sample this field --------------
#     if length(keys(moves)) > 0 && !(pos.x == new_goal_x && pos.y == new_goal_y)
#             println("there are moves ")
#             {:maze => pos} ~ labeled_categorical(labels, [frac_obstacles, (1-frac_obstacles)])

#         # no moves means this is either the goal or an obstacle
#     else 
#         if pos.x == goal_x && pos.y == goal_y
#             println("pos is goal ")
#             if new_goal_x != pos.x || new_goal_y != pos.y
#                 # if there is a new goal, then this should be an obstacle
#                 println("goal is new ")
#                 {:maze => pos} ~ labeled_categorical(labels, [0.8, 0.2])
#             end
#         elseif !(pos.x == new_goal_x && pos.y == new_goal_y)
#             println("pos is not goal ")
#             # not the goal, so it is probably an obstacle
#             {:maze => pos} ~ labeled_categorical(labels, [0.8, 0.2])

#         end
#     end   

# end;

In [None]:
function metropolis_hastings_luisa(
    trace, proposal::GenerativeFunction, proposal_args::Tuple;
    check=false, observations=EmptyChoiceMap())
    
    model_args = get_args(trace)
    argdiffs = map((_) -> NoChange(), model_args)
    proposal_args_forward = (trace, proposal_args...,)
    (fwd_choices, fwd_weight, _) = propose(proposal, proposal_args_forward)
    println("fwd_choices", fwd_choices)
    (new_trace, weight, _, discard) = update(trace,
        model_args, argdiffs, fwd_choices)
        println("after update")
    proposal_args_backward = (new_trace, proposal_args...,)
    (bwd_weight, _) = assess(proposal, proposal_args_backward, discard)
    alpha = weight - fwd_weight + bwd_weight
    check && check_observations(get_choices(new_trace), observations)
    if log(rand()) < alpha
        # accept
        return (new_trace, true)
    else
        # reject
        return (trace, false)
    end
end;


In [None]:
@gen function propose_smart_maze(trace, all_moves::Dict)

    ## propose_frac_obstacles
    
    current_value = get_choices(trace)[:maze => :frac_obstacles]
    concentration = 10
    
    # Calculate alpha and beta
    α = current_value * concentration
    β = (1 - current_value) * concentration
    
    # Create Beta distribution
    frac_obstacles = {:maze => :frac_obstacles} ~ beta(α, β)
    
        ## Propose goal -------------------------------------------------
    
        choices = get_choices(trace)
        maze_choices = get_submap(choices, :maze)
        episodes = nested_view(choices)[:episode]
        goal_x = maze_choices[:goal_x]
        goal_y = maze_choices[:goal_y]
        last_positions = []
        for i in collect(keys(episodes))
            actions = episodes[i]
            max_action = maximum(collect(keys(actions)))
            if max_action == 100
                continue
            else
                pos = collect(keys(actions[max_action]))[1]
                new_pos = actions[max_action][pos]
                action = collect(keys(new_pos))[1]
                push!(last_positions, new_pos[action][:new_pos])
            end
        end
        maze = collect(get_values_shallow(maze_choices))
        n = Int(sqrt(length(maze) - 3))
        xs = [pos.x for pos in last_positions]
        x_probs = generate_probabilities(xs, n)
        ys = [pos.y for pos in last_positions]
        y_probs = generate_probabilities(ys, n)
    
        new_goal_x = {:maze => :goal_x} ~ categorical(x_probs)
        new_goal_y = {:maze => :goal_y} ~ categorical(y_probs)   
        {:maze => Pos(new_goal_x, new_goal_y)} ~ labeled_categorical([goal], [1.0])
    
        labels = [obstacle, empty]
        targets = Dict()
    
        # Propose fields -------------------------------------------------
        # for x in 1:n
        #     for y in 1:n
        #         pos = Pos(x,y)
    
        #         field_above = Pos(pos.x, pos.y + 1)
        #         field_below = Pos(pos.x, pos.y - 1)
        #         field_left  = Pos(pos.x - 1, pos.y)
        #         field_right = Pos(pos.x + 1, pos.y)
        #         adj_fields = [field_above, field_below, field_left, field_right]
        #         moves = haskey(all_moves, pos) ? all_moves[pos] : Dict()
    
        #         # sample adjacent fields ------------
        #         for button in keys(moves)
        #             probs = get_choices(trace)[:ctrl => button]
        #             dir = argmax(probs)
        #             target = adj_fields[dir]
        #             println("target ", target)
        #             if haskey(targets, target) || (target.x == new_goal_x && target.y == new_goal_y) ||(target.x == goal_x && target.y == goal_y)
        #                 continue
        #             end
        #             if length(keys(moves)) > 0
        #                 if any(new_pos -> new_pos == pos, moves[button]) && has_value(get_choices(trace), :maze => target)
        #                     {:maze => target} ~ labeled_categorical(labels, [0.8, 0.2])
        #                      targets[target] = true
        #                 else
        #                     {:maze => target} ~ labeled_categorical(labels, [0.2, 0.8])
        #                     targets[target] = true
        #                 end
        #             else
        #                 {:maze => target} ~ labeled_categorical(labels, [frac_obstacles, (1-frac_obstacles)])
        #                 targets[target] = true
        #             end
        #         end
        #     end
        # end
        for x in 1:n
            for y in 1:n
                pos = Pos(x,y)
                # sample this field --------------
                if haskey(targets, pos)
                    continue
                end
                if length(keys(moves)) > 0 && !(pos.x == new_goal_x && pos.y == new_goal_y)
                        # println("there are moves ", pos)
                        {:maze => pos} ~ labeled_categorical(labels, [0.2, 0.8])
    
                    # no moves means this is either the goal or an obstacle
                else 
                    # println("no moves ", pos)
                    if pos.x == goal_x && pos.y == goal_y
                        # println("pos is goal ")
                        if new_goal_x != pos.x || new_goal_y != pos.y
                            # if there is a new goal, then this should be an obstacle
                            # println("goal is new ")
                            {:maze => pos} ~ labeled_categorical(labels, [frac_obstacles, (1-frac_obstacles)])
                        end
                    elseif !(pos.x == new_goal_x && pos.y == new_goal_y)
                        # println("pos is not goal ")
                        # not the goal, 
                        {:maze => pos} ~ labeled_categorical(labels, [frac_obstacles, (1-frac_obstacles)])
    
                    end
                end   
            end
        end
    
    
    end;

In [None]:
@gen function maze_model_unknown_maze(past_policies::Vector{Policy}, start::Pos)
    T = length(past_policies)
    @assert T > 0
    n = Int(sqrt(length(past_policies[1])))

    controller = {:ctrl} ~ select_controller()
    maze = {:maze} ~ select_maze(n)

    episodes = []
    for t in 1:T
        policy = past_policies[t]
        episode = {:episode => t} ~ simulate_episode(maze, controller, start, 100, policy)
        push!(episodes, episode)
    end
    
    return maze, controller
end;

In [None]:

function block_update_maze(trace)
    for button in [a,b,c,d]
        (trace,_) = mh(trace, select(:ctrl => button))
    end
    
    maze_choices = get_submap(get_choices(trace), :maze)
    for (pos, field) in collect(get_values_shallow(maze_choices))
        (trace,_) = mh(trace, select(:maze => pos))
    end

    return trace
end;

In [None]:
Random.seed!(7)
policies = [random_policy(6) for i in 1:99]
push!(policies, optimal_policy)

function generate_episodes(environment, policies)
    episodes = choicemap()

    for (i, policy) in enumerate(policies)
        episode = simulate_episode(environment.maze, environment.controller, start, 100, policy)
        for (j, pos) in enumerate(episode.visited)
            if j < length(episode.visited)
                episodes[:episode => i => j => pos => policy[pos] => :new_pos] = episode.visited[j+1]
            end
        end
    end
    return episodes
end

episodes = generate_episodes(environment, policies)


In [1]:
using Serialization

# Save block_maze_traces to a file
serialize("block_maze_traces.jls", block_maze_traces)
# Load block_maze_traces from a file
# block_maze_traces = deserialize("block_maze_traces.jls")

LoadError: UndefVarError: `block_maze_traces` not defined

In [None]:
using StatsBase  # For the `countmap` function

function generate_probabilities(xs, n)
    # Step 1: Count occurrences of each number from 1 to n in xs
    counts = countmap(xs)  # Returns a dictionary with counts for each number
    
    # Step 2: Create an array to hold the probabilities
    probs = zeros(Float64, n)

    # Step 3: Assign probabilities proportional to counts and distance from high-count indices
    # Define a decay function that decreases with distance (Gaussian-like decay for example)
    function decay(distance, sigma=1.0)
        return exp(-distance^2 / (2 * sigma^2))
    end
    
    for i in 1:n
        # Base probability is proportional to count
        count_prob = haskey(counts, i) && counts[i] > 0 ? counts[i] * 10 : 1
        
        # Compute distance-based contribution
        distance_effect = 0.0
        for (j, count) in counts
            if count > 0
                distance_effect += count * decay(abs(i - j))
            end
        end
        
        # Combine count-based probability and distance effect
        probs[i] = count_prob + distance_effect
    end
    
    # Step 4: Normalize the probabilities so they sum to 1
    total = sum(probs)
    probs ./= total  # Divide each probability by the total sum to normalize
    
    return probs
end;


In [None]:
@gen function propose_frac_obstacles(trace)

    current_value = get_choices(trace)[:maze => :frac_obstacles]
    concentration = 1

    # Calculate alpha and beta
    α = current_value * concentration
    β = (1 - current_value) * concentration

    # Create Beta distribution
    frac_obstacles = {:maze => :frac_obstacles} ~ beta(α, β)
end;

In [None]:
@gen function propose_maze(trace)

    ## propose goal ---------------------------
    choices = get_choices(trace)
    maze_choices = get_submap(choices, :maze)
    episodes = nested_view(choices)[:episode]
    goal_x = maze_choices[:goal_x]
    goal_y = maze_choices[:goal_y]
    frac_obstacles = choices[:maze => :frac_obstacles]
    last_positions = []
    for i in collect(keys(episodes))
        actions = episodes[i]
        max_action = maximum(collect(keys(actions)))
        if max_action == 100
            continue
        else
            pos = collect(keys(actions[max_action]))[1]
            new_pos = actions[max_action][pos]
            action = collect(keys(new_pos))[1]
            push!(last_positions, new_pos[action][:new_pos])
        end
    end
    maze = collect(get_values_shallow(maze_choices))
    n = Int(sqrt(length(maze) - 3))
    xs = [pos.x for pos in last_positions]
    x_probs = generate_probabilities(xs, n)
    ys = [pos.y for pos in last_positions]
    y_probs = generate_probabilities(ys, n)

    new_goal_x = {:maze => :goal_x} ~ categorical(x_probs)
    new_goal_y = {:maze => :goal_y} ~ categorical(y_probs)   

    ## propose fields ---------------------------
    labels = [obstacle, empty]
    for x in 1:n
        for y in 1:n
            addr = Pos(x,y)
            if x == goal_x && y == goal_y
                if new_goal_x !== x || new_goal_y !== y
                    new_goal = {:maze => Pos(new_goal_x, new_goal_y)} ~ labeled_categorical([goal], [1.0])
                    {:maze => addr} ~ labeled_categorical(labels, [frac_obstacles, 1-frac_obstacles])
                end
            elseif x != new_goal_x || y != new_goal_y
                {:maze => addr} ~ labeled_categorical(labels, [frac_obstacles, 1-frac_obstacles])
            end
        end
    end
end;

In [None]:
function drift_update_unknown_maze(trace)
    for button in [a,b,c,d]
        (trace,_) = mh(trace, propose_controller, (button,))
    end
    (trace,_) = mh(trace, propose_frac_obstacles, ())
    
    (trace,_) = mh(trace, propose_maze, ())
    return trace
end;

In [None]:
@gen function propose_goal(trace)
    choices = get_choices(trace)
    maze_choices = get_submap(choices, :maze)
    episodes = nested_view(choices)[:episode]
    frac_obstacles = nested_view(choices)[:maze => :frac_obstacles]
    goal_x = maze_choices[:goal_x]
    goal_y = maze_choices[:goal_y]
    last_positions = []
    for i in collect(keys(episodes))
        actions = episodes[i]
        max_action = maximum(collect(keys(actions)))
        if max_action == 100
            continue
        else
            pos = collect(keys(actions[max_action]))[1]
            new_pos = actions[max_action][pos]
            action = collect(keys(new_pos))[1]
            push!(last_positions, new_pos[action][:new_pos])
        end
    end
    maze = collect(get_values_shallow(maze_choices))
    n = Int(sqrt(length(maze) - 3))
    xs = [pos.x for pos in last_positions]
    x_probs = generate_probabilities(xs, n)
    ys = [pos.y for pos in last_positions]
    y_probs = generate_probabilities(ys, n)

    new_goal_x = {:maze => :goal_x} ~ categorical(x_probs)
    new_goal_y = {:maze => :goal_y} ~ categorical(y_probs)   
    {:maze => Pos(new_goal_x, new_goal_y)} ~ labeled_categorical([goal], [1.0])
    if goal_x != new_goal_x || goal_y != new_goal_y
        {:maze => Pos(goal_x, goal_y)} ~ labeled_categorical([obstacle, empty], [frac_obstacles, 1-frac_obstacles])
    end

end;


In [None]:
@gen function propose_field_based_on_moves(trace, pos::Pos, moves::Dict)
    # println("pos is ", pos)
    labels = [obstacle, empty]
    field_above = Pos(pos.x, pos.y + 1)
    field_below = Pos(pos.x, pos.y - 1)
    field_left  = Pos(pos.x - 1, pos.y)
    field_right = Pos(pos.x + 1, pos.y)
    adj_fields = [field_above, field_below, field_left, field_right]

    frac_obstacles = get_choices(trace)[:maze => :frac_obstacles]
    targets = Dict()

    # sample adjacent fields ------------
    for button in keys(moves)
        probs = get_choices(trace)[:ctrl => button]
        dir = argmax(probs)
        target = adj_fields[dir]
        targets[target] = true
        if length(keys(moves)) > 0
            if any(new_pos -> new_pos == target, moves[button])
                # println((pos, dir, target, moves[button]))
                {:maze => target} ~ labeled_categorical([empty], [1.0])
            elseif has_value(get_choices(trace), :maze => target)
                {:maze => target} ~ labeled_categorical(labels, [0.8, 0.2])
            end
        elseif has_value(get_choices(trace), :maze => target)
            {:maze => target} ~ labeled_categorical(labels, [0.8, 0.2])
        end
        # println("target ", target)
    end

    # sample this field --------------
    if any(b -> length(moves[b]) > 0, keys(moves))
        # println("there are moves ")
        {:maze => pos} ~ labeled_categorical([empty], [1.0])
    else 
        {:maze => pos} ~ labeled_categorical(labels, [0.8, 0.2])
    end   

    for field in adj_fields
        if !haskey(targets, field) && has_value(get_choices(trace), :maze => field)
            {:maze => field} ~ labeled_categorical(labels, [frac_obstacles, (1-frac_obstacles)])
        end
    end

end;

In [None]:
function get_moves(trace)

    choices = get_choices(trace)
    episodes = nested_view(choices)[:episode]
    moves = Dict()

    for ep_key in collect(keys(episodes))
        episode = episodes[ep_key]
        for step_key in collect(keys(episode))
            step = episode[step_key]
            for pos_key in collect(keys(step))
                pos = step[pos_key]
                if !haskey(moves, pos_key)
                    moves[pos_key] = Dict()
                end
                for button_key in collect(keys(pos))
                    if haskey(moves[pos_key], button_key)
                        push!(moves[pos_key][button_key], pos[button_key][:new_pos])
                    else 
                        moves[pos_key][button_key] = [pos[button_key][:new_pos]]
                    end
                end
            end
        end
    end

    return moves
end;


In [None]:
function drift_update_smart_maze(trace, moves::Dict)
    for button in [a,b,c,d]
        (trace,_) = mh(trace, propose_controller, (button,))
    end
    
    (trace,_) = mh(trace, propose_frac_obstacles, ())

    old_goal_x = get_choices(trace)[:maze => :goal_x]
    old_goal_y = get_choices(trace)[:maze => :goal_y]

    (trace,_) = mh(trace, propose_goal, ())

    new_goal_x = get_choices(trace)[:maze => :goal_x]
    new_goal_y = get_choices(trace)[:maze => :goal_y]    

    maze_choices = get_submap(get_choices(trace), :maze)
    for (pos, field) in collect(get_values_shallow(maze_choices))
        if pos !== :frac_obstacles && pos !== :goal_x && pos !== :goal_y && pos != Pos(old_goal_x, old_goal_y) && pos != Pos(new_goal_x, new_goal_y)
            m = haskey(moves, pos) ? moves[pos] : Dict()
            (trace,_) = mh(trace, propose_field_based_on_moves, (pos, m) ) 
        end
    end


    return trace
end;

In [None]:
function interact(environment, trace, smart_maze_traces, episodes, policies)
    maze, controller = get_retval(trace)
    policy, state_values = value_iteration(maze, controller, 0.9, 0.01)
    print_maze(environment, start, policy)

    i = length(get_submaps_shallow(get_submap(episodes, :episode))) + 1
    episode = simulate_episode(environment.maze, environment.controller, start, 100, policy)
    for (j, pos) in enumerate(episode.visited)
        if j < length(episode.visited)
            episodes[:episode => i => j => pos => policy[pos] => :new_pos] = episode.visited[j+1]
        end
    end
    
    push!(policies, policy)

    trace, _ = generate(maze_model_unknown_maze, (policies, start), episodes)

    steps = 20
    moves = get_moves(trace)
    for i in 1:steps
        trace = drift_update_smart_maze(trace, moves)
        push!(smart_maze_traces, trace)
    end
    return trace
end;

In [None]:
    
# @gen function select_mazemodel(n::Int, start_pos::Pos, goal_pos::Pos)::MazeModel
#     mazemodel = Dict{Pos, Float64}()
#     for x in 1:n
#         for y in 1:n
#             pos = Pos(x, y)
#             if x == start_pos.x && y == start_pos.y
#                 mazemodel[pos] = {pos} ~ labeled_categorical([Float64(1)], [1]) # because this field is empty if it is the start
#             elseif x == goal_pos.x && y == goal_pos.y
#                 mazemodel[pos] = {pos} ~ labeled_categorical([Float64(1)], [1]) # because this field is empty if it is the goal
#             else
#                 mazemodel[pos] = {pos} ~ beta(2,1) 
#             end
#         end
#     end
#     return mazemodel
# end;

In [None]:
# mazemodel = select_mazemodel(3, Pos(1,1),Pos(3,2))

# plot_mazemodel(mazemodel)

