Skip to content

Commit

Permalink
Only allow sparse special optim for non-binary variables
Browse files Browse the repository at this point in the history
  • Loading branch information
jtackm committed Jun 13, 2023
1 parent 46fd3cd commit a2ba21a
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 11 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### general

- strongly speed-up contingency table computation for heterogeneous=true and max_k=0/1
- use the more compiler-friendly `stack` (introduced in Julia v1.9) instead of `hcat` for large numbers of columns (if available)
- improve univariate pvalue filtering
- remove performance bottleneck in three-way `adjust_df`
Expand Down
108 changes: 99 additions & 9 deletions src/contingency.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,49 @@ function contingency_table!(X::Int, Y::Int, data::SparseArrays.AbstractSparseMat
# Get the pointers to the start and end of the non-zero elements in each column
ptr_X, ptr_Y = data.colptr[X], data.colptr[Y]
ptr_X_end, ptr_Y_end = data.colptr[X + 1], data.colptr[Y + 1]
row_X, row_Y = data.rowval[ptr_X], data.rowval[ptr_Y]

# While there are non-zero elements remaining in either column
@inbounds while ptr_X < ptr_X_end && ptr_Y < ptr_Y_end
row_X, row_Y = data.rowval[ptr_X], data.rowval[ptr_Y]
#row_X, row_Y = data.rowval[ptr_X], data.rowval[ptr_Y]

if row_X == row_Y
val_X, val_Y = data.nzval[ptr_X] + 1, data.nzval[ptr_Y] + 1
test_obj.ctab[val_X, val_Y] += 1
ptr_X += 1
ptr_Y += 1
row_X, row_Y = data.rowval[ptr_X], data.rowval[ptr_Y]
elseif row_X < row_Y
val_X = data.nzval[ptr_X] + 1
val_Y = 1
ptr_X += 1
row_X = data.rowval[ptr_X]
else
val_Y = data.nzval[ptr_Y] + 1
val_X = 1
ptr_Y += 1
row_Y = data.rowval[ptr_Y]
end

test_obj.ctab[val_X, val_Y] += 1
end

# Finish zero / non-zero pairs at the tail of the
# columns
@inbounds while ptr_X < ptr_X_end
val_X = data.nzval[ptr_X] + 1
ptr_X += 1
test_obj.ctab[val_X, 1] += 1
end

@inbounds while ptr_Y < ptr_Y_end
val_Y = data.nzval[ptr_Y] + 1
ptr_Y += 1
test_obj.ctab[1, val_Y] += 1
end

# add double-zero entries
test_obj.ctab[1, 1] = size(data, 1) - sum(test_obj.ctab)

return nothing
end

Expand All @@ -114,25 +140,26 @@ function find_next_Z(row_Z, ptr_Z, ptr_Z_end, row_next, A)
while ptr_Z < (ptr_Z_end-1) && row_Z < row_next
ptr_Z += 1
row_Z = A.rowval[ptr_Z]
end
end

val_Z = row_Z == row_next ? A.nzval[ptr_Z] + 1 : 1

return (val_Z, ptr_Z, row_Z)
end

# Auxillary function for 3-way + max_k = 1 / heterogeneous = true special case
function find_next_XorY(row, ptr, ptr_end, A)
val = A.nzval[ptr] + 1
ptr += 1
row = ptr == ptr_end ? 0 : A.rowval[ptr]
row = ptr == ptr_end ? (size(A, 1) + 1) : A.rowval[ptr]

return (val, ptr, row)
end

# 3-way, optimized for max_k = 1 and heterogeneous = true
function contingency_table!(X::Int, Y::Int, Z::Int, data::SparseArrays.AbstractSparseMatrixCSC{<:Integer},
test_obj::MiTestCond{<:Integer, Nz})
"""Not implemented for binary variables (for which zeros have to be recorded) since
slowdown may be too big"""
fill!(test_obj.ctab, 0)
# only reset the z_map elements that will be used
# (corresponding to abundances 0, 1, 2)
Expand All @@ -143,9 +170,14 @@ function contingency_table!(X::Int, Y::Int, Z::Int, data::SparseArrays.AbstractS
ptr_X, ptr_Y, ptr_Z = data.colptr[X], data.colptr[Y], data.colptr[Z]
ptr_X_end, ptr_Y_end, ptr_Z_end = data.colptr[X + 1], data.colptr[Y + 1], data.colptr[Z + 1]
row_X, row_Y, row_Z = data.rowval[ptr_X], data.rowval[ptr_Y], data.rowval[ptr_Z]

#@show row_X row_Y row_Z ptr_X ptr_Y ptr_Z
# While there are non-zero elements remaining in either column
@inbounds while ptr_X < ptr_X_end && ptr_Y < ptr_Y_end
#rows_checked = Set()
@inbounds while ptr_X < ptr_X_end || ptr_Y < ptr_Y_end
#min_row = min(row_X, row_Y)
#push!(rows_checked, min_row)
#cmp_trip = Tuple(data[min_row, [X, Y, Z]] .+ 1)

if row_X == row_Y
val_Z, ptr_Z, row_Z = find_next_Z(row_Z, ptr_Z, ptr_Z_end, row_X, data)
val_X, ptr_X, row_X = find_next_XorY(row_X, ptr_X, ptr_X_end, data)
Expand All @@ -159,6 +191,13 @@ function contingency_table!(X::Int, Y::Int, Z::Int, data::SparseArrays.AbstractS
val_Y, ptr_Y, row_Y = find_next_XorY(row_Y, ptr_Y, ptr_Y_end, data)
val_X = 1
end

#val_trip = (val_X, val_Y, val_Z)
#@show row_X row_Y row_Z val_trip cmp_trip ptr_X ptr_Y ptr_Z
#if val_trip != cmp_trip
# @show row_X row_Y row_Z val_trip cmp_trip ptr_X ptr_Y ptr_Z
# error()
#end

test_obj.ctab[val_X, val_Y, val_Z] += 1

Expand All @@ -168,13 +207,64 @@ function contingency_table!(X::Int, Y::Int, Z::Int, data::SparseArrays.AbstractS
end
end

#@show row_X row_Y row_Z ptr_X ptr_X_end ptr_Y ptr_Y_end ptr_Z ptr_Z_end

#=
# Finish zero / non-zero pairs at the tail of the X or Y
# column
@inbounds while ptr_X < ptr_X_end
push!(rows_checked, row_X)
val_Z, ptr_Z, row_Z = find_next_Z(row_Z, ptr_Z, ptr_Z_end, row_X, data)
val_X, ptr_X, row_X = find_next_XorY(row_X, ptr_X, ptr_X_end, data)
test_obj.ctab[val_X, 1, val_Z] += 1
if test_obj.zmap.z_map_arr[val_Z] == -1
test_obj.zmap.levels_total += 1
test_obj.zmap.z_map_arr[val_Z] = 1
end
end
@inbounds while ptr_Y < ptr_Y_end
push!(rows_checked, row_Y)
val_Z, ptr_Z, row_Z = find_next_Z(row_Z, ptr_Z, ptr_Z_end, row_Y, data)
val_Y, ptr_Y, row_Y = find_next_XorY(row_Y, ptr_Y, ptr_Y_end, data)
test_obj.ctab[1, val_Y, val_Z] += 1
if test_obj.zmap.z_map_arr[val_Z] == -1
test_obj.zmap.levels_total += 1
test_obj.zmap.z_map_arr[val_Z] = 1
end
end
=#
#=
# go to the first Z row beyond Y (if there is any)
val_Z, ptr_Z, row_Z = find_next_XorY(row_Z, ptr_Z, ptr_Z_end, data)
@inbounds while ptr_Z < ptr_Z_end
val_Z, ptr_Z, row_Z = find_next_XorY(row_Z, ptr_Z, ptr_Z_end, data)
test_obj.ctab[1, 1, val_Z] += 1
if test_obj.zmap.z_map_arr[val_Z] == -1
test_obj.zmap.levels_total += 1
test_obj.zmap.z_map_arr[val_Z] = 1
end
end
=#

# add triple-zero entries
test_obj.ctab[1, 1, 1] = size(data, 1) - sum(test_obj.ctab)

#rows_nz = Set(findall(vec(any(.!iszero.(data[:, [X, Y, Z]]), dims=2))))
#@show length(rows_checked) length(rows_nz) setdiff(rows_checked, rows_nz) setdiff(rows_nz, rows_checked)

return nothing
end

function contingency_table!(X::Int, Y::Int, Zs::NTuple{N,T} where {N,T<:Integer}, data::SparseArrays.AbstractSparseMatrixCSC{<:Integer},
test_obj::ContTest3D)
# Special case: max_k = 1 / heterogeneous = true
if length(Zs) == 1 && is_zero_adjusted(test_obj)
# Special case: max_k = 1 / heterogeneous = true (not implemented for binary variables)
if length(Zs) == 1 && is_zero_adjusted(test_obj) && test_obj.max_vals[X] > 1 && test_obj.max_vals[Y] > 1
contingency_table!(X, Y, Zs[1], data, test_obj)
# Otherwise use flexible general-purpose backend
else
Expand Down
11 changes: 9 additions & 2 deletions src/tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function test(X::Int, Y::Int, data::AbstractMatrix{<:Integer}, test_obj::Abstrac
else
contingency_table!(X, Y, data, test_obj)
end

if is_zero_adjusted(test_obj)
sub_ctab = nz_adjust_cont_tab(max_val_x, max_val_y, test_obj.ctab)
levels_x = size(sub_ctab, 1)
Expand All @@ -61,7 +61,14 @@ function test(X::Int, Y::Int, data::AbstractMatrix{<:Integer}, test_obj::Abstrac
pval = 1.0
suff_power = false
else
mi_stat = mutual_information(sub_ctab, levels_x, levels_y, test_obj.marg_i, test_obj.marg_j)
try
mi_stat = mutual_information(sub_ctab, levels_x, levels_y, test_obj.marg_i, test_obj.marg_j)
catch DomainError
display(test_obj.ctab)
display(sub_ctab)
@show X Y test_obj.marg_i test_obj.marg_j levels_x levels_y n_obs
error("debug")
end

df = adjust_df(test_obj.marg_i, test_obj.marg_j, levels_x, levels_y)
pval = mi_pval(abs(mi_stat), df, n_obs)
Expand Down
16 changes: 16 additions & 0 deletions test/learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,22 @@ end
end
end

@testset "sparse special optim (max_k 0 / 1, mi_nz)" begin
norm_res = FlashWeave.normalize_data(data, test_name="mi_nz", make_sparse=false)
A = norm_res.data
# make some variables binary to test Nz behaviour
A[.!iszero.(A[:, end-5:end])] .= 1
A_sp = sparse(A)

for max_k in [0, 1]
@testset "max_k $max_k" begin
net, net_sp = [learn_network(x; sensitive=false, heterogeneous=true, max_k, normalize=false, verbose=false,
make_sparse=y) for (x, y) in [(A, false), (A_sp, true)]]
@test net == net_sp
end
end
end

# smoke test fast elimination heuristic
@testset "fast_elim" begin
@test isa(learn_network(data, sensitive=true, heterogeneous=false,
Expand Down

0 comments on commit a2ba21a

Please sign in to comment.