In [27]:
using Pkg
using Plots
using ViscousStreaming
using NBInclude
using HDF5
using JLD2
using Distributions
using PyCall
using Distances
using JSON

In [28]:
@nbinclude("DQN.ipynb")
@nbinclude("Shared.ipynb")

motion (generic function with 1 method)

In [29]:
function train(agent)
    for i in 1:params["episodes"]
        current_state = reset1()
        for j in 1:params["episode_length"]
			action = py"int"(agent.select_action_index(current_state, true))
            new_state, reward, done = motion(current_state, action + 1)
			agent.on_new_sample(current_state, action, reward, new_state, done)
			agent.replay_mem(params["batch_size"])
            current_state = new_state
            if done == true
              break
            end
        end
    end
end

train (generic function with 1 method)

In [30]:
#agent = py"DQNAgent"(params, [2,5,8,10,12,13,14,15])

In [31]:
function get_success_rate(agent, count)
    success_count = 0
    for i in 1:count
        current_state = reset1()
        for j in 1:params["episode_length"]
            action = py"int"(agent.select_action_index(current_state, false))
            current_state, reward, done = motion(current_state, action + 1)
            if done == true
                success_count += 1
                break
            end
        end
    end
    return success_count / count
end

get_success_rate (generic function with 1 method)

In [32]:
function sweep(params, base_actions, testing_actions, process_data, process_data_prefix)
    essential_actions = []
    optional_actions = []
    for testing_action in testing_actions
        process_data_key = "$process_data_prefix-$testing_action"
        if haskey(process_data, process_data_key)
            testing_success_rate = process_data[process_data_key]
        else
            parts = split(process_data_prefix, "-")
            reverse_data_key="none"
            if length(parts) > 1
                reverse_data_key = join([parts[1], testing_action, parts[2]], "-")
            end
            if haskey(process_data, reverse_data_key)
                testing_success_rate = process_data[reverse_data_key]
            else
                target_actions = vcat(base_actions, [x for x in testing_actions if x != testing_action])
                testing_agent = py"DQNAgent"(params, target_actions)
                train(testing_agent)
                testing_success_rate = get_success_rate(testing_agent, params["success_rate_average_run"])
            end
        end
        process_data[process_data_key] = testing_success_rate
        record_process_data(process_data)
        write_("sweep - tested with $testing_action - $testing_success_rate")
        if testing_success_rate < success_rate * (1.0 - params["success_rate_acceptable_error_ratio"])
            push!(essential_actions, testing_action)
        else
            push!(optional_actions, testing_action)
        end
    end
    return essential_actions, optional_actions
end

sweep (generic function with 1 method)

In [33]:
function record_phase(dict, phase, testing_actions, selected_actions)
    for testing_action in testing_actions
        dict["$phase-a$testing_action"] = "t"
    end
    for selected_action in selected_actions
        dict["$phase-a$selected_action"] = "s"
    end
    dict["phase"]=phase
    json_file_path = "result.json"  # Change the file path as needed
    open(json_file_path, "w") do file
        JSON.print(file, dict)
    end
end

record_phase (generic function with 1 method)

In [34]:
function record_process_data(dict)
    json_file_path = "process_data.json"  # Change the file path as needed
    open(json_file_path, "w") do file
        JSON.print(file, dict)
    end
end

record_process_data (generic function with 1 method)

In [35]:
function write_(text)
    open("debug.txt", "a") do file
        # Write the text to the file
        println(file, text)
    end
end

write_ (generic function with 1 method)

In [36]:
result = Dict()
process_data = Dict()

#phase 0
params = py"parameters"()
all_actions = collect(0:params["number_of_actions"]-1)
does_load = true
if does_load == false
    agent = py"DQNAgent"(params, all_actions)
    train(agent)

    success_rate = get_success_rate(agent, params["success_rate_average_run"])
    result["success_rate"] = success_rate
    write_("success_rate: $success_rate")
    testing_actions = copy(all_actions)
    selected_actions = []
    phase = 0
end
if does_load == true
    selected_actions = []
    testing_actions = []
    result = JSON.parsefile("result.json")
    success_rate = result["success_rate"]
    for phase in 0:result["phase"]
        for action in all_actions
            key="$phase-a$action"
            if haskey(result, key)
                if result[key] == "s"
                    if !(action in selected_actions)
                        push!(selected_actions, action)
                    end
                else
                    if phase == result["phase"]
                        push!(testing_actions, action)
                    end
                end
            end
        end
    end
    phase = result["phase"]
    process_data = JSON.parsefile("process_data.json")
end
record_phase(result, phase, testing_actions, selected_actions)
while !isempty(testing_actions)
    global success_rate
    global testing_actions
    global selected_actions
    global phase
    write_("$phase - phase entered")
    if phase == 0
        process_data_key_prefix = "$phase"
        essential_actions, optional_actions = sweep(params, selected_actions, testing_actions, process_data, process_data_key_prefix)
        selected_actions = vcat(selected_actions, essential_actions)
        write_("essential_actions: $essential_actions")
        write_("optional_actions: $optional_actions")
    else
        optional_actions = testing_actions
    end
    write_("selected_actions: $selected_actions")
    write_("optional_actions: $optional_actions")

    minimal_essential_actions = []
    minimal_optional_actions = []
    for optional_action in optional_actions
        process_data_key_prefix = "$phase-$optional_action"
        write_("optional_actions selection - testing with $optional_action")
        target_optional_actions = [x for x in optional_actions if x != optional_action]
        target_essential_actions, target_optional_actions = sweep(params, selected_actions, target_optional_actions, process_data, process_data_key_prefix)
        write_("selected_actions: $selected_actions")
        write_("target_essential_actions: $target_essential_actions")
        write_("target_optional_actions: $target_optional_actions")
        if length(minimal_essential_actions) == 0 || length(target_essential_actions) < length(minimal_essential_actions)
            minimal_essential_actions = copy(target_essential_actions)
            minimal_optional_actions = copy(target_optional_actions)

            if length(minimal_essential_actions) == 0
                break
            end
        end
        if length(target_optional_actions) == 0
            minimal_essential_actions = [optional_action]
            minimal_optional_actions = copy(target_essential_actions)
            break
        end
    end

    selected_actions = vcat(selected_actions, minimal_essential_actions)
    testing_actions = minimal_optional_actions
    write_("selected_actions: $selected_actions")
    write_("testing_actions: $testing_actions")
    phase += 1
    record_phase(result, phase, testing_actions, selected_actions)

    if length(testing_actions) == 1
        testing_actions = []
        phase += 1
        record_phase(result, phase, testing_actions, selected_actions)
    end
end

InterruptException: InterruptException: