# 2 methods for choosing group representatives

+ `:id`: Does another interpolative decomp: basically chooses roughly independent variables in each group
+ `:rss`: This minimizes the RSS of non-selected variables, as proposed by Trevor Hastie. This could selected highly correlated variables. 

In [13]:
# load packages for this tutorial
using Revise
using Knockoffs
using LinearAlgebra
using Random
using StatsBase
using Statistics
using ToeplitzMatrices
using Distributions
using Clustering
using ProgressMeter
using LowRankApprox
using Test
using RCall
# using Plots
# gr(fmt=:png);

# some helper functions to compute power and empirical FDR
function TP(correct_groups, signif_groups)
    return length(signif_groups ∩ correct_groups) / max(1, length(correct_groups))
end
function TP(correct_groups, β̂, groups)
    signif_groups = get_signif_groups(β̂, groups)
    return TP(correct_groups, signif_groups)
end
function power(correct_snps, discovered_snps)
    return length(discovered_snps ∩ correct_snps) / length(correct_snps)
end
function FDR(correct_groups, signif_groups)
    FP = length(signif_groups) - length(signif_groups ∩ correct_groups) # number of false positives
    return FP / max(1, length(signif_groups))
end
function FDR(correct_groups, β̂, groups)
    signif_groups = get_signif_groups(β̂, groups)
    return FDR(correct_groups, signif_groups)
end
function get_signif_groups(β, groups)
    correct_groups = Int[]
    for i in findall(!iszero, β)
        g = groups[i]
        g ∈ correct_groups || push!(correct_groups, g)
    end
    return correct_groups
end

R"""
step1 <- function(C,vlist=seq(ncol(C)),RSS0=sum(diag(C)),zero=1e-12){
    dC <- diag(C)
    rs <- colSums(C^2)/dC
    imax <- order(rs,decreasing=TRUE)[1]
    vmin <- sum(dC) - rs[imax]
    residC = C - outer(C[,imax],C[,imax],"*")/C[imax,imax]
    index = vlist[imax]
    izero = diag(residC) <= zero
    list(index = index, variance = vmin, R2 = 1-vmin/RSS0, C=residC[!izero,!izero],vlist=vlist[!izero])
}

subsetC <- function(C, k, traceit=FALSE){
    ## C correlation matrix
    ## k subset size
    indices <- rep(0, k)
    p <- ncol(C)
    RSS0 <- p
    R2 <- double(k)
    vlist = seq(p)
    for(i in 1:k){
        fit1 <- step1(C, RSS0=RSS0, vlist=vlist)
        indices[i] <- fit1$index
        C <- fit1$C
        vlist <- fit1$vlist
        R2[i] <- fit1$R2
        if(traceit)cat(i, "index", fit1$index, "Variance Explained", fit1$variance,"R-squared",fit1$R2,"\n")
    }
    list(indices = indices, R2=R2)
}
""";

## Simulate data

In [187]:
m = 1
p = 510
k = 10 # number of causal groups
n = 250 # sample size
μ = zeros(p)
Σ = simulate_AR1(p, a=3, b=1)
X = rand(MvNormal(μ, Σ), n)' |> Matrix
zscore!(X, mean(X, dims=1), std(X, dims=1));

In [115]:
# Julia implementation that have some bug (result different than Trevor's code)
function select_one(C::AbstractMatrix{T}) where T
    p = size(C, 2)
#     return [tr(C - C[:,i]*C[i,:]'/C[i,i]) for i in 1:p] |> argmin
    best_val, min_idx = typemax(T), 0
    for i in 1:p
        val = zero(T)
        for j in 1:p
            j == i && continue
            val += C[j, j] - C[j,i]*C[i,j]/C[i,i]
        end
        if val < best_val
            best_val = val
            min_idx = i
        end
    end
    return min_idx
end

function select_k(C::AbstractMatrix, k::Int)
    p = size(C, 2)
    selected, not_selected = Int[], collect(1:p)
    C̃ = copy(C)
    for i in 1:min(k, p)
        idx = select_one(C̃)
        rep = not_selected[idx]
        push!(selected, rep)
        deleteat!(not_selected, findfirst(x -> x == rep, not_selected))
        C̃ -= C̃[:, idx] * C̃[idx, :]' ./ C̃[idx, idx]
        C̃ = @view(C[not_selected, not_selected])
    end
    return selected
end

selected = select_k(C, 10)

[selected result[:indices]]

10×2 Matrix{Float64}:
 354.0  354.0
 294.0  294.0
 355.0  443.0
 295.0  211.0
 359.0  311.0
 443.0  498.0
 352.0  116.0
 293.0   58.0
 211.0  197.0
 311.0  104.0

In [186]:
# faithful re-implementation of Trevor's R code. Probably not the most Julian/efficient Julia code
function step1(C::AbstractMatrix, vlist, RSS0, tol=1e-12)
    dC = diag(C)
    rs = vec(sum(C.^2, dims=1)) ./ dC
    v, imax = findmax(rs)
    vmin = sum(dC) - rs[imax]
    residC = C - (C[:,imax] * C[:,imax]' ./ C[imax, imax])
    index = vlist[imax]
    nzero = findall(x -> x > tol, diag(residC))
    R2 = 1 - vmin/RSS0
    return index, R2, residC[nzero, nzero], vlist[nzero]
end

function subsetC(C::AbstractMatrix, k::Int)
    p = size(C, 2)
    indices = zeros(Int, k)
    RSS0 = p
    R2 = zeros(k)
    vlist = collect(1:p)
    for i in 1:k
        idx, r2, Cnew, vnew = step1(C, vlist, RSS0)
        indices[i] = idx
        C = Cnew
        vlist = vnew
        R2[i] = r2
    end
    return indices, R2
end

@time selected, R2 = subsetC(C, 10)
[selected result[:indices]]

  0.229662 seconds (473.46 k allocations: 121.045 MiB, 7.99% gc time, 76.44% compilation time)


10×2 Matrix{Float64}:
 354.0  354.0
 294.0  294.0
 443.0  443.0
 211.0  211.0
 311.0  311.0
 498.0  498.0
 116.0  116.0
  58.0   58.0
 197.0  197.0
 104.0  104.0

In [183]:
[R2 result[:R2]]

10×2 Matrix{Float64}:
 0.0213559  0.0213559
 0.0422617  0.0422617
 0.0596227  0.0596227
 0.0758653  0.0758653
 0.0913927  0.0913927
 0.106448   0.106448
 0.121425   0.121425
 0.13594    0.13594
 0.150434   0.150434
 0.164428   0.164428

## Some test

### Interpolative decomposition, selecting group reps by ID

In [264]:
Random.seed!(2022)
nrep = 3
rep_method = :id
groups1, group_reps = id_partition_groups(X, rep_method=rep_method, nrep=nrep)
@test countmap(groups1[group_reps]) |> values |> collect |> maximum == 3
groups1, group_reps = id_partition_groups(Symmetric(cor(X)), rep_method=rep_method, nrep=nrep)
@test countmap(groups1[group_reps]) |> values |> collect |> maximum == 3

[32m[1mTest Passed[22m[39m
  Expression: ((countmap(groups1[group_reps]) |> values) |> collect) |> maximum == 3
   Evaluated: 3 == 3

### Interpolative decomposition, selecting group reps by Trevor's method

In [265]:
Random.seed!(2022)
nrep = 2
rep_method = :rss
groups2, group_reps = id_partition_groups(X, rep_method=rep_method, nrep=nrep)
@test countmap(groups2[group_reps]) |> values |> collect |> maximum == 2

[32m[1mTest Passed[22m[39m
  Expression: ((countmap(groups2[group_reps]) |> values) |> collect) |> maximum == 2
   Evaluated: 2 == 2

### hierarchical clustering, using ID to choose reps

In [269]:
Random.seed!(2022)
nrep = 2
rep_method = :id
groups1, group_reps1 = hc_partition_groups(X, rep_method=rep_method, nrep=nrep)
@test countmap(groups[group_reps]) |> values |> collect |> maximum == 2

[32m[1mTest Passed[22m[39m
  Expression: ((countmap(groups[group_reps]) |> values) |> collect) |> maximum == 2
   Evaluated: 2 == 2

### hierarchical clustering, using Trevor's method to choose reps

In [273]:
Random.seed!(2022)
nrep = 2
rep_method = :rss
groups2, group_reps2 = hc_partition_groups(X, rep_method=rep_method, nrep=nrep)
groups2, group_reps2 = hc_partition_groups(Symmetric(cor(X)), rep_method=rep_method, nrep=nrep)
@test countmap(groups[group_reps]) |> values |> collect |> maximum == 2

[32m[1mTest Passed[22m[39m
  Expression: ((countmap(groups[group_reps]) |> values) |> collect) |> maximum == 2
   Evaluated: 2 == 2

In [272]:
[group_reps1 group_reps2]

285×2 Matrix{Int64}:
   1    1
   2    2
   3    3
   4    4
   5    5
   6    7
   9    9
  10   10
  11   11
  12   12
  14   13
  18   16
  19   19
   ⋮  
 494  494
 495  495
 496  496
 497  497
 498  498
 499  499
 501  501
 502  502
 503  503
 505  504
 506  506
 510  508