In [1]:
using CSV
using DataFrames
using JSON
using ArchGDAL
using Proj
using Rasters
using Base.Threads
using JLD2
using Lux
using LuxCore
using EasyHybrid
using Optimisers
using Statistics
using Plots
using Distributed
using Parquet
include("helpers.jl")
using .Helpers


# define key parameters

In [2]:
version = "v20251219"
res_m = 30 # meters

30

## load in necessary data

In [3]:
# load in predictors
@load "./data/predictors_$(version).jld2" predictors

# load in covariates scalers
@load "./data/covs_scaler.jld2" cov_scaler

# load production model
@load "./map/prod_SiNN_model_$(version).jld2" hmb pss stt

# load cov list
cov = JSON.parsefile("./cov_path_full.json");
tnames = collect(keys(cov))
tpaths = collect(values(cov))

# load polygons
geom = ArchGDAL.read("nuts_de_2021.gpkg") do ds
    lyr = ArchGDAL.getlayer(ds, 0)
    ArchGDAL.getgeom(first(lyr))
end
env = ArchGDAL.envelope(geom)
bbox3035 = (env.MinX, env.MinY, env.MaxX, env.MaxY)
# bboxmine = (8.956051,51.815757,10.450192,53.154421) # examine area in northern DE, suggested by Bernhard
# bbox3035 = convert_bbox_wgs84_to_3035(bboxmine);

(4.0319524076000005e6, 2.6840749562e6, 4.671179791300001e6, 3.5430855823e6)

## tiling

In [4]:
xs, ys = make_grid_3035(bbox3035, res_m);
tiles = make_tiles(xs, ys; tilesize = 2048)

println("Julia threads: ", Threads.nthreads())
println("tiles: ", length(tiles))
println("length: ", length(tpaths))

save_col = [
    :x3035, :y3035,
    :pred_SOCconc, :pred_CF, :pred_BD,
    :pred_SOCdensity,
    :soc, :ocd, :bd, :cf
];

Julia threads: 96
tiles: 154
length: 362


In [5]:
addprocs()

@everywhere begin
    using ArchGDAL
    using Proj
    if !isdefined(Main, :Helpers)
        include("helpers.jl")
    end
    using .Helpers
end

function process_tile(
    tid::Integer,
    tile,
    xs, ys,
    tpaths, tnames,
    predictors, cov_scaler,
    hmb, ps, st;
    outdir = "./map",
    res_m = 0,
    version = ""
)
    println("----------------------------------------------------")
    xind, yind = tile
    xs_t = xs[xind]
    ys_t = ys[yind]

    out = pmap(i -> Helpers.sample_tiff_onto_grid(
                    tpaths[i], xs_t, ys_t),
              eachindex(tpaths))

    df = DataFrame()
    nx = length(xs_t)
    ny = length(ys_t)

    df.x3035 = repeat(xs_t, inner=ny)
    df.y3035 = repeat(ys_t, outer=nx)

    for i in eachindex(tnames)
        df[!, Symbol(tnames[i])] = out[i]
    end

    col1 = df[!, :soil_moisture_s1_clms_qr_1_p0_05_m_1km_20140101_20241231_eu_epsg3035_v20250211]
    col2 = df[!, :green_glad_landsat_ard2_seasconv_m_yearly_p50_30m_s_YYYY0101_YYYY1231_eu_epsg_3035_v20231127]
    if all(==(first(col1)), col1) && all(==(first(col2)), col2)
        println("invalid tile $(tid), skip")
        return
    end
    
    println("finish overlay")
    Helpers.preprocess_predictors!(df, predictors, cov_scaler)
    x_test = to_keyedArray(Float32.(df[!, predictors]));
    ŷ_test, st_test = hmb(x_test, pss, LuxCore.testmode(stt))

    # prediction
    for var in [:BD, :SOCconc, :CF, :SOCdensity, :oBD, :mBD]
        if hasproperty(ŷ_test, var)
            val = getproperty(ŷ_test, var)
    
            if val isa AbstractVector && length(val) == nrow(df)
                df[!, Symbol("pred_", var)] = val # per row
    
            elseif (val isa Number) || (val isa AbstractVector && length(val) == 1)
                df[!, Symbol("pred_", var)] = fill(Float32(val isa AbstractVector ? first(val) : val), nrow(df))
            end
        end
    end
    
    df[!, :soc] = exp.(df[!, :pred_SOCconc] ./ Helpers.scalers[:SOCconc]) .- 1;
    df[!, :cf] = (exp.(df[!, :pred_CF] ./ Helpers.scalers[:CF]) .- 1) ./ 100;
    df[!, :bd] = df[!, :pred_BD] ./ Helpers.scalers[:BD];
    df[!, :ocd] = exp.(df[!, :pred_SOCdensity] ./ Helpers.scalers[:SOCdensity]);

    println("finish prediction, ", size(df))
    # write_parquet("./data/out_$(res_m)m_$(version).pq", df[:, save_col])
    
    for var in [:bd, :soc, :cf, :ocd, :pred_oBD, :pred_mBD]
        Helpers.write_geotiff_from_grid(
            df, var,
            joinpath(outdir, "$(var)_$(res_m)m_$(version)_tile$(tid).tif")
        )
    end

    println("finished tile $tid / $(length(tiles))")
    return
end


process_tile (generic function with 1 method)

In [6]:
@time for (tid, tile) in enumerate(tiles)
    process_tile(
        tid,
        tile,
        xs, ys,
        tpaths, tnames,
        predictors, cov_scaler,
        hmb, pss, stt;
        res_m = res_m,
        version = version
    )
end

# @time df = process_tile(100, tiles[100], xs, ys, tpaths, tnames, predictors, cov_scaler, hmb, pss, stt;
#                   res_m=res_m, version=version);

----------------------------------------------------
invalid tile 1, skip
----------------------------------------------------
invalid tile 2, skip
----------------------------------------------------
finish overlay
finish prediction, (4194304, 374)
finished tile 3 / 154
----------------------------------------------------
finish overlay
finish prediction, (4194304, 374)
finished tile 4 / 154
----------------------------------------------------
finish overlay
finish prediction, (4194304, 374)
finished tile 5 / 154
----------------------------------------------------
finish overlay
finish prediction, (4194304, 374)
finished tile 6 / 154
----------------------------------------------------
finish overlay
finish prediction, (4194304, 374)
finished tile 7 / 154
----------------------------------------------------
finish overlay
finish prediction, (4194304, 374)
finished tile 8 / 154
----------------------------------------------------
finish overlay
finish prediction, (4194304, 374)
finish

## check cov scalers

In [None]:
# clean(x) = filter(!isnan, skipmissing(x))

# rows = Vector{NamedTuple}()

# for col in predictors
#     v_o = clean(oridf[!, col])
#     v_d = clean(df[!, col])

#     push!(rows, (
#         variable   = String(col),
#         q05_oridf  = quantile(v_o, 0.05),
#         q05_df     = quantile(v_d, 0.05),
#         q50_oridf  = quantile(v_o, 0.50),
#         q50_df     = quantile(v_d, 0.50),
#         q95_oridf  = quantile(v_o, 0.95),
#         q95_df     = quantile(v_d, 0.95)
#     ))
# end

# qt = DataFrame(rows)
# CSV.write("./data/predictor_quantiles_check.csv", qt)

# write_parquet("./data/production_preprocessed_$(res_m)m_$(version).pq", df)
# CSV.write("production_preprocessed_$(res_m)m_$(version).csv", df)

## check predictions

In [None]:
# for col in ["pred_BD", "pred_SOCconc", "pred_CF", "pred_SOCdensity"]

#     vals = df[:, col]

#     # 有效值（非 missing 且非 NaN）
#     valid_vals = filter(x -> !ismissing(x) && !isnan(x), vals)

#     n_valid = length(valid_vals)
#     vmin = minimum(valid_vals)
#     vmax = maximum(valid_vals)

#     println("Variable: $col")
#     println("  Valid count = $n_valid")
#     println("  Min = $vmin")
#     println("  Max = $vmax\n")

#     histogram(
#         vals;
#         bins = 50,
#         xlabel = col,
#         ylabel = "Frequency",
#         title = "Histogram of $col",
#         lw = 1,
#         legend = false
#     )
#     display(current())
# end