# Choosing group representatives

In [1]:
using Revise
using Knockoffs
using LinearAlgebra
using Random
using StatsBase
using Statistics
using ToeplitzMatrices
using Distributions
using Clustering
using ProgressMeter
using LowRankApprox
using Test
using RCall
using CSV, DataFrames
using BenchmarkTools

# using Plots
# gr(fmt=:png);

# some helper functions to compute power and empirical FDR
function TP(correct_groups, signif_groups)
    return length(signif_groups ∩ correct_groups) / max(1, length(correct_groups))
end
function TP(correct_groups, β̂, groups)
    signif_groups = get_signif_groups(β̂, groups)
    return TP(correct_groups, signif_groups)
end
function power(correct_snps, discovered_snps)
    return length(discovered_snps ∩ correct_snps) / length(correct_snps)
end
function FDR(correct_groups, signif_groups)
    FP = length(signif_groups) - length(signif_groups ∩ correct_groups) # number of false positives
    return FP / max(1, length(signif_groups))
end
function FDR(correct_groups, β̂, groups)
    signif_groups = get_signif_groups(β̂, groups)
    return FDR(correct_groups, signif_groups)
end
function get_signif_groups(β, groups)
    correct_groups = Int[]
    for i in findall(!iszero, β)
        g = groups[i]
        g ∈ correct_groups || push!(correct_groups, g)
    end
    return correct_groups
end

R"""
# zihuai's code for finding representative variants per group
Get.group.rep<-function(Sigma,clusters,inv.Sigma=NULL,thres=0.75,search.method='subsetC',stop.method='R2.ratio'){
  if(length(inv.Sigma)==0 & stop.method=='R2.ratio'){inv.Sigma<-solve(Sigma)}
  rep.data<-c()
  for(j in 1:max(clusters)){
    # print(j)
    if(sum(clusters==j)==1){
      rep.data<-rbind(rep.data,cbind(j,which(clusters==j)))
    }else{
      cor.G<-Sigma[clusters==j,clusters==j]
      if(search.method=='ID'){
        #interpolative decomposition
        A<-chol(cor.G)
        temp.fit<-rid(A,ncol(A),rand=F,idx_only=T)
        index.all<-temp.fit$idx
      }
      if(search.method=='subsetC'){
        index.all<-subsetC(cor.G, k=nrow(cor.G), traceit=FALSE)$indices
      }
      index<-index.all[1]
      for(i in 1:(nrow(cor.G)-1)){
        # print(i)
        #for(i in 1:4){
        temp.A<-cor.G[index,index,drop=F]
        #pre-compute some matrices
        if(i==1){inv.A<-solve(temp.A)}
        B<-cor.G[index,(1:nrow(cor.G))[-index],drop=F]
        #representative residual R2
        R2.R<-colSums(B*inv.A%*%B)
        inv.AB<-inv.A%*%B
        
        if(stop.method=='R2.ratio'){
          #representative plus other groups R2
          index.O<-which(clusters!=j)
          index.OR<-c(which(clusters==j)[index],index.O)
          inv.A.OR<-inv.Sigma[index.OR,index.OR]-
            inv.Sigma[index.OR,-index.OR,drop=F]%*%solve(inv.Sigma[-index.OR,-index.OR])%*%t(inv.Sigma[index.OR,-index.OR,drop=F])
          B.OR<-Sigma[which(clusters==j)[-index],index.OR,drop=F]
          R2.OR<-rowSums(B.OR%*%inv.A.OR*B.OR)#diag(B%*%inv.A%*%t(B))

        #print(R2.R)
        #print(R2.OR)
        #print(B.OR[1:4])
        #print(R2.R/R2.OR)

          if(mean(R2.R/R2.OR)>=thres){
            #print(mean(R2.R/R2.OR))
            break
          }
        }
        if(stop.method=='R2'){
          if(mean(R2.R)>=thres){
            print(min(R2.R))
            break
          }
        }
        index.add<-index.all[i+1]
        b<-cor.G[index,index.add,drop=F]
        c<-cor.G[index.add,index.add,drop=F]
        #R<-as.numeric(solve(c-t(b)%*%inv.A%*%b))
        R<-solve(c-t(b)%*%inv.A%*%b)
        inv.Ab<-inv.A%*%b
        inv.A<-rbind(cbind(inv.A+inv.Ab%*%R%*%t(inv.Ab),-inv.Ab%*%R),cbind(-R%*%t(inv.Ab),R))
        #update results
        index<-c(index,index.add)
        
      }
      index<-which(clusters==j)[index]
      rep.data<-rbind(rep.data,cbind(j,as.numeric(index)))
    }
  }
  return(rep.data)
}

step1 <- function(C,vlist=seq(ncol(C)),RSS0=sum(diag(C)),zero=1e-12){
  dC <- diag(C)
  rs <- colSums(C^2)/dC
  imax <- order(rs,decreasing=TRUE)[1]
  vmin <- sum(dC) - rs[imax]
  residC = C - outer(C[,imax],C[,imax],"*")/C[imax,imax]
  index = vlist[imax]
  izero = diag(residC) <= zero
  list(index = index, variance = vmin, R2 = 1-vmin/RSS0, C=residC[!izero,!izero],vlist=vlist[!izero])
}

subsetC <- function(C, k=NA, traceit=FALSE){
  ## C correlation matrix
  ## k subset size
  do.adaptive <- is.na(k)
  p <- ncol(C)
  if (do.adaptive) {
    k <- p-1
  }
  indices <- rep(0, k)
  RSS0 <- p
  R2 <- double(k)
  vlist = seq(p)
  for(i in 1:k){
    fit1 <- step1(C, RSS0=RSS0, vlist=vlist)
    indices[i] <- fit1$index
    C <- as.matrix(fit1$C)
    vlist <- fit1$vlist
    R2[i] <- fit1$R2
    if(traceit)cat(i, "index", fit1$index, "Variance Explained", fit1$variance,"R-squared",fit1$R2,"\n")
    
    # if there is at least 3 R2 values,
    # check early stopping rule
    if (do.adaptive && (i >= 3)) {
      rsq_u <- R2[i]
      rsq_m <- R2[i-1]
      rsq_l <- R2[i-2]
      if (check_early_stopping_rule(rsq_l, rsq_m, rsq_u)) {
        indices <- indices[1:i]
        R2 <- R2[1:i]
        break
      }
    }
  }
  list(indices = indices, R2=R2)
}

check_early_stopping_rule <- function(rsq_l, rsq_m, rsq_u, cond_0_thresh=1e-2, cond_1_thresh=1e-2) 
{
  delta_u <- (rsq_u-rsq_m)
  delta_m <- (rsq_m-rsq_l)
  (delta_u < cond_0_thresh*rsq_u) && ((delta_m*rsq_u-delta_u*rsq_m) < cond_1_thresh*rsq_m*rsq_u)
}
""";

┌ Info: Precompiling Knockoffs [878bf26d-0c49-448a-9df5-b057c815d613]
└ @ Base loading.jl:1423


## Test data

In [187]:
# simulate sigma
m = 1
p = 1000
k = 10 # number of causal groups
n = 250 # sample size
μ = zeros(p)
Sigma = simulate_AR1(p, a=3, b=1)
groups = repeat(1:200, inner=5)

1000-element Vector{Int64}:
   1
   1
   1
   1
   1
   2
   2
   2
   2
   2
   3
   3
   3
   ⋮
 198
 198
 199
 199
 199
 199
 199
 200
 200
 200
 200
 200

In [2]:
# sigma from gnomAD
p = 1000
datadir = "/Users/biona001/Benjamin_Folder/research/4th_project_PRS/group_knockoff_test_data"
covfile = CSV.read(joinpath(datadir, "CorG_2_127374341_128034347.txt"), DataFrame) # 3782 SNPs
Sigma = covfile |> Matrix{Float64}
Sigma = 0.99Sigma + 0.01I #ensure PSD
Sigma = Sigma[1:p, 1:p]
@time groups = hc_partition_groups(Symmetric(Sigma))

  0.935046 seconds (3.08 M allocations: 185.249 MiB, 1.43% gc time, 98.36% compilation time)


1000-element Vector{Int64}:
   1
   2
   3
   4
   5
   5
   5
   6
   7
   7
   7
   8
   9
   ⋮
 218
 218
 231
 218
 123
 218
 218
 218
 210
 218
 218
 218

### Julia implementation of Zihuai's method for searching representatives

In [15]:
@btime group_reps = Knockoffs.choose_group_reps(Symmetric(Sigma), groups)

  248.929 ms (29271 allocations: 56.48 MiB)


236-element Vector{Int64}:
   1
   2
   3
   4
   6
  20
  47
  49
  55
  56
  57
  59
  60
   ⋮
 896
 904
 931
 937
 938
 941
 955
 956
 961
 964
 966
 991

### Zihuai's R code for searching representatives

In [16]:
SigmaInv = inv(Sigma)
@rput groups Sigma SigmaInv
@btime begin
    R"""
    rep_data <- Get.group.rep(Sigma,groups,inv.Sigma=SigmaInv,thres=0.5,search.method='subsetC',stop.method="R2.ratio")
    """
end
@rget rep_data # default: reps not sorted
rep_data[:, 2] |> sort

  1.369 s (42 allocations: 1.14 KiB)


236-element Vector{Float64}:
   1.0
   2.0
   3.0
   4.0
   6.0
  20.0
  47.0
  49.0
  55.0
  56.0
  57.0
  59.0
  60.0
   ⋮
 896.0
 904.0
 931.0
 937.0
 938.0
 941.0
 955.0
 956.0
 961.0
 964.0
 966.0
 991.0

### Profile code

Out of 161 samples, 
+ 140 samples spent on `Σ_RORO_inv .= @view(Σinv[RO, RO])`

In [17]:
using Profile
@profile Knockoffs.choose_group_reps(Symmetric(Sigma), groups)
Profile.clear()
@profile Knockoffs.choose_group_reps(Symmetric(Sigma), groups)

Profile.print()

Overhead ╎ [+additional indent] Count File:Line; Function
   ╎161 @Base/task.jl:429; (::IJulia.var"#15#18")()
   ╎ 161 @IJulia/src/eventloop.jl:8; eventloop(socket::ZMQ.Socket)
   ╎  161 @Base/essentials.jl:714; invokelatest
   ╎   161 @Base/essentials.jl:716; #invokelatest#2
   ╎    161 .../execute_request.jl:67; execute_request(socket::ZMQ.So...
   ╎     161 .../SoftGlobalScope.jl:65; softscope_include_string(m::Mo...
   ╎    ╎ 161 @Base/loading.jl:1196; include_string(mapexpr::type...
   ╎    ╎  161 @Base/boot.jl:373; eval
   ╎    ╎   161 ...offs/src/group.jl:1852; choose_group_reps(Σ::Symmet...
   ╎    ╎    1   ...ffs/src/group.jl:1856; choose_group_reps(Σ::Symme...
   ╎    ╎     1   ...src/symmetric.jl:677; inv(A::Symmetric{Float64, ...
   ╎    ╎    ╎ 1   ...src/symmetric.jl:662; _inv(A::Symmetric{Float64...
   ╎    ╎    ╎  1   ...gebra/src/lu.jl:278; lu
   ╎    ╎    ╎   1   ...gebra/src/lu.jl:278; lu
   ╎    ╎    ╎    1   ...gebra/src/lu.jl:279; lu(A::Symmetric{Float64,...
   ╎  

### Julia implementation of Trevor's method 

This is old code that have some bug (result different than Trevor's code)

In [3]:
# Julia implementation that have some bug (result different than Trevor's code)
function select_one(C::AbstractMatrix{T}) where T
    p = size(C, 2)
#     return [tr(C - C[:,i]*C[i,:]'/C[i,i]) for i in 1:p] |> argmin
    best_val, min_idx = typemax(T), 0
    for i in 1:p
        val = zero(T)
        for j in 1:p
            j == i && continue
            val += C[j, j] - C[j,i]*C[i,j]/C[i,i]
        end
        if val < best_val
            best_val = val
            min_idx = i
        end
    end
    return min_idx
end

function select_k(C::AbstractMatrix, k::Int)
    p = size(C, 2)
    selected, not_selected = Int[], collect(1:p)
    C̃ = copy(C)
    for i in 1:min(k, p)
        idx = select_one(C̃)
        rep = not_selected[idx]
        push!(selected, rep)
        deleteat!(not_selected, findfirst(x -> x == rep, not_selected))
        C̃ -= C̃[:, idx] * C̃[idx, :]' ./ C̃[idx, idx]
        C̃ = @view(C[not_selected, not_selected])
    end
    return selected
end

selected = select_k(C, 10)

[selected result[:indices]]

LoadError: UndefVarError: C not defined

In [4]:
# faithful re-implementation of Trevor's R code. Probably not the most Julian/efficient Julia code
function step1(C::AbstractMatrix, vlist, RSS0, tol=1e-12)
    dC = diag(C)
    rs = vec(sum(C.^2, dims=1)) ./ dC
    v, imax = findmax(rs)
    vmin = sum(dC) - rs[imax]
    residC = C - (C[:,imax] * C[:,imax]' ./ C[imax, imax])
    index = vlist[imax]
    nzero = findall(x -> x > tol, diag(residC))
    R2 = 1 - vmin/RSS0
    return index, R2, residC[nzero, nzero], vlist[nzero]
end

function subsetC(C::AbstractMatrix, k::Int)
    p = size(C, 2)
    indices = zeros(Int, k)
    RSS0 = p
    R2 = zeros(k)
    vlist = collect(1:p)
    for i in 1:k
        idx, r2, Cnew, vnew = step1(C, vlist, RSS0)
        indices[i] = idx
        C = Cnew
        vlist = vnew
        R2[i] = r2
    end
    return indices, R2
end

@time selected, R2 = subsetC(C, 10)
[selected result[:indices]]

LoadError: UndefVarError: C not defined

In [183]:
[R2 result[:R2]]

10×2 Matrix{Float64}:
 0.0213559  0.0213559
 0.0422617  0.0422617
 0.0596227  0.0596227
 0.0758653  0.0758653
 0.0913927  0.0913927
 0.106448   0.106448
 0.121425   0.121425
 0.13594    0.13594
 0.150434   0.150434
 0.164428   0.164428

## Some test

### Interpolative decomposition, selecting group reps by ID

In [264]:
Random.seed!(2022)
nrep = 3
rep_method = :id
groups1, group_reps = id_partition_groups(X, rep_method=rep_method, nrep=nrep)
@test countmap(groups1[group_reps]) |> values |> collect |> maximum == 3
groups1, group_reps = id_partition_groups(Symmetric(cor(X)), rep_method=rep_method, nrep=nrep)
@test countmap(groups1[group_reps]) |> values |> collect |> maximum == 3

[32m[1mTest Passed[22m[39m
  Expression: ((countmap(groups1[group_reps]) |> values) |> collect) |> maximum == 3
   Evaluated: 3 == 3

### Interpolative decomposition, selecting group reps by Trevor's method

In [265]:
Random.seed!(2022)
nrep = 2
rep_method = :rss
groups2, group_reps = id_partition_groups(X, rep_method=rep_method, nrep=nrep)
@test countmap(groups2[group_reps]) |> values |> collect |> maximum == 2

[32m[1mTest Passed[22m[39m
  Expression: ((countmap(groups2[group_reps]) |> values) |> collect) |> maximum == 2
   Evaluated: 2 == 2

### hierarchical clustering, using ID to choose reps

In [269]:
Random.seed!(2022)
nrep = 2
rep_method = :id
groups1, group_reps1 = hc_partition_groups(X, rep_method=rep_method, nrep=nrep)
@test countmap(groups[group_reps]) |> values |> collect |> maximum == 2

[32m[1mTest Passed[22m[39m
  Expression: ((countmap(groups[group_reps]) |> values) |> collect) |> maximum == 2
   Evaluated: 2 == 2

### hierarchical clustering, using Trevor's method to choose reps

In [273]:
Random.seed!(2022)
nrep = 2
rep_method = :rss
groups2, group_reps2 = hc_partition_groups(X, rep_method=rep_method, nrep=nrep)
groups2, group_reps2 = hc_partition_groups(Symmetric(cor(X)), rep_method=rep_method, nrep=nrep)
@test countmap(groups[group_reps]) |> values |> collect |> maximum == 2

[32m[1mTest Passed[22m[39m
  Expression: ((countmap(groups[group_reps]) |> values) |> collect) |> maximum == 2
   Evaluated: 2 == 2

In [272]:
[group_reps1 group_reps2]

285×2 Matrix{Int64}:
   1    1
   2    2
   3    3
   4    4
   5    5
   6    7
   9    9
  10   10
  11   11
  12   12
  14   13
  18   16
  19   19
   ⋮  
 494  494
 495  495
 496  496
 497  497
 498  498
 499  499
 501  501
 502  502
 503  503
 505  504
 506  506
 510  508