Skip to content

Commit

Permalink
bugfix GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Sep 4, 2020
1 parent 7014d81 commit a7561a1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
39 changes: 21 additions & 18 deletions src/framework.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,41 @@ mapgrowth(wrapper::ModelWrapper; kwargs...) =
mapgrowth(layers...; kwargs...) =
mapgrowth(layers; kwargs...)
mapgrowth(layers::Tuple; series::AbstractGeoSeries, tspan::AbstractRange, arraytype=Array) = begin
period = step(tspan)
nperiods = length(tspan)
period = step(tspan); nperiods = length(tspan)
startdate, enddate = first(tspan), last(tspan)
required_keys = Tuple(union(keys(layers)))

# Copy only the required keys to a memory-backed stack
stack = GeoStack(deepcopy(first(series)); keys=required_keys)

A = first(values(stack));
missingval = eltype(A)(NaN)

# Replace false with NaN
mask = map(x -> x ? eltype(A)(x) : missingval, boolmask(A)) |> parent |> arraytype
stackbuffer = GeoData.modify(arraytype, stack)
# Setup output vector

# Make a 3 dimensional GeoArray for output, adding the time dimension
# to init (there should be a function for this in DimensionalData.jl - growdim?
ti = Ti(tspan; mode=Sampled(Ordered(), Regular(period), Intervals(Start())))
output = GeoArray(
arraytype(zeros(size(A)..., nperiods)), (dims(A)..., ti);
name="growthrate",
missingval=missingval,
)
outdims = (dims(A)..., ti)
outA = arraytype(zeros(eltype(A), size(A)..., nperiods))
output = GeoArray(outA, outdims; name="growthrate", missingval=missingval)

runperiods!(output, stackbuffer, series, mask, layers, tspan)

# Return a GeoArray wrapping a regular Array, not arraytype
GeoData.modify(Array, output)
end

function runperiods!(output, stackbuffer, series, mask, layers, tspan)
period = step(tspan); nperiods = length(tspan)
println("Running for $(1:nperiods)")
for p in 1:nperiods
n = 0
periodstart = tspan[p]
periodend = periodstart + period
println("\n", "Processing period between: $periodstart and $periodend")

# We don't use `Between` as it might unintentionally cut off the
# last time if it partially extends beyond the period.
# So we jsut work with time as Points using `Where`.
Expand All @@ -55,24 +63,19 @@ mapgrowth(layers::Tuple; series::AbstractGeoSeries, tspan::AbstractRange, arrayt
println(" ", val(dims(subseries, Ti))[t])
# Copy the arrays we need from disk to the buffer stack
copy!(stackbuffer, subseries[t])
output[Ti(p)] .+= combinelayers(layers, stackbuffer)
# For some reason now this is broken with DD getindex, view is a workaround
parent(view(output, Ti(p))) .+= combinelayers(layers, stackbuffer)
n += 1
end
if n > 0
output[Ti(p)] .*= mask ./ n
parent(view(output, Ti(p))) .*= mask ./ n
else
@warn ("No files found for the $period period starting $periodstart")
output[Ti(p)] .*= mask
parent(view(output, Ti(p))) .*= mask
end
end

# Return a GeoArray wrapping a regular Array, not arraytype
rebuild(output, Array(parent(output)))
end

periodstartdates(startdate, period, nperiods) =
[startdate + p * period for p in 0:nperiods-1]

@inline combinelayers(layer, stackbuffer) = combinelayers((layer,), stackbuffer)
@inline combinelayers(layers::Tuple, stackbuffer::AbstractGeoStack) =
conditionalrate.(Ref(first(layers)), parent(stackbuffer[keys(first(layers))])) .+ combinelayers(tail(layers), stackbuffer)
Expand Down
3 changes: 1 addition & 2 deletions src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ condition(m, x) = true

abstract type AbstractLowerStress <: StressModel end

# TODO set units in the model, this is a temporary hack
@inline condition(m::AbstractLowerStress, x) = x < m.threshold
@inline rate(m::AbstractLowerStress, x) = (m.threshold - x) * m.mortalityrate
@inline condition(m::AbstractLowerStress, x) = x < m.threshold

"""
LowerStress(key::Symbol, threshold, mortalityrate)
Expand Down

0 comments on commit a7561a1

Please sign in to comment.