In [None]:
using PyCall
using SparseArrays
using Optim
using CUDA
using Zygote
using Plots
using StatsBase
using LinearAlgebra
using DataFrames
using StatsPlots
using ChainRulesCore
using CSV
using Random
using Juliana
using LineSearches
using Statistics
using NPZ

Random.seed!(7432059)
CUDA.seed!(3875)

# Import Python code

In [None]:
standalone = pyimport("pyftpp.standalone")
dose = pyimport("pyftpp.dose");
dicom = pyimport("pyftpp.dicom")
pydicom = pyimport("pydicom")
logging = pyimport("logging")
logging.basicConfig(level="INFO")

Dose = dose.Dose;
export_to_dicom = dicom.export_to_dicom;

# Config

In [None]:
patient_ID = "test_06"
data_dir = "/data/user/bellotti_r/data";
output_dir = "./output/$patient_ID"
plan_file = "$data_dir/clinical_plans_300_iter/$(patient_ID)_0.json";

In [None]:
fiona_standalone_bin_path = "./pyftpp/bin"
fiona_jar_path = "$fiona_standalone_bin_path/ch.psi.ftpp.standalone.planner-1.0.7.jar";

In [None]:
optimisation_grid_resolution = 0.17f0;
spot_spacing = 0.4f0;

In [None]:
output_dir = "$(output_dir)"

In [None]:
mkpath(output_dir)

# Load data

In [None]:
ct_path, patient_data = Juliana.load_patient_data(data_dir, patient_ID);

In [None]:
info = Juliana.PatientInfo(data_dir, plan_file, patient_ID)
angles = [(info.plan.gantry_angle(i), info.plan.couch_angle(i), info.plan.nozzle_extraction(i)) for i in info.plan.field_IDs]
# Needed to make a nice matrix rather than a vector of vectors out of it.
angles = collect(hcat(collect.(angles)...)')

# Calculate Dij matrix

In [None]:
function build_checker_board_mask(grid)
    checker_board = zeros(Float32, Tuple(grid.size))
    
    first_in_z_is_white = true
    for iz in 1:grid.size[3]
        first_in_z_is_white = !first_in_z_is_white
        first_in_y_is_white = first_in_z_is_white
        for iy in 1:grid.size[2]
            is_white = first_in_y_is_white
            for ix in 1:grid.size[1]
                checker_board[ix, iy, iz] = is_white
                is_white = !is_white
            end
            first_in_y_is_white = !first_in_y_is_white
        end
    end
    return checker_board
end



In [None]:
function get_optimisation_points_from_prescription(grid, prescriptions, structures; interest_distance=2)
    # Select points that are within a distance of interest_distance from any
    # target or OAR with a constraint.
    distance_from_roi = Array{Float32, 3}(undef, Tuple(grid.size))
    fill!(distance_from_roi, Inf)

    roi_structure_names = Set{String}()
    for (name, value) in prescriptions.target_doses
        push!(roi_structure_names, name)
    end
    for constraint in prescriptions.constraints
        push!(roi_structure_names, constraint.structure_name)
    end

    for name in roi_structure_names
        distance_from_roi .= min.(distance_from_roi, structures[name].distanceFromStructure)
    end

    optimisation_roi_mask = distance_from_roi .<= interest_distance
    
    # Restrict ourselves to the points that are not further than the interest_distance from any target.
    distance_from_targets = Array{Float32, 3}(undef, Tuple(grid.size))
    fill!(distance_from_targets, Inf)
    for (name, value) in prescriptions.target_doses
        distance_from_targets .= min.(
            distance_from_targets,
            structures[name].distanceFromStructure,
        )
    end
    
    optimisation_roi_mask .= optimisation_roi_mask .&& (distance_from_targets .<= interest_distance)
    optimisation_roi_mask .= optimisation_roi_mask .&& build_checker_board_mask(grid)

    # Convert the optimisation point mask to indices and positions.
    optimisation_point_indices = findall(optimisation_roi_mask)
    optimisation_point_indices = Array(vcat([reshape(collect(Tuple(i)), (:, 3)) for i in optimisation_point_indices]...))
    optimisation_points = collect(hcat([Juliana.index_to_xyz(index, grid) for index in eachrow(optimisation_point_indices)]...)')

    return optimisation_roi_mask, optimisation_points, optimisation_point_indices
end

## Calculate Dij using Fiona for the spot placement

In [None]:
optimisation_mask, optimisation_points_before, optimisation_point_indices = get_optimisation_points_from_prescription(
    patient_data.ct.grid,
    patient_data.prescriptions,
    patient_data.structures,
);

In [None]:
function get_optimisation_grid(optimisation_points, grid)
    p_min = reshape(minimum(optimisation_points, dims=1), (:,))
    p_max = reshape(maximum(optimisation_points, dims=1), (:,))

    shape = convert.(Int64, ceil.((p_max .- p_min) ./ grid.spacing)) .+ 1

    return Juliana.Grid(
        grid.spacing,
        p_min,
        shape,
    )
end

optimisation_grid = get_optimisation_grid(optimisation_points_before, patient_data.ct.grid)

In [None]:
coldest_target_name, coldest_target_dose = Juliana.coldest_target(patient_data.prescriptions)

In [None]:
# Dummy value; we're not optimising with Fiona.
target_dose = 1

Dij, optimisation_points = Juliana.FionaStandalone.calculate_Dij(
    output_dir,
    ct_path,
    target_dose,
    patient_data.structures[coldest_target_name],
    fiona_standalone_bin_path,
    fiona_jar_path,
    optimisation_grid,
    angles[:, 1],
    angles[:, 2],
    angles[:, 3],
    debugging=false,
    optimization_points=optimisation_points_before,
);

In [None]:
size(optimisation_points)

In [None]:
size(Dij)

## Calculate Dij using Julia

In [None]:
plan = Juliana.FionaStandalone.read_plan_file("$output_dir/result_plan.json");

In [None]:
npzwrite("$(output_dir)/optimisation_points_juliana.npy", optimisation_points_before)

In [None]:
@assert optimisation_points_before ≈ optimisation_points

# Optimise

In [None]:
config = Juliana.get_optimisation_configuration(
    patient_data.ct,
    patient_data.prescriptions,
    patient_data.structures,
    Dij,
    optimisation_point_indices,
);

In [None]:
w = ones(size(Dij, 2))
mean_dose = sum(collect((config.Dij * cu(w))) .* collect(config.normalisationStructureMask)) / sum(config.normalisationStructureMask)
w *= config.normalisationDose / mean_dose;
w = cu(w);

In [None]:
subloss_weights = Dict{String, Float32}(
    "ideal_dose_loss" => 1.f0,
    "maximum_loss" => 1.f0,
    "minimum_loss" => 1.f0,
    "normalisation_variance" => 1.f0,
)

for constraint in config.prescriptions.constraints
    if constraint.priority == Juliana.soft
        continue
    end

    if constraint.kind == Juliana.constraint_mean
        subloss_weights["$(constraint.structure_name)_mean_loss"] = 1f0
    elseif Juliana.is_maximum_constraint(constraint)
        subloss_weights["$(constraint.structure_name)_max_loss"] = 1f0
    end
end

## Logic

In [None]:
methods(Juliana.dose_loss!)

In [None]:
methods(Juliana.loss!)

In [None]:
@time Juliana.dose_loss!(config.Dij * w, config, Dict{String, Float32}(), subloss_weights)

In [None]:
@time Juliana.loss!(w, config, Dict{String, Float32}(), subloss_weights);
@time Juliana.loss_gradient(w, config, subloss_weights);

In [None]:
@code_warntype Juliana.loss!(config.Dij * w, config, subloss_weights, loss_parts=nothing)

In [None]:
@code_warntype Juliana.loss_gradient(config.Dij * w, config, subloss_weights, loss_parts=nothing)

## Calling

In [None]:
"""
Stops the iteration if the function value has not decreased by more than delta in the last patience iterations.
"""
function build_early_stopping(delta::T, patience) where {T}
    previous_best = typemax(T)
    previous_best_iteration = 1

    function early_stopping(value, iteration)
        if value <= (previous_best - delta)
            previous_best = value
            previous_best_iteration = iteration
        end
        
        return (iteration - previous_best_iteration) > patience
    end

    return early_stopping
end

In [None]:
function optimise_using_optim(w::AbstractArray{T, N}, config::Juliana.OptimisationConfiguration, subloss_weights::Dict{String, T}) where {T, N}
    # Maximum number of iterations.
    n_iterations = 10_000

    function my_loss(w)
        clamp!(w, convert(T, 0.), typemax(T))
        loss = Juliana.loss!(w, config, Dict{String, Float32}(), subloss_weights)

        return loss
    end

    function my_loss_gradient!(gradient, w)
        clamp!(w, convert(T, 0.), typemax(T))

        grad = Juliana.loss_gradient(w, config, subloss_weights)

        gradient[:] = grad[:]
    end

    # We don't call the Optim.optimize function directly
    # because we want to log the sublosses.
    # See the following issue:
    # https://github.com/JuliaNLSolvers/Optim.jl/issues/1024
    early_stopping = build_early_stopping(0.5f0, 25)
    
    alg = LBFGS(linesearch=LineSearches.HagerZhang())
    options = Optim.Options()
    d = Optim.promote_objtype(alg, w, :finite, true, my_loss, my_loss_gradient!)
    state = Optim.initial_state(alg, options, d, w);

    history = Vector{Dict{String, T}}()
    gradients = Array{Vector{T}, 1}()
    loss = Inf
    for i in 1:1000
        # Check for convergence.
        if early_stopping(loss, i)
            break
        end

        # Iterate.
        Optim.update_state!(d, state, alg)
        Optim.update_g!(d, state, alg)
        Optim.update_h!(d, state, alg)

        # Log sublosses.
        loss_parts = Dict{String, T}()
        w = state.x
        clamp!(w, zero(T), typemax(T))
        loss = Juliana.loss!(w, config, loss_parts, subloss_weights)
        println(loss)
        loss_parts["total_loss"] = loss
        push!(history, copy(loss_parts))
        grad = Juliana.loss_gradient(w, config, subloss_weights)
        push!(gradients, grad)
    end
    w_opt = state.x
    clamp!(w_opt, zero(T), typemax(T))
    
    return w_opt, history, gradients
end

In [None]:
start = time()

In [None]:
w_opt, history, gradients = optimise_using_optim(w, config, subloss_weights);

In [None]:
stop = time()

In [None]:
open("$(output_dir)/optimisation_time.txt", "w") do file
    write(file, "$(stop - start)s")
end

In [None]:
w_opt .*= (config.normalisationDose / Juliana.mean_dose(config.Dij * w_opt, config.normalisationStructureMask));

In [None]:
w_opt = clamp!(w_opt, 0., typemax(Float32));
w_opt_cpu = collect(w_opt);

In [None]:
npzwrite("$output_dir/gradients.npy", collect(hcat(gradients...)'))

In [None]:
npzwrite("$output_dir/w_opt.npy", w_opt_cpu)

In [None]:
dose_cpu = Dij * w_opt_cpu;

In [None]:
maximum(dose_cpu)

In [None]:
dose = collect(config.Dij * w_opt);
npzwrite("$output_dir/dose_at_optimisation_points_from_Dij.npy", dose)

In [None]:
dose_matrix = zeros(Float32, Tuple(patient_data.ct.grid.size));

for (i, p) in enumerate(eachrow(collect(optimisation_points)))
    indices = Juliana.xyz_to_index(p, patient_data.ct.grid)
    dose_matrix[indices...] = dose[i]
end

dose_fiona = Dose(dose_matrix, patient_data.ct.grid.spacing, patient_data.ct.grid.origin)
dose_fiona.save("$output_dir/dose_on_optimisation_grid_mapped_to_ct_grid.dat")

In [None]:
normalisation_mask_overlay = Dose(
    convert.(Float32, collect(Juliana.calculate_normalisation_mask(patient_data.prescriptions, patient_data.structures))),
    patient_data.ct.grid.spacing,
    patient_data.ct.grid.origin,
)

## Analysis

### Loss

In [None]:
# Code taken from: https://stackoverflow.com/a/54170025
loss_df = vcat(DataFrame.(history)...);

In [None]:
# CSV.write("$output_dir/$patient_ID/losses_evaluations.csv", loss_df)
CSV.write("$output_dir/losses.csv", loss_df)

In [None]:
loss_df[1, :]

In [None]:
loss_df[end, :]

In [None]:
plot(loss_df[!, "maximum_loss"])

In [None]:
@df loss_df plot(cols(propertynames(loss_df)))

### Gradient

In [None]:
grad = Juliana.loss_gradient(w_opt, config, subloss_weights);

In [None]:
histogram(collect(grad))

### DVH curves

In [None]:
dose = config.Dij * w_opt;
mask = config.normalisationStructureMask;

volumes = collect(LinRange(0.f0, 100.f0, 401));
@time doses = [Juliana.dvh_d(dose, mask, v) for v in volumes];
@time dose_values = Juliana.dvh_d(dose, mask, volumes)

plot(dose_values * 100 / config.normalisationDose, volumes[end:-1:1], marker=2, xlims=(75, 110))

In [None]:
dose = config.Dij * w_opt;
mask = config.normalisationStructureMask;
dose *= config.normalisationDose / Juliana.mean_dose(dose, mask)

dose_values = collect(LinRange(0.f0, 1.2f0 * config.normalisationDose, 121))
@time volumes = Juliana.dvh_v(dose, mask, dose_values)

plot(dose_values * 100 / config.normalisationDose, volumes, marker=2, xlim=(75, 110))

### DVH metrics

In [None]:
# Keep columns that are not always zero.
active_losses = loss_df[:, .!all.(eachcol(loss_df .== 0))];
final_losses = sort(Dict(names(active_losses[end, :]) .=> values(active_losses[end, :])); byvalue=true)

In [None]:
for (name, target_dose) in config.prescriptions.target_doses
    mask = config.structures[name]
    V95 = Juliana.dvh_v(dose, mask, [0.95f0 * target_dose])[1]
    d = 0.8f0 * target_dose
    V80 = Juliana.dvh_v(dose, mask, d)
    
    println("V95 $name = $(V95)%")
    println("V80 $name      = $(V80)")
end

In [None]:
config.normalisationDose

# Save the results to DICOM

In [None]:
# Write the optimised spot weights to the config file and call Fiona standalone to calculate
# the full dose distribution.
w_opt_cpu = collect(w_opt)
plan = Juliana.FionaStandalone.read_plan_file("$output_dir/result_plan.json");

In [None]:
new_plan = Juliana.FionaStandalone.update_spot_weights(plan, w_opt_cpu)
Juliana.FionaStandalone.write_plan_config("$output_dir/result_plan.json", new_plan);

In [None]:
# I think this will release the GPU, which is necessary to let
# Fiona standalone use the GPU. 
CUDA.device_reset!()
sleep(1)
Juliana.FionaStandalone.run_dose_calculation(fiona_jar_path, output_dir, false, false)

In [None]:
ideal_dose_distribution, importance = Juliana.calculate_ideal_dose_distribution(
    patient_data.ct,
    patient_data.prescriptions.target_doses,
    patient_data.structures,
);

In [None]:
doses = Dict(
    "reference" => Dose.load("$data_dir/clinical_dose_distributions/$(patient_ID)_0.dat"),
    # "ideal" => Dose(ideal_dose_distribution, info.patient.ct.spacing, info.patient.ct.origin),
    "ideal" => Dose(ideal_dose_distribution, patient_data.ct.grid.spacing, patient_data.ct.grid.origin),
)


normalisation_mask = convert.(Float32, Juliana.calculate_normalisation_mask(
    patient_data.prescriptions,
    patient_data.structures,
))

for (name, dose) in doses
    normalised_dose = Juliana.normalise_dose(
        dose.data,
        normalisation_mask,
        convert(Float32, config.normalisationDose),
    )
    doses[name] = Dose(normalised_dose, dose.spacing, dose.origin);
end

In [None]:
# TODO: Change in future.
new_patient_ID = "new_optimiser_prototype_$patient_ID"
study_instance_UID = pydicom.uid.generate_uid(entropy_srcs=[new_patient_ID])

In [None]:
dicom_output_dir = "$output_dir/../DICOM/$patient_ID"
mkpath(dicom_output_dir)

In [None]:
pyftpp_CT = pyimport("pyftpp.ct").CT;
pyftpp_Structure = pyimport("pyftpp.structure").Structure
pyftpp_Structure = pyimport("pyftpp.structure").Structure
pyftpp_StructureSet = pyimport("pyftpp.structure").StructureSet

ct_for_export = pyftpp_CT(
    patient_data.ct.data,
    patient_data.ct.grid.spacing,
    patient_data.ct.grid.origin,
)

structures_for_export = []
for (name, structure) in patient_data.structures
    push!(structures_for_export, pyftpp_Structure(name, structure.points))
end
structures_for_export = pyftpp_StructureSet(structures_for_export, ct_for_export)

In [None]:
doses["fiona_during_optimisation"] = dose_fiona
doses["normalisation_mask"] = normalisation_mask_overlay
doses["minimum"] = Dose(
    convert.(Float32, Juliana.calculate_minimum_dose_distribution(
        patient_data.prescriptions,
        patient_data.structures,
        ideal_dose_distribution,
    )),
    patient_data.ct.grid.spacing,
    patient_data.ct.grid.origin,
)
doses["new_during_optimisation"] = dose_fiona
doses["new_recalculated"] = Dose.load("$output_dir/result_dose.dat")

In [None]:
export_to_dicom(
    ct_for_export,
    info.patient.structures,
    dicom_output_dir,
    study_instance_UID,
    new_patient_ID,
    doses,
    decrease_precision=true,
)

In [None]:
loss_df[end-1, :]