## Dependencies and Setup

Load the target image, and a saved simulation on which we will try to find an optimal next placement

In [None]:
using Revise
using Paint
using Serialization
using Images, ImageShow
using Plots
using StaticArrays
using ImageFeatures
using IntervalSets
using Combinatorics
using StatsBase

In [None]:
target = float.(load("../lisa.png"))
hist = Serialization.deserialize("../output/simresult/simlog_100-prims_100000-batch_3-epoch_100-refine.bin")
baseimage = hist.history[end].current
baseimage

### Raw Difference Map

Convert the RGB difference to grayscale, take the absolute value, and normalize

In [None]:
diff = Gray.(abs.(baseimage .- target))
diff = diff ./ maximum(diff)

### Blurred Difference Map

Apply a gaussian blur to the difference map, to emulate local accumulation as in the area of a shape.
Some algorithms require more contrast, so we re-normalize for use in some cases

In [None]:
blurdiffraw = imfilter(diff, Kernel.gaussian(2)) # A 9x9 Gaussian Filter Kernel
blurdiff = blurdiffraw ./ maximum(blurdiffraw)
plot(plot(blurdiffraw), plot(blurdiff), axis=false, ticks=false)

## Sampling Strategies

### Top-N Selection

In [None]:
function sample_topN(image ; N)
    perm = reverse(sortperm([image...]))
    points = perm[1:N]
    return collect(map(linearidx -> CartesianIndices(image)[linearidx].I, points))
end

### Probabilistic Selection

In [None]:
function sample_prob(image ; N, scale_factor = 5.0)
    points = sample(1:prod(size(image)), Weights(reshape(Float32.(image).^scale_factor, prod(size(image)))), N)
    return collect(map(linearidx -> CartesianIndices(image)[linearidx].I, points))
end

### Pure Random

In [None]:
function sample_random(image ; N)
    points = sample(1:prod(size(image)), N)
    return collect(map(linearidx -> CartesianIndices(image)[linearidx].I, points))
end

### Fastcorners Feature Detection

In [None]:
function cornerthreshold(image, N)
    getpoints(thresh) = Keypoints(fastcorners(image, 12, thresh))

    if length(getpoints(0.0)) <= N
        println("Fastcorners failure, no points detected!")
        return 0.0
    end

    # binary search to find the best threshold
    lower = 0.0
    upper = 0.5
    while abs(upper - lower) > 0.01
        midpoint = (lower + upper) ./ 2
        if length(getpoints(midpoint)) < N
            upper = midpoint
        else
            lower = midpoint
        end
        # println(lower, " ", upper, " ", length(getpoints(lower)), " ", length(getpoints(upper)))
    end

    return lower
end

function sample_fastcorners(image ; N)
    return collect(map(idx -> idx.I, Keypoints(fastcorners(image, 12, cornerthreshold(image, N)))[1:N]))
end

### ORB Feature Detection

In [None]:
function sample_orb(image ; N)
    orb_params = ORB(num_keypoints = N, threshold = cornerthreshold(image, N) / 2.0)
    _, ret_keypoints = create_descriptor(image, orb_params)
    return collect(map(idx -> idx.I, ret_keypoints))
end

## Evaluation

### Plots

In [None]:
function plotpoints(image, points ; title = "")
    w, h = size(image)
    plt = plot(image, axis=false, ticks=false, xlims=(1,w), ylims=(1,h), legend=false, title=title)
    scatter!(plt, map(reverse, points), markersize=4)
    plt
end

NumPoints = 300
Samplers = [(x; N) -> [], sample_random, sample_topN, sample_prob, sample_fastcorners, sample_orb]
Titles = ["Reference", "Random", "Top N", "Probabilistic", "Fastcorners", "ORB Features"]

plots = []
for i in eachindex(Samplers)
    plttop = plotpoints(diff, Samplers[i](diff, N=NumPoints), title=Titles[i])
    pltbot = plotpoints(blurdiffraw, Samplers[i](blurdiff, N=NumPoints))
    push!(plots, plot(plttop, pltbot, layout=(2,1)))
end

plot(plots..., layout=grid(1,length(Samplers)), size=(256 * length(Samplers),256 * 2))

### Combinatorial Search

In [None]:
function evaluatesampler(sampler ; N, refine=0)
    points = sampler(diff, N = N)
    points = map(p -> Point(p[1] / 200.0, p[2] / 200.0), points)
    tris = collect(map(Triangle, (combinations(points, 3))))
    cols = averagepixel_batch(target, tris, RasterAlgorithmScanline())
    losses = drawloss_batch(target, baseimage, tris, cols, SELoss(), RasterAlgorithmScanline())

    for i=1:3
    for k=1:refine
        rngs = randn(Float32, length(tris), 6) * 0.025f0
        newtris = mutate_batch(tris, rngs)
        newcols = averagepixel_batch(target, newtris, RasterAlgorithmScanline())
        newlosses = drawloss_batch(target, baseimage, newtris, newcols, SELoss(), RasterAlgorithmScanline())
        for i=1:length(tris)
            if newlosses[i] < losses[i]
                losses[i] = newlosses[i]
                tris[i] = newtris[i]
                cols[i] = newcols[i]
            end
        end
    end
    end

    minloss, minidx = findmin(losses)
    return minloss, tris[minidx], cols[minidx]
end

@time minloss, mintri, mincol = evaluatesampler(sample_prob, N = 300, refine = 100)
println(minloss)
newimage = copy(baseimage)
draw!(newimage, mintri, RGB{Float32}(1, 0, 1), RasterAlgorithmScanline())
newimage