# Check if representative group knockoffs are working

In [1]:
# load packages for this tutorial
using Revise
using Knockoffs
using LinearAlgebra
using Random
using StatsBase
using Statistics
using ToeplitzMatrices
using Distributions
using Clustering
using ProgressMeter
using LowRankApprox
using Plots
using CSV, DataFrames
gr(fmt=:png);

# some helper functions to compute power and empirical FDR
function TP(correct_groups, signif_groups)
    return length(signif_groups ∩ correct_groups) / max(1, length(correct_groups))
end
function TP(correct_groups, β̂, groups)
    signif_groups = get_signif_groups(β̂, groups)
    return TP(correct_groups, signif_groups)
end
function power(correct_snps, discovered_snps)
    return length(discovered_snps ∩ correct_snps) / length(correct_snps)
end
function FDR(correct_groups, signif_groups)
    FP = length(signif_groups) - length(signif_groups ∩ correct_groups) # number of false positives
    return FP / max(1, length(signif_groups))
end
function FDR(correct_groups, β̂, groups)
    signif_groups = get_signif_groups(β̂, groups)
    return FDR(correct_groups, signif_groups)
end
function get_signif_groups(β, groups)
    correct_groups = Int[]
    for i in findall(!iszero, β)
        g = groups[i]
        g ∈ correct_groups || push!(correct_groups, g)
    end
    return correct_groups
end

┌ Info: Precompiling Knockoffs [878bf26d-0c49-448a-9df5-b057c815d613]
└ @ Base loading.jl:1423


get_signif_groups (generic function with 1 method)

## Simulate data with gnomdAD panel

In [35]:
datadir = "/Users/biona001/Benjamin_Folder/research/4th_project_PRS/group_knockoff_test_data"
covfile = CSV.read(joinpath(datadir, "CorG_2_127374341_128034347.txt"), DataFrame)
Σ = covfile |> Matrix{Float64}
Σ = 0.99Σ + 0.01I #ensure PSD

# test on smaller data
idx = 500 # 1241 # includes largest group with 192 members
Σ = Σ[1:idx, 1:idx];

# simulate data
m = 5
p = size(Σ, 1)
k = 10 # number of causal groups
n = 1000 # sample size

# simulate X
μ = zeros(p)
X = rand(MvNormal(μ, Σ), n)' |> Matrix
zscore!(X, mean(X, dims=1), std(X, dims=1)); # standardize columns of X

# define groups
nrep = 5
groups = id_partition_groups(X, force_contiguous=false)
unique_groups = unique(groups)
countmap(groups) |> values |> collect |> sort

136-element Vector{Int64}:
  1
  1
  1
  1
  1
  1
  1
  1
  1
  1
  1
  1
  1
  ⋮
 15
 15
 17
 18
 18
 22
 24
 25
 25
 32
 35
 44

In [36]:
# group knockoffs
@time ko = modelX_gaussian_group_knockoffs(X, :maxent, groups, μ, Σ, m=m, verbose=true);

Maxent initial obj = -21697.062824271106
Iter 1 (PCA): obj = -17873.19705916164, δ = 2.7894170882622116, t1 = 0.2, t2 = 0.07
Iter 2 (PCA): obj = -14422.241803784958, δ = 0.9222596242085215, t1 = 0.37, t2 = 0.13
Iter 3 (PCA): obj = -12753.226804435964, δ = 0.5414552226809967, t1 = 0.58, t2 = 0.2
Iter 4 (PCA): obj = -11795.165658886217, δ = 0.4192851451235552, t1 = 0.86, t2 = 0.26
Iter 5 (PCA): obj = -11206.72302765663, δ = 0.48592870354058665, t1 = 1.07, t2 = 0.33
Iter 6 (PCA): obj = -10838.549047856957, δ = 0.6366331320900753, t1 = 1.33, t2 = 0.39
Iter 7 (PCA): obj = -10617.833462405928, δ = 0.6754313214808505, t1 = 1.57, t2 = 0.45
Iter 8 (PCA): obj = -10482.505674070737, δ = 0.27530848796806606, t1 = 1.79, t2 = 0.51
Iter 9 (PCA): obj = -10373.903657779418, δ = 0.17728527440017383, t1 = 1.95, t2 = 0.57
Iter 10 (PCA): obj = -10284.36124155696, δ = 0.2618276162815484, t1 = 2.1, t2 = 0.64
Iter 11 (CCD): obj = -10068.859996265239, δ = 0.5205009289960065, t1 = 2.38, t2 = 1.05, t3 = 0.0
Iter

In [37]:
# rep group knockoffs
@time rko = modelX_gaussian_rep_group_knockoffs(X, :maxent, μ, Σ, groups, verbose=true)

X̃ = rko.X̃
group_reps = rko.group_reps
Xr = X[:, group_reps]
Xc = X[:, setdiff(1:p, group_reps)]
X̃r = X̃[:, group_reps]
X̃c = X̃[:, setdiff(1:p, group_reps)];

140 representatives for 500 variables, 148 optimization variables
  0.416092 seconds (19.74 k allocations: 209.171 MiB)


In [38]:
count(!iszero, ko.S), count(!iszero, rko.S)

(8836, 230548)

Check if knockoffs generated from conditional independence assumption satisfy exchangability

In [39]:
# right column is the first 10 reps, left column is the group they belong to
[groups[group_reps[1:10]] group_reps[1:10]]

10×2 Matrix{Int64}:
  1   1
  2   2
  3   3
  4   4
  6   6
 90   8
  5  13
 17  38
  7  47
  8  49

In [40]:
groups[1:5]

5-element Vector{Int64}:
 1
 2
 3
 4
 6

In [41]:
cor(Xr)[1:5, 1:5]

5×5 Matrix{Float64}:
  1.0          0.00752692   0.387493   -0.178493    0.122001
  0.00752692   1.0          0.0207748  -0.0882473   0.486391
  0.387493     0.0207748    1.0        -0.0668503   0.240392
 -0.178493    -0.0882473   -0.0668503   1.0        -0.123443
  0.122001     0.486391     0.240392   -0.123443    1.0

In [42]:
cor(X̃r)[1:5, 1:5]

5×5 Matrix{Float64}:
  1.0         0.0414789    0.286198   -0.0836722    0.141817
  0.0414789   1.0          0.0868466  -0.00633061   0.496822
  0.286198    0.0868466    1.0        -0.0590992    0.253416
 -0.0836722  -0.00633061  -0.0590992   1.0         -0.0279316
  0.141817    0.496822     0.253416   -0.0279316    1.0

In [43]:
cor(Xc)[1:5, 1:5]

5×5 Matrix{Float64}:
  1.0        0.98148   -0.312384  -0.309631  -0.302546
  0.98148    1.0       -0.312411  -0.310994  -0.30411
 -0.312384  -0.312411   1.0        0.977494   0.979676
 -0.309631  -0.310994   0.977494   1.0        0.983289
 -0.302546  -0.30411    0.979676   0.983289   1.0

In [44]:
cor(X̃c)[1:5, 1:5]

5×5 Matrix{Float64}:
  1.0        0.687178  -0.468214  -0.371239  -0.292553
  0.687178   1.0       -0.206367  -0.268106  -0.269473
 -0.468214  -0.206367   1.0        0.84027    0.739155
 -0.371239  -0.268106   0.84027    1.0        0.95733
 -0.292553  -0.269473   0.739155   0.95733    1.0

**Conclusion**
Seems like $X_r$ and $\tilde{X}_r$ agrees fairly well while $X_c$ and $\tilde{X}_c$ agrees only somewhat. Lets try this in simulations

## One simulation

In [47]:
# simulate data
m = 5
p = size(Σ, 1)
k = 10 # number of causal groups
n = 500 # sample size
μ = zeros(p)

# simulate X
X = rand(MvNormal(μ, Σ), n)' |> Matrix
zscore!(X, mean(X, dims=1), std(X, dims=1)); # standardize columns of X

# define groups
groups = id_partition_groups(X, force_contiguous=false)

# simulate βtrue
βtrue = zeros(p)
βtrue[1:k] .= rand(-1:2:1, k) .* randn(k)
shuffle!(βtrue)
causal_groups = get_signif_groups(βtrue, groups)

# simulate y
y = X * βtrue + randn(n)

# fully general me
@time me = modelX_gaussian_group_knockoffs(
    X, :maxent, groups, μ, Σ, 
    m = m,
    tol = 0.0001,    # convergence tolerance
    verbose=false, # whether to print informative intermediate results
)
me_ko_filter = fit_lasso(y, me)
me_power = round(TP(causal_groups, me_ko_filter.βs[3], groups), digits=3)
me_fdr = round(FDR(causal_groups, me_ko_filter.βs[3], groups), digits=3)
me_ssum = sum(abs.(me_ko_filter.ko.S))
@show me_power, me_fdr

# representative ME knockoffs
@time rme = modelX_gaussian_rep_group_knockoffs(
    X, :maxent, μ, Σ, groups, 
    m = m, 
);
rme_ko_filter = fit_lasso(y, rme)
discovered_groups = groups[findall(!iszero, rme_ko_filter.βs[3])] |> unique
rme_power = round(TP(causal_groups, discovered_groups), digits=3)
rme_fdr = round(FDR(causal_groups, discovered_groups), digits=3)
@show rme_power, rme_fdr

 14.923825 seconds (115.80 k allocations: 265.523 MiB, 0.07% gc time)
(me_power, me_fdr) = (0.444, 0.2)
135 representatives for 500 variables, 141 optimization variables
  0.868293 seconds (19.71 k allocations: 382.991 MiB, 0.64% gc time)
(rme_power, rme_fdr) = (0.444, 0.0)


(0.444, 0.0)

### Interpolative decomposition, target FDR = 0.1, m=5

In [49]:
fdr_hat = 0.0
nsims = 10
for i in 1:nsims
    Random.seed!(i)
    
    # simulate data
    m = 5
    p = size(Σ, 1)
    k = 10 # number of causal groups
    n = 500 # sample size
    μ = zeros(p)

    # simulate X
    X = rand(MvNormal(μ, Σ), n)' |> Matrix
    zscore!(X, mean(X, dims=1), std(X, dims=1)); # standardize columns of X

    # define groups
    groups = id_partition_groups(X, force_contiguous=false)

    # simulate βtrue
    βtrue = zeros(p)
    βtrue[1:k] .= rand(-1:2:1, k) .* randn(k)
    shuffle!(βtrue)
    causal_groups = get_signif_groups(βtrue, groups)

    # simulate y
    y = X * βtrue + randn(n)

    # fully general me
    @time me = modelX_gaussian_group_knockoffs(
        X, :maxent, groups, μ, Σ, 
        m = m,
        tol = 0.0001,    # convergence tolerance
        verbose=false, # whether to print informative intermediate results
    )
    me_ko_filter = fit_lasso(y, me)
    me_power = round(TP(causal_groups, me_ko_filter.βs[3], groups), digits=3)
    me_fdr = round(FDR(causal_groups, me_ko_filter.βs[3], groups), digits=3)
    me_ssum = sum(abs.(me_ko_filter.ko.S))
    @show me_power, me_fdr

    # representative ME knockoffs
    @time rme = modelX_gaussian_rep_group_knockoffs(
        X, :maxent, μ, Σ, groups, 
        m = m, 
    );
    rme_ko_filter = fit_lasso(y, rme)
    discovered_groups = groups[findall(!iszero, rme_ko_filter.βs[3])] |> unique
    rme_power = round(TP(causal_groups, discovered_groups), digits=3)
    rme_fdr = round(FDR(causal_groups, discovered_groups), digits=3)
    @show rme_power, rme_fdr
    
    fdr_hat += rme_fdr
end
fdr_hat /= nsims
println("representative ME knockoff has avg FDR $fdr_hat")

 14.410588 seconds (116.16 k allocations: 265.621 MiB, 1.19% gc time)
(me_power, me_fdr) = (0.25, 0.0)
  1.029325 seconds (536.70 k allocations: 402.994 MiB, 1.17% gc time, 23.00% compilation time)
(rme_power, rme_fdr) = (0.5, 0.0)
  9.613800 seconds (74.35 k allocations: 262.346 MiB, 0.09% gc time)
(me_power, me_fdr) = (0.222, 0.0)
  0.767761 seconds (19.79 k allocations: 390.207 MiB, 0.95% gc time)
(rme_power, rme_fdr) = (0.222, 0.0)
  9.289118 seconds (72.26 k allocations: 262.174 MiB)
(me_power, me_fdr) = (0.4, 0.2)
  0.826819 seconds (19.98 k allocations: 397.904 MiB)
(rme_power, rme_fdr) = (0.5, 0.167)
  9.366307 seconds (70.27 k allocations: 261.993 MiB)
(me_power, me_fdr) = (0.333, 0.0)
  0.760721 seconds (19.87 k allocations: 393.700 MiB, 0.92% gc time)
(rme_power, rme_fdr) = (0.444, 0.2)
 12.459039 seconds (90.86 k allocations: 263.743 MiB)
(me_power, me_fdr) = (0.778, 0.0)
  0.814730 seconds (19.66 k allocations: 379.557 MiB, 1.28% gc time)
(rme_power, rme_fdr) = (0.889, 0.1