Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/CustomLossFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ function _sigmoid(ŷ::Matrix{T}, y::T) where {T<:AbstractFloat}
end;

function _leaky_relu(ŷ::Matrix{T}, y::T) where {T<:AbstractFloat}
return min.(0.001 .* (y .- ŷ) .+ 1., leakyrelu.((y .- ŷ) .* 10, 0.001))
return min.(0.001 .* (y .- ŷ) .+ 1.0, leakyrelu.((y .- ŷ) .* 10, 0.001))
end;


"""
ψₘ(y, m)

Expand Down Expand Up @@ -105,7 +104,7 @@ function adaptative_block_learning(nn_model, data, hparams)
@showprogress for epoch in 1:(hparams.epochs)
loss, grads = Flux.withgradient(nn_model) do nn
aₖ = zeros(hparams.K + 1)
for i in 1:hparams.samples
for i in 1:(hparams.samples)
x = rand(hparams.transform, hparams.K)
yₖ = nn(x')
aₖ += generate_aₖ(yₖ, data.data[i])
Expand Down Expand Up @@ -149,22 +148,23 @@ function convergence_to_uniform(aₖ::Vector{T}) where {T<:Int}
end;

function get_better_K(nn_model, data, min_K, hparams)
K = hparams.max_k
for k in min_K:hparams.max_k
if !convergence_to_uniform(get_window_of_Aₖ(hparams.transform, nn_model, data, k))
K = k
break
end
range = min_K:1:(hparams.max_k)
index = findfirst(
k -> !convergence_to_uniform(get_window_of_Aₖ(hparams.transform, nn_model, data, k)),
range
)
if index === nothing
return hparams.max_k
end
return K
return range[index]
end;

"""
auto_adaptative_block_learning(model, data, hparams)

Custom loss function for the model.
"""
function auto_adaptative_block_learning(nn_model, data, hparams)
function auto_adaptative_block_learning(nn_model, data, hparams::AutoAdaptativeHyperParams)
@assert length(data) == hparams.samples

K = 2
Expand All @@ -179,7 +179,7 @@ function auto_adaptative_block_learning(nn_model, data, hparams)
end
loss, grads = Flux.withgradient(nn_model) do nn
aₖ = zeros(K + 1)
for i in 1:hparams.samples
for i in 1:(hparams.samples)
x = rand(hparams.transform, K)
yₖ = nn(x')
aₖ += generate_aₖ(yₖ, data.data[i])
Expand Down