In [3]:
include("../../src/industrial_stats.jl")
using .IndustrialStats: Designs, Models, OptimalityCriteria, TensorOps
using Random
using Distributions
using StatsBase
using Logging
using HDF5
using LinearAlgebra

## State

In [4]:
struct WorldState 
    countries::Array{Float32, 3} # Array of all countries
    imperialists::Vector{Int} # Vector of indices for imperialists 
    colonies::Vector{Int} # Vector of indices pointing to the imperialists
end

function get_colonies(world::WorldState)
    colony_indices = setdiff(1:size(world.countries, 1), world.imperialists)
    return world.countries[colony_indices, :, :]
end

function get_imperialists(world::WorldState)
    return world.countries[world.imperialists, :, :]
end

function get_minimizer(world::WorldState, objective::Function)
    return world.countries[partialsortperm(objective(world.countries), 1), :, :]
end

function summarize_world(w::WorldState) 
    println("-- World State --")
    println("Size of Countries: $(size(w.countries))")
    println("Num Imperialists: $(length(w.imperialists))")
    println("Num Colonies: $(size(get_colonies(w), 1))")
end

summarize_world (generic function with 1 method)

## Initialization

In [3]:
function normalize_scores_max(imperialist_scores)
    (imperialist_scores .- maximum(imperialist_scores)) ./ sum(imperialist_scores)
end

function normalize_scores_softmax(scores)
    exp.(scores) ./ sum(exp.(scores))
end

function initialize_world_empires(initializer, objective; colony_weight = 0.1, num_countries = 1500, num_imperialists = 100)
    # Initialize and compute objective
    initial_countries = initializer(num_countries)
    initial_scores = objective(initial_countries)

    # Find best scores and assign imperialists
    imperialists = partialsortperm(initial_scores, 1:num_imperialists)
    imperialist_scores = initial_scores[imperialists]

    # Use normalized scores to distribute colonies to imperialists
    normalized_imperialist_scores = normalize_scores_softmax(imperialist_scores)
    
    # Distribute colonies using imperialist score weights
    colonies = sample(imperialists, Weights(imperialist_scores), num_countries, replace = true)

    # Make imperialists their own imperialist
    colonies[imperialists] .= imperialists

    # Return initial world state
    return WorldState(initial_countries, imperialists, colonies)
end

initialize_world_empires (generic function with 1 method)

## State Update

In [4]:
function replace_imperialists(w::WorldState, old_imperialists, new_imperialists)    
    swap_dict = Dict(zip(old_imperialists, new_imperialists))
    new_colonies = [get(swap_dict, c, c) for c in w.colonies]
    return new_colonies
end

function score_empires(w::WorldState, objective::Function; weight = 0.1)
    imperialist_scores = objective(get_imperialists(w))
    active_imperialists = unique(w.colonies)
    colony_scores = [(mean ∘ objective)(w.countries[w.colonies .== i, :, :]) for i in active_imperialists]
    return [i + weight * c for (i, c) in zip(imperialist_scores, colony_scores)]
end

score_empires (generic function with 1 method)

In [5]:
function revolutionize(w::WorldState, initializer::Function; p = 0.05)
    num_countries = size(w.countries, 1)
    num_revolutions = Int(round(p * num_countries))
    colonies = get_colonies(w)
    revolution_indices = sample(1:size(colonies, 1), num_revolutions, replace = false)
    colonies[revolution_indices, :, :] .= initializer(num_revolutions)
    return WorldState(w.countries, w.imperialists, w.colonies)
end

function assimilate(world::WorldState; beta = 2)
    imperialist_view = view(world.countries, world.colonies, :, :)
    n = size(world.countries, 1)
    movement_factor = rand(n)
    rand_pert = rand(-1:1, size(imperialist_view)...)
    updated_positions = (movement_factor) .* world.countries .+ (1 .- movement_factor) .* (imperialist_view .+ rand_pert)
    updated_positions = clamp.(updated_positions, -1.0, 1.0)
    return WorldState(updated_positions, world.imperialists, world.colonies)
end

function update_imperialists(w::WorldState, objective::Function)
    num_imperialists = length(w.imperialists)
    all_scores = objective(w.countries)
    imperialists = partialsortperm(all_scores, 1:num_imperialists)
    new_imperialists = setdiff(imperialists, w.imperialists)
    old_imperialists = setdiff(w.imperialists, imperialists)
    new_colonies = replace_imperialists(w, old_imperialists, new_imperialists)
    return WorldState(w.countries, imperialists, new_colonies)
end

function update_empires(w::WorldState, objective::Function; weight = 0.1, num_bad_colonies = 3)
    empire_scores = score_empires(w, objective; weight = 0.1)
    empire_scores = normalize_scores_softmax(empire_scores)
    rand_factor = rand(length(empire_scores))
    colony_scores = objective(get_colonies(w))
    worst_colonies = partialsortperm(colony_scores, 1:num_bad_colonies, rev=true)
    imperialists = unique(w.colonies)
    fallen_empires = setdiff(w.imperialists, imperialists)
    empire_scores[isnan.(empire_scores)] .= 10e10
    new_imperialist_assignment = sample(imperialists, Weights(empire_scores), num_bad_colonies + length(fallen_empires), replace = true)
    new_colonies = replace_imperialists(w, vcat(fallen_empires, w.colonies[worst_colonies]), new_imperialist_assignment)
    return WorldState(w.countries, imperialists, new_colonies)
end

update_empires (generic function with 1 method)

## Simulation Runner

In [6]:
function run_simulation(initializer, objective; colony_weight = 0.1, num_countries = 1500, num_imperialists = 100, num_bad_colonies = 3, colony_power_weight = 0.1, max_iterations = 1000, snapshot = (x) -> nothing)
    world = initialize_world_empires(initializer, objective; colony_weight = colony_weight, num_countries = num_countries, num_imperialists = num_imperialists)
    
    for i in 1:max_iterations
        world = revolutionize(world, initializer)
        world = assimilate(world)
        world = update_imperialists(world, objective)
        world = update_empires(world, objective; weight = colony_power_weight, num_bad_colonies = num_bad_colonies)

        snapshot(world)
        
        if length(world.imperialists) == 1
            @info "Stopping after $i iterations"
            return world
        end
    end

    return world
end

run_simulation (generic function with 1 method)

## Test

In [7]:
initializer = Designs.make_initializer(8, 4)
model = Models.linear
num_bad_colonies = 3
colony_power_weight = 0.1
num_bad_colonies = 3
objective = TensorOps.squeeze ∘ OptimalityCriteria.d_criterion ∘ model
i = 1
snapshot = (w) -> h5write("world_$i.h5", "world", w.countries)


Main.IndustrialStats.TensorOps.squeeze ∘ Main.IndustrialStats.OptimalityCriteria.d_criterion ∘ Main.IndustrialStats.Models.var"#model_builder#19"{Int64, Int64, Bool, Vector{Any}, Bool, Bool}(1, 0, true, Any[], false, true)

In [8]:
world = run_simulation(
    initializer, 
    objective; 
    colony_weight = 0.1, 
    num_countries = 5000, 
    num_imperialists = 200, 
    num_bad_colonies = num_bad_colonies, 
    colony_power_weight = colony_power_weight, 
    max_iterations=10000,
    snapshot = snapshot)

WorldState(Float32[0.04404493 -0.4363388 0.14355062 -0.5571092; -0.889516 -0.55706835 1.0 -0.18299487; 0.32235792 0.2938431 -0.41585544 0.35440305; 0.1384951 0.08998396 0.2223096 0.24398533; -0.5263076 -0.9099544 -1.0 -1.0; 0.30601275 0.33448246 0.16923024 -0.42143944; -0.9645337 0.33998743 -0.83883584 0.069970064; -0.66223115 0.36040562 0.2094874 -1.0; -0.025639175 0.5440412 -1.0 -0.022812912; 0.24161692 0.12900467 -0.1503357 -0.58330375; -0.3295179 0.3453697 -0.39967692 -0.9029032; -1.0 1.0 0.105030455 -0.58999306; -0.8360618 -0.53382504 0.6323513 0.8394162; 0.80888104 0.20407964 -0.15642405 -1.0; 0.5215924 -0.4758688 -0.33990738 0.89233625; 0.3331596 0.5868452 0.86048496 0.47647238; -1.0 0.920355 1.0 0.8348902; 0.38435158 -0.9998313 0.5508039 0.19050322; -0.084912546 0.39339238 -0.21889207 -0.6374528; -0.63868856 -0.36145106 -0.1664795 0.6192451; -0.11475714 -0.033331305 -0.77637476 -1.0; -0.35210118 0.0924228 -0.95574456 0.24785763; 0.56242025 0.7419209 -0.38643795 0.6454978; 1.0 1

┌ Info: Stopping after 76 iterations
└ @ Main /workspaces/Pearson-Exploratorium/notebooks/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X12sdnNjb2RlLXJlbW90ZQ==.jl:11


WorldState(Float32[1.0 1.0 … 1.0 0.040028468; 0.35882208 0.99991703 … 0.87312305 0.67169225; … ; 0.70652074 1.0 … 1.0 -0.033492252; 0.57419246 1.0 … -0.31619775 1.0;;; 0.872703 0.8416451 … -0.9605297 -1.0; 0.023357773 0.92846894 … -1.0 -0.24901994; … ; 0.94099617 0.71490115 … -0.2081143 -0.16253315; -0.7030072 1.0 … -1.0 -0.17082007;;; -0.51982373 -0.040091313 … 1.0 -0.5094538; -1.0 0.74873286 … 0.2612113 -0.93679935; … ; 0.2272056 0.5757349 … 0.19858609 -0.7896606; -0.8276742 -0.73488307 … -0.027433546 -0.1960951;;; -0.99272376 1.0 … 1.0 1.0; -0.95911 0.25384152 … 0.19577013 0.9363902; … ; -0.8420134 1.0 … 0.32982183 1.0; 0.16758217 0.9904391 … 0.021975957 0.362308], [683], [683, 683, 683, 683, 683, 683, 683, 683, 683, 683  …  683, 683, 683, 683, 683, 683, 683, 683, 683, 683])

In [9]:
minimizer = get_minimizer(world, objective)

8×4 Matrix{Float32}:
  1.0         -1.0       -0.111047   0.0146674
  1.0          1.0       -0.88094    1.0
  0.772591     0.997878   0.999596  -1.0
 -1.0          0.914088  -1.0       -0.999206
 -1.0          1.0       -0.95493    1.0
 -1.0         -1.0       -0.880961  -1.0
 -0.85479     -0.111047   1.0        1.0
  0.00864908  -0.999723   0.772187   0.111047