# Make fastPHASE knockoffs

In [1]:
using Revise
using LinearAlgebra
using DelimitedFiles
using Distributions
using ProgressMeter
using SnpArrays
using Random
using Knockoffs
using BenchmarkTools
plinkname = "/Users/biona001/.julia/dev/Knockoffs/fastphase/ukb.10k.chr10"
datadir = "/Users/biona001/.julia/dev/Knockoffs/fastphase"
# plinkname = "/scratch/users/bbchu/ukb_SHAPEIT/subset/ukb.10k.chr10"
# datadir = "/scratch/users/bbchu/fastphase"
T = 10
extension="ukb_chr10_n1000_"

"ukb_chr10_n1000_"

## Make knockoffs

In [14]:
X̃ = hmm_knockoff(plinkname, extension, T=T, datadir=datadir)

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:05[39m


10000×29481 SnpArray:
 0x03  0x02  0x00  0x02  0x03  0x02  …  0x00  0x02  0x00  0x02  0x02  0x03
 0x03  0x00  0x00  0x02  0x03  0x00     0x00  0x00  0x00  0x02  0x03  0x00
 0x02  0x00  0x00  0x02  0x02  0x00     0x00  0x00  0x00  0x00  0x03  0x00
 0x03  0x00  0x03  0x00  0x02  0x00     0x00  0x00  0x00  0x00  0x02  0x00
 0x02  0x00  0x03  0x00  0x02  0x00     0x00  0x00  0x02  0x00  0x03  0x00
 0x03  0x00  0x00  0x00  0x00  0x00  …  0x00  0x00  0x00  0x00  0x00  0x02
 0x03  0x00  0x00  0x00  0x03  0x00     0x00  0x00  0x00  0x00  0x00  0x00
 0x03  0x00  0x00  0x00  0x03  0x00     0x02  0x00  0x00  0x00  0x03  0x00
 0x03  0x00  0x02  0x00  0x03  0x00     0x00  0x00  0x00  0x00  0x03  0x02
 0x03  0x00  0x00  0x00  0x03  0x00     0x00  0x00  0x00  0x00  0x02  0x00
 0x00  0x00  0x00  0x00  0x00  0x00  …  0x00  0x00  0x00  0x00  0x00  0x00
 0x00  0x00  0x00  0x00  0x00  0x00     0x00  0x00  0x00  0x00  0x00  0x00
 0x00  0x00  0x00  0x00  0x00  0x00     0x00  0x00  0x00  0x00  0x00  0x00
   

In [116]:
snpdata = SnpData(joinpath(datadir, plinkname))
Xfull = snpdata.snparray
n, p = size(Xfull)

# get r, α, θ estimated by fastPHASE
r, θ, α = process_fastphase_output(datadir, T, extension=extension)
K = size(θ, 2)
statespace = (K * (K + 1)) >> 1
table = MarkovChainTable(K)

# get initial states (marginal distribution vector) and Markov transition matrices
q = get_initial_probabilities(α, table)
Q = get_genotype_transition_matrix(r, θ, α, q, table)

# preallocated arrays
# full_knockoff = SnpArray(outfile * ".bed", n, 2p)
full_knockoff = zeros(Int, n, p)
X = zeros(Float64, p)
Z = zeros(Int, p)
Z̃ = zeros(Int, p)
X̃ = zeros(Int, p)
N = zeros(p, statespace)
d_K = Categorical([1 / statespace for _ in 1:statespace]) # for sampling markov chains (length statespace)
d_3 = Categorical([1 / statespace for _ in 1:statespace]) # for sampling genotypes (length 3);

In [117]:
@showprogress for i in 1:10    
    # sample hidden states (algorithm 3 in Sesia et al)
    copyto!(X, @view(Xfull[i, :]))
    forward_backward_sampling!(Z, X, d_K, Q, q, θ, table)
    
    # sample knockoff of markov chain (algorithm 2 in Sesia et al)
    markov_knockoffs!(Z̃, Z, N, d_K, Q, q)

    # sample knockoffs of genotypes (eq 6 in Sesia et al)
    genotype_knockoffs!(X̃, Z̃, table, θ, d_3)

    # save knockoff
    full_knockoff[i, :] .= X̃
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:05[39m


## (scaled) Forward backward algorithm to get Z

In [17]:
xdata = SnpData(plinkname)
x = xdata.snparray
n, p = size(x)

# get r, α, θ estimated by fastPHASE
r, θ, α = process_fastphase_output(datadir, T, extension=extension)
K = size(θ, 2) # number of haplotype motifs
statespace = (K * (K + 1)) >> 1
table = MarkovChainTable(K)

# form transition matrices, initial state and emission probabilities
H = get_haplotype_transition_matrix(r, θ, α)
Q = get_genotype_transition_matrix(H, table)
q = get_initial_probabilities(α, table);

In [18]:
xi = zeros(Float64, p)
Z = zeros(Int, p);

Random.seed!(2022)
i = 1
xi = copyto!(xi, @view(x[i, :]))
@time forward_backward_sampling!(Z, xi, Q, q, θ, table)
[Z [index_to_pair(table, i) for i in Z]]

  0.096534 seconds (29.49 k allocations: 27.892 MiB)


29481×2 Matrix{Any}:
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
 10  (1, 10)
  ⋮  
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)
 43  (6, 8)

## Sample knockoff of markov chain

In [19]:
Z̃ = zeros(Int, p)
N = zeros(p, statespace)
d = Categorical([1 / statespace for _ in 1:statespace])
@btime markov_knockoffs!($Z̃, $Z, $N, $d, $Q, $q);

  383.333 ms (0 allocations: 0 bytes)


In [38]:
Random.seed!(2022)
Z̃ = markov_knockoffs(Z, Q, q)

29481-element Vector{Int64}:
 10
 10
 10
 10
 10
 10
 10
 10
 10
 10
 10
 10
  4
  ⋮
 43
 43
 43
 43
 43
 43
 43
 43
 43
 43
 43
 43

## Sample genotype knockoffs

In [44]:
X̃ = genotype_knockoffs(Z̃, table, θ)
@show cor(X̃, xi)
[X̃ xi]

cor(X̃, xi) = 0.4848137894799969


29481×2 Matrix{Float64}:
 2.0  1.0
 0.0  0.0
 0.0  1.0
 1.0  0.0
 2.0  1.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 1.0  2.0
 2.0  1.0
 ⋮    
 1.0  0.0
 0.0  0.0
 0.0  0.0
 1.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 1.0  0.0
 1.0  0.0
 2.0  0.0

## Numerical issues with sampling knockoffs

In [9]:
Z̃ = zeros(Int, p)
N = zeros(p, statespace)
d = Categorical([1 / statespace for _ in 1:statespace])

j = 1
update_normalizing_constants!(N, Z, Z̃, Q, q, j) # equation 5
Knockoffs.single_state_dmc_knockoff!(Z̃, Z, d, N, Q, q, j)

j = 2
update_normalizing_constants!(N, Z, Z̃, Q, q, j) # equation 5
Knockoffs.single_state_dmc_knockoff!(Z̃, Z, d, N, Q, q, j)

j = 3
update_normalizing_constants!(N, Z, Z̃, Q, q, j) # equation 5
Knockoffs.single_state_dmc_knockoff!(Z̃, Z, d, N, Q, q, j)

j = 4
update_normalizing_constants!(N, Z, Z̃, Q, q, j) # equation 5
Knockoffs.single_state_dmc_knockoff!(Z̃, Z, d, N, Q, q, j)

j = 5
update_normalizing_constants!(N, Z, Z̃, Q, q, j) # equation 5
Knockoffs.single_state_dmc_knockoff!(Z̃, Z, d, N, Q, q, j)

j = 6
update_normalizing_constants!(N, Z, Z̃, Q, q, j) # equation 5
Knockoffs.single_state_dmc_knockoff!(Z̃, Z, d, N, Q, q, j)

j = 7
update_normalizing_constants!(N, Z, Z̃, Q, q, j) # equation 5
Knockoffs.single_state_dmc_knockoff!(Z̃, Z, d, N, Q, q, j)

j = 8
update_normalizing_constants!(N, Z, Z̃, Q, q, j) # equation 5
Knockoffs.single_state_dmc_knockoff!(Z̃, Z, d, N, Q, q, j)

In [53]:
N

29481×55 Matrix{Float64}:
 0.0104452   0.0202199    0.0202368    …  0.0102105    0.0181799   0.00809235
 0.00206479  0.000988552  0.000876434     4.75048e-8   0.00160533  0.00250341
 8.28113e-7  4.18914e-7   2.98606e-7      2.42606e-11  6.68505e-7  9.49961e-7
 0.0         0.0          0.0             0.0          0.0         0.0
 0.0         0.0          0.0             0.0          0.0         0.0
 0.0         0.0          0.0          …  0.0          0.0         0.0
 0.0         0.0          0.0             0.0          0.0         0.0
 0.0         0.0          0.0             0.0          0.0         0.0
 0.0         0.0          0.0             0.0          0.0         0.0
 0.0         0.0          0.0             0.0          0.0         0.0
 0.0         0.0          0.0          …  0.0          0.0         0.0
 0.0         0.0          0.0             0.0          0.0         0.0
 0.0         0.0          0.0             0.0          0.0         0.0
 ⋮                            

In [11]:
Z̃ = zeros(Int, p)
N = zeros(p, statespace)
d = Categorical([1 / statespace for _ in 1:statespace])

for j in 1:p
    update_normalizing_constants!(N, Z, Z̃, Q, q, j) # equation 5
    Knockoffs.single_state_dmc_knockoff!(Z̃, Z, d, N, Q, q, j)
end

@show sum(d.p)
sum(N, dims=2)

sum(d.p) = 1.0000000000000002


29481×1 Matrix{Float64}:
  1.0000000000000002
 56.06005372525685
  0.01783801368582635
 56.08064192825152
  0.017831491854678007
 56.273251658420556
  0.01782567349876249
 56.613311080204994
  0.01771881035180266
 56.76635465678353
  0.017662207758523424
 56.96109951679734
  0.01758068922862137
  ⋮
  3.132412630011038e25
  3.196259933853093e-26
  3.1286567772538117e25
  3.200514475361951e-26
  3.128465486712592e25
  3.196656612371382e-26
  3.1297517035959344e25
  3.1971495204023127e-26
  3.1286675650940035e25
  3.196725686271059e-26
  3.1282072034338844e25
  1.7582232816686372e-24

In [20]:
sum(N, dims=2)

29481×1 Matrix{Float64}:
  1.0000000000000002
 54.2639683695825
  0.018428434849456683
 54.28294099029847
  0.01842202157592164
 54.46560222966858
  0.018421161559418645
 54.76540497321972
  0.018324716116195197
 54.878190326055076
  0.018273299653702453
 55.0139156956132
  0.018204465860137662
  ⋮
  3.6609090374126475e86
  2.73506239070679e-87
  3.65622384277509e86
  2.7386355293396198e-87
  3.656499771203274e86
  2.735025320919982e-87
  3.657915374966508e86
  2.735829710737848e-87
  3.656225791496842e86
  2.7354690477186257e-87
  3.6556919305423893e86
  2.3076678095328526e-92

In [92]:
N

29481×55 Matrix{Float64}:
     0.0104452        0.0202199    …      0.0181799        0.00809235
     0.00206479       0.000988552         0.00160533       0.00250341
     8.28113e-7       4.18914e-7          6.68505e-7       9.49961e-7
     0.022767         0.0144971           0.0215867        0.0202142
     2.19784e-5       1.68489e-5          1.57211e-5       7.47459e-6
     0.072532         0.028756     …      0.0335778        0.00271332
     1.52591e-5       1.51584e-5          2.28177e-5       2.76152e-5
     0.0296311        0.0766969           0.0997475        0.121466
     6.97654e-6       1.59861e-5          2.34815e-5       2.22566e-5
     0.0697864        0.055066            0.0305128        0.0566642
     3.31068e-5       1.93318e-5   …      1.20735e-5       1.20862e-5
     0.126527         0.142382            0.149429         0.133096
     0.000110562      9.05297e-5          8.60514e-5       8.62945e-5
     ⋮                             ⋱                   
   813.006    

In [27]:
log(1.7789e74)

170.96729207730846

In [26]:
log(4.19576e-94)

-215.00892424987288

## Debug: Look at markov chains and its knockoff

In [128]:
pwd()

"/Users/biona001/.julia/dev/Knockoffs/test"

In [131]:
θ_full = readdlm("/Users/biona001/.julia/dev/Knockoffs/fastphase/ukb_chr10_n1000_thetahat.txt", 
    comments=true, comment_char = '>', header=false);

In [139]:
i = 2
rows = rows = (i - 1) * p + 1:p*i

29482:58962

In [136]:
i = 5
rows = rows = (i - 1) * p + 1:p*i
@view(θ_full[rows])

29481-element view(::Vector{Float64}, 117925:147405) with eltype Float64:
 0.999
 0.001
 0.969835
 0.001
 0.999
 0.001
 0.001
 0.001
 0.001
 0.001
 0.001
 0.481411
 0.999
 ⋮
 0.999
 0.984593
 0.001
 0.999
 0.001
 0.001
 0.238347
 0.001
 0.001
 0.001
 0.999
 0.956837

In [123]:
snp1 = convert(Vector{Float64}, @view(Xfull[:, 1]))

10000-element Vector{Int64}:
 2
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 ⋮
 2
 0
 0
 0
 0
 0
 0
 2
 0
 0
 0
 0

In [124]:
sum(snp1)

1496

In [120]:
Xfull

10000×29481 SnpArray:
 0x02  0x00  0x02  0x00  0x02  0x00  …  0x00  0x00  0x00  0x00  0x00  0x00
 0x00  0x00  0x03  0x00  0x03  0x00     0x00  0x00  0x00  0x02  0x02  0x00
 0x00  0x02  0x03  0x02  0x00  0x00     0x00  0x00  0x00  0x00  0x00  0x00
 0x00  0x00  0x00  0x00  0x00  0x00     0x00  0x00  0x00  0x02  0x02  0x00
 0x00  0x02  0x03  0x00  0x02  0x00     0x00  0x00  0x02  0x00  0x02  0x02
 0x00  0x00  0x03  0x00  0x02  0x00  …  0x00  0x00  0x00  0x00  0x00  0x00
 0x00  0x00  0x03  0x00  0x00  0x00     0x00  0x00  0x00  0x00  0x02  0x00
 0x00  0x00  0x03  0x00  0x02  0x00     0x00  0x00  0x00  0x02  0x02  0x00
 0x00  0x00  0x03  0x00  0x02  0x00     0x00  0x00  0x00  0x00  0x02  0x00
 0x00  0x00  0x03  0x00  0x03  0x00     0x00  0x00  0x00  0x00  0x00  0x00
 0x00  0x02  0x03  0x02  0x02  0x00  …  0x00  0x00  0x00  0x00  0x00  0x00
 0x00  0x00  0x03  0x00  0x02  0x00     0x00  0x00  0x00  0x00  0x00  0x00
 0x00  0x00  0x03  0x00  0x02  0x00     0x02  0x00  0x00  0x00  0x02  0x02
   

In [107]:
snpdata = SnpData(joinpath(datadir, plinkname))
Xfull = snpdata.snparray
n, p = size(Xfull)

# get r, α, θ estimated by fastPHASE
r, θ, α = process_fastphase_output(datadir, T, extension=extension)
K = size(θ, 2)
statespace = (K * (K + 1)) >> 1
table = MarkovChainTable(K)

# get initial states (marginal distribution vector) and Markov transition matrices
q = get_initial_probabilities(α, table)
Q = get_genotype_transition_matrix(r, θ, α, q, table)

# preallocated arrays
# full_knockoff = SnpArray(outfile * ".bed", n, 2p)
full_knockoff = zeros(Int, n, p)
X = zeros(Float64, p)
Z = zeros(Int, p)
Z̃ = zeros(Int, p)
X̃ = zeros(Int, p)
N = zeros(p, statespace)
d_K = Categorical([1 / statespace for _ in 1:statespace]) # for sampling markov chains (length statespace)
d_3 = Categorical([1 / statespace for _ in 1:statespace]); # for sampling genotypes (length 3)

In [110]:
i = 100

# sample hidden states (algorithm 3 in Sesia et al)
copyto!(X, @view(Xfull[i, :]))
forward_backward_sampling!(Z, X, d_K, Q, q, θ, table)

# sample knockoff of markov chain (algorithm 2 in Sesia et al)
markov_knockoffs!(Z̃, Z, N, d_K, Q, q)

# sample knockoffs of genotypes (eq 6 in Sesia et al)
genotype_knockoffs!(X̃, Z̃, table, θ, d_3)

[Z̃ Z X̃ X]

29481×4 Matrix{Float64}:
  2.0   2.0  2.0  0.0
  2.0   2.0  0.0  0.0
  2.0   2.0  1.0  2.0
  2.0   2.0  0.0  0.0
  2.0   2.0  2.0  1.0
  2.0   2.0  0.0  0.0
  2.0   2.0  0.0  0.0
  2.0   2.0  0.0  0.0
  2.0   2.0  0.0  0.0
  2.0   2.0  0.0  0.0
  2.0   2.0  0.0  0.0
  2.0   2.0  1.0  1.0
  2.0   2.0  2.0  0.0
  ⋮               
 50.0  50.0  1.0  1.0
 50.0  50.0  1.0  1.0
 50.0  50.0  0.0  0.0
 50.0  50.0  1.0  1.0
 50.0  50.0  0.0  0.0
 50.0  50.0  0.0  0.0
 50.0  50.0  0.0  1.0
 50.0  50.0  0.0  0.0
 50.0  50.0  0.0  0.0
 50.0  50.0  0.0  0.0
 50.0  50.0  0.0  1.0
 50.0  50.0  0.0  1.0

In [98]:
a, b = Knockoffs.index_to_pair(table, 10)
@show get_genotype_emission_probabilities(θ, 0, a, b, 1)
@show get_genotype_emission_probabilities(θ, 1, a, b, 1)
@show get_genotype_emission_probabilities(θ, 2, a, b, 1)

get_genotype_emission_probabilities(θ, 0, a, b, 1) = 0.0034533329529600264
get_genotype_emission_probabilities(θ, 1, a, b, 1) = 0.15869613409408023
get_genotype_emission_probabilities(θ, 2, a, b, 1) = 0.8378505329529597


0.8378505329529597

In [83]:
idx = findall(Z .!= Z̃)
[idx Z[idx] Z̃[idx]]

1071×3 Matrix{Int64}:
    13  47  31
    35  28  32
    57  45  42
   108  42  48
   111  26  48
   130  40  27
   157  52   6
   158   6   7
   255  52   8
   260  41  43
   261   6  41
   275   8   6
   314  38   8
     ⋮      
 29103  28  32
 29141  13  42
 29142  16  42
 29166  23  41
 29167  41  43
 29225  25  17
 29242  25  52
 29310  19  10
 29319   7  10
 29327   7   1
 29354   1   7
 29454  11  17

In [84]:
idx = 100:120
[idx Z[idx] Z̃[idx]]

21×3 Matrix{Int64}:
 100  42  42
 101  42  42
 102  42  42
 103  42  42
 104  42  42
 105  42  42
 106  42  42
 107  42  42
 108  42  48
 109  48  48
 110  48  48
 111  26  48
 112  26  26
 113  26  26
 114  26  26
 115  26  26
 116  26  26
 117  26  26
 118  26  26
 119  26  26
 120  27  27

In [77]:
idx = 50:70
[idx Z[idx] Z̃[idx]]

21×3 Matrix{Int64}:
 50  7  7
 51  7  7
 52  7  7
 53  7  7
 54  7  7
 55  7  7
 56  2  7
 57  2  2
 58  2  2
 59  2  2
 60  2  2
 61  2  2
 62  2  2
 63  2  2
 64  2  2
 65  2  2
 66  2  7
 67  7  7
 68  7  7
 69  7  7
 70  7  7

In [76]:
idx = 29275:29285
[idx Z[idx] Z̃[idx]]

11×3 Matrix{Int64}:
 29275  51  51
 29276  51  51
 29277  51  51
 29278  51  51
 29279  51  51
 29280  51  44
 29281  44  15
 29282  15  15
 29283  15  15
 29284  15  15
 29285  17  17