In [None]:
using NCDatasets, Flux, Distributed, DataFrames, CSV, StatsPlots
using Statistics: mean 
#Random, StatsBase
#import YAML
using DelimitedFiles: readdlm

using Measures, Plots

In [None]:
gr()

In [None]:
addprocs(3; exeflags="--project")

In [None]:
@everywhere begin
    cd("/home/ebr/projects/tsunami-inundation-emulator/article_runs_X/t457/mc8_l8_rel")
end

In [None]:
pwd()

In [None]:
readdir()

In [None]:
@everywhere begin
    using Logging
    global_logger(SimpleLogger(stdout, Logging.Info))
    include("datareader.jl")
    include("model_config.jl")
    config = DataReader.parse_config("config.yml")

    # Set here
    config["gpu"] = false
    config["batch_size"] = 10
end

In [None]:
# Load mask
const ct_mask = BitArray(readdlm("ct_mask.txt", '\t',Bool, '\n'));

In [None]:
sum(ct_mask)

In [None]:
dataset = config["train_data"]
eval_dir = joinpath(config["rundir"], "evaluation", "train")

In [None]:
# Load model
model = ModelConfig.load(joinpath(config["rundir"],"$(config["model_name"]).jls"))
Flux.reset!(model)

In [None]:
# Load batches on workers.
@everywhere begin
    reader = DataReader.Reader(config)
end
batches = RemoteChannel(()->Channel(4))
scenarios = DataReader.scenarios(dataset)

for worker in workers()
    remote_do(reader, worker, scenarios, batches)
end

nr_of_batches = ceil(countlines(dataset)/config["batch_size"])

In [None]:
@info "First pass. Computing mean and hits."
mean_target = zeros(Float32, sum(ct_mask))
count_hits = zeros(Int32, sum(ct_mask))

for batch_nr in 1:nr_of_batches
    batch = take!(batches)

    @info "Batch: $batch_nr"
    hit = sum(batch.flow_depths[ct_mask,1,:] .> 0., dims=2)
    target = batch.flow_depths[ct_mask,1,:]
    preds = relu(model(batch.etas) - batch.deformed_topographies[ct_mask,1,:])

    mean_target = mean(target, dims=2)/batch_nr .+ mean_target*((batch_nr - 1)/batch_nr)
    count_hits = count_hits .+ hit
end

In [None]:
hit_map = zeros(Float32, config["dims"]);
#fill!(rsquared_map,NaN)
hit_map[ct_mask] = count_hits

heatmap(
    hit_map[:,:]'; 
    aspect_ratio=1., 
    xlim=(1,config["dims"][1]), 
    ylim=(1,config["dims"][2]),
    #clim=(0,1),
    #c = cgrad([:red, :orange, :green], [0.5, 0.95, 0.98]),
    margins = 3mm
)

In [None]:
@info "Second pass. Computing R-squared."
mean_residual_squares = zeros(Float32, sum(ct_mask))
mean_total_squares = zeros(Float32, sum(ct_mask))

for batch_nr in 1:nr_of_batches
    batch = take!(batches)

    @info "Batch: $batch_nr"
    
    target = batch.flow_depths[ct_mask,1,:]
    preds = relu(model(batch.etas) - batch.deformed_topographies[ct_mask,1,:])
    
    residual_square = (target-preds).^2
    total_square = (target-mean_target*ones(config["batch_size"])').^2

    mean_residual_squares = mean(residual_square, dims=2)/batch_nr .+ mean_residual_squares*((batch_nr - 1)/batch_nr)
    mean_total_squares = mean(total_square, dims=2)/batch_nr .+ mean_total_squares*((batch_nr - 1)/batch_nr)
end

In [None]:
rsquared = 1 .- mean_residual_squares./mean_total_squares

In [None]:
rsquared_map = zeros(Float32, config["dims"]);
fill!(rsquared_map,NaN)
rsquared_map[ct_mask] = rsquared

In [None]:
heatmap(
    rsquared_map'; 
    aspect_ratio=1., 
    xlim=(1,config["dims"][1]), 
    ylim=(1,config["dims"][2]),
    clim=(0,1),
    c = cgrad([:red, :orange, :green], [0.5, 0.95, 0.98]),
    margins = 3mm
)

In [None]:
eval_dir

In [None]:
# write to file.
Dataset(joinpath(eval_dir,"r_squared.nc"),"c") do ds
	defDim(ds,"x",config["dims"][1])
	defDim(ds,"y",config["dims"][2])
	r = defVar(ds,"r-squared", Float32,("x","y"))
	r[:,:] = rsquared_map
end

# Plot R-squared agains hits.

In [None]:
# Read from file
model_folder = "/mnt/NGI_disks/ebr/P/2022/01/20220127/Calculations/AP3/models/tsunami-inundation-emulator/article_runs/t591/mc8_l8_rel"

ds_test = NCDataset(joinpath(model_folder, "evaluation/test", "r_squared.nc"),"r")
ds_train = NCDataset(joinpath(model_folder, "evaluation/train", "r_squared.nc"),"r")

r_square_test = ds_test["r-squared"];
hits_train = ds_train["hit-count"];
hits_test = ds_test["hit-count"];

In [None]:
p1 = heatmap(
    r_square_test[:,:]'; 
    aspect_ratio=1., 
    xlim=(1,config["dims"][1]), 
    ylim=(1,config["dims"][2]),
    clim=(0,1),
    c = cgrad([:red, :orange, :green], [0.5, 0.95, 0.98]),
    margins = 3mm
)

p2 = heatmap(
    hits_train[:,:]'; 
    aspect_ratio=1., 
    xlim=(1,config["dims"][1]), 
    ylim=(1,config["dims"][2]),
    margins = 3mm
)
plot(
    p1, p2,
    layout = (1,2),
)

In [None]:
df = DataFrame(:hits_train => hits_train[ct_mask], :hits_test => hits_test[ct_mask], :r2 => r_square_test[ct_mask]);
df_filtered = filter(row -> row.hits_test > 5 && row.r2 > -0.2, df);

In [None]:
@df df_filtered histogram2d( 
    :r2,
    :hits_train,
    bins=(30,30),
    show_empty_bins=false,
    color=cgrad(:amp, 16, categorical=true),
    xlims=(-0.2,1.),
    ylims=(0,600),
    alpha=0.8,    
    xlabel="\$r^2\$ on test data",
    ylabel="Number of training scenarios which inundates pixel",
    #dpi=300,
    labelfontsize = 9,
    aspect_ratio = 1.5/300,
    gridlinewidth=2.,
    gridlinealpha=1.,
    minorgrid=true,
    gridstyle=:dash,
    xrotation = 90,
    colorbar_title = "\n Number of pixels on map (colour)",
    colorbar_titlefontsize = 9,
    tickfontsize = 7
)

In [None]:
@df df_filtered histogram2d( 
    :hits_train,
    :r2,
    bins=(30,30),
    show_empty_bins=false,
    color=cgrad(:amp, 16, categorical=true),
    ylims=(-0.2,1.),
    xlims=(0,600),
    alpha=0.8,    
    ylabel="R-squared",
    xlabel="Hits in training set.",
    #dpi=300,
    #aspect_ratio = 1.5/300,
    gridlinewidth=2.,
    gridlinealpha=1.,
    minorgrid=true,
    gridstyle=:dash
)

In [None]:
savefig(joinpath(model_folder, "evaluation/test-train", "R2_hits_hist_transpose.pdf"))

In [None]:
@df df_filtered[1:5000,:] scatter(
    :hits, 
    :r2,
    #scale = :log10, 
    xlabel="R-squared", 
    ylabel="Nonzero flowdepths", 
    #label = "Test", 
    markershape=:circle,
    markersize = 1.5,
    alpha=0.1,
    #legend=:topleft,
    ylims=(-1,1),
    #xlims=(1e-3,0),
    #ticks=[1e-2,1e-1,1],
    dpi=300,
)