In [1]:
using Pkg
pkg"add https://github.com/kose-y/ParProx.jl"
pkg"add Printf Statistics CSV Mmap CodecZlib ROCAnalysis DataFrames"
pkg"add CUDA Adapt"

[32m[1m    Updating[22m[39m git-repo `https://github.com/kose-y/ParProx.jl`
[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`


In [2]:
versioninfo()

Julia Version 1.8.5
Commit 17cfb8e65ea (2023-01-08 06:45 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900KF
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, goldmont)
  Threads: 1 on 32 virtual cores
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64


## **Cox regression with non-overlapping groups**

Cox regression on somatic mutation count data

In [3]:
using ParProx, Printf, Statistics # load the packages
using CSV, DataFrames, CodecZlib, Mmap # packages for data reading. GZip is used to read the gzipped text file.
using Random, CUDA, Adapt

In [4]:
CUDA.device()

CuDevice(0): NVIDIA GeForce RTX 3080

Data description of `somatic_sum_table_GBM.tsv.gz`

row: somatic mutation

col 1-9: information, col 10-: mutation count of each subject(GPD)

In [5]:
somatic = DataFrame(CSV.File(transcode(GzipDecompressor, Mmap.mmap(
            ParProx.datadir("somatic_sum_table_GBM.tsv.gz")))));

In [6]:
first(somatic[:, 1:10], 5)

Row,uniprot_accession,start_position,end_position,center_position,unit_name,gene_name,gene_id,unit_label,sum,TCGA-02-0003
Unnamed: 0_level_1,String15,String7,String7,String7,String15,String31?,String15,String3,Int64,Int64
1,Q9H2S6,,,,NCU,TNMD,ENSG00000000005,NCU,1,0
2,O60762,28.0,199.0,113.0,Glycos_transf_2,DPM1,ENSG00000000419,PIU,1,0
3,Q8IZE3,26.0,245.0,135.0,Pkinase,SCYL3,ENSG00000000457,PIU,1,0
4,Q8IZE3,,,,LU,SCYL3,ENSG00000000457,LU,2,0
5,Q8IZE3,,,,NCU,SCYL3,ENSG00000000457,NCU,1,0


In [7]:
size(somatic)

(25307, 399)

**Data description of `primary_TCGA_CDR_GBM.tsv`**

row: subject(GPD)

col: clinical information

In [8]:
cdr = DataFrame(CSV.File(ParProx.datadir("primary_TCGA_CDR_GBM.tsv")));

In [9]:
first(cdr[:, 1:5], 5)

Row,V1,bcr_patient_barcode,type,age_at_initial_pathologic_diagnosis,gender
Unnamed: 0_level_1,Int64,String15,String3,Int64,String7
1,2647,TCGA-02-0003,GBM,50,MALE
2,2664,TCGA-02-0033,GBM,54,MALE
3,2671,TCGA-02-0047,GBM,78,MALE
4,2676,TCGA-02-0055,GBM,62,FEMALE
5,2734,TCGA-02-2466,GBM,61,MALE


In [10]:
size(cdr)

(390, 34)

In [11]:
survival_event = convert(Vector, cdr[!, :OS])
survival_time = convert(Vector, cdr[!, Symbol("OS.time")]);

In [12]:
age_at_diagnosis = cdr[!, :age_at_initial_pathologic_diagnosis]
gender = map(x -> x == "MALE", cdr[!, :gender])
clinical = hcat(age_at_diagnosis, gender);

In [13]:
somatic_predictors = hcat(
    convert(
        Array{Float64}, 
        transpose(convert(Array{Float64}, Matrix(somatic[:, 10:end])))
    ),
    clinical
)

390×25309 Matrix{Float64}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  50.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  54.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  78.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  62.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  61.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  57.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  43.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  53.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  64.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  81.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  84.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  67.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  63.0  0.0
 ⋮         

Description of `somatic_predictors`

row: kind of GPD

col: somatic mutations(25307) + clinical info of GPD(2) = 25309

In [14]:
normalize(x) = (x .- mean(x; dims=1)) ./ std(x; dims=1)
somatic_predictors = normalize(somatic_predictors)

390×25309 Matrix{Float64}:
 -0.050637  -0.050637  -0.050637  …  -0.0879321  -0.748927    0.764111
 -0.050637  -0.050637  -0.050637     -0.0879321  -0.450504    0.764111
 -0.050637  -0.050637  -0.050637     -0.0879321   1.34003     0.764111
 -0.050637  -0.050637  -0.050637     -0.0879321   0.146342   -1.30536
 -0.050637  -0.050637  -0.050637     -0.0879321   0.0717363   0.764111
 -0.050637  -0.050637  -0.050637  …  -0.0879321  -0.226687    0.764111
 -0.050637  -0.050637  -0.050637     -0.0879321  -1.27117     0.764111
 -0.050637  -0.050637  -0.050637     -0.0879321  -0.52511     0.764111
 -0.050637  -0.050637  -0.050637     -0.0879321   0.295553    0.764111
 -0.050637  -0.050637  -0.050637     -0.0879321   1.56385    -1.30536
 -0.050637  -0.050637  -0.050637  …  -0.0879321   1.78767    -1.30536
 -0.050637  -0.050637  -0.050637     -0.0879321   0.519371    0.764111
 -0.050637  -0.050637  -0.050637     -0.0879321   0.220948   -1.30536
  ⋮                               ⋱                   

In [15]:
all(cdr[!, :bcr_patient_barcode] .== names(somatic)[10:end])

true

**Sort by descending order of observed time**

In [16]:
sortorder = sortperm(survival_time; rev=true)
somatic_predictors = normalize(somatic_predictors[sortorder, :])
survival_event = survival_event[sortorder]
survival_time = survival_time[sortorder]

390-element Vector{Int64}:
 3881
 3667
 2883
 2791
 2246
 2126
 1818
 1788
 1481
 1458
 1448
 1426
 1417
    ⋮
   15
   13
   12
    6
    6
    6
    5
    4
    4
    3
    3
    0

Make variable name by using first eight columns of somatic dataframe 

In [17]:
somatic_variable_names = map(x -> "$(x[1])\t$(x[2])\t$(x[3])\t" * 
    "$(x[4])\t$(x[5])\t$(x[6])\t$(x[7])\t$(x[8])", eachrow(somatic))
somatic_variable_names = vcat(somatic_variable_names, map(x -> "$x\t\t\t\t\t\t\t", 
        ["age_at_diagnosis", "gender"]));

**Group Definition**

somatic_groups: unique id of gene

somatic_group_sizes: count of each id(gene)

somatic_grpidx: mapped value from element to index of group

In [18]:
somatic_groups = unique(somatic[!, :gene_id])
somatic_group_sizes = map(somatic_groups) do x 
    count(y -> y .== x, somatic[!, :gene_id])
end
somatic_grpidx = map(somatic[!, :gene_id]) do x 
    findfirst(y -> y .== x, somatic_groups)
end;

In [19]:
somatic_grpidx

25307-element Vector{Int64}:
     1
     2
     3
     3
     3
     4
     5
     5
     5
     5
     5
     5
     5
     ⋮
 14247
 14248
 14249
 14250
 14251
 14252
 14253
 14253
 14253
 14253
 14253
 14254

**Cross Validation**

using C-index as metric for cross validation in cox regression

In [20]:
Random.seed!(222)
T = Float64
A = CuArray
U = ParProx.COXUpdate(; maxiter=10000, step=10, tol=5e-4, verbose=true)
lambdas = 10 .^ (range(-5, stop=-7, length=21)) # 21 values equally log-spaced in 10^-5 .. 10^-7
penalties = [GroupNormL2{T, A}(l, somatic_grpidx) for l in lambdas]

score = ParProx.cross_validate(U, somatic_predictors, 
    survival_event, survival_time, penalties, 5; T=T, A=A)

10	(-3.6430362798761964, Inf, 45.0)
20	(-3.6427784146926614, 5.5541135178655456e-5, 41.0)
  6.521090 seconds (33.39 M allocations: 1.730 GiB, 2.70% gc time, 73.82% compilation time: 0% of which was recompilation)
10	(-3.6422442769463803, Inf, 117.0)
20	(-3.6417392606602674, 0.00010879893457028088, 114.0)
  0.016042 seconds (17.47 k allocations: 885.898 KiB)
10	(-3.640765166412631, Inf, 333.0)
20	(-3.6398518417450587, 0.00019684349818131227, 329.0)
  0.014586 seconds (17.59 k allocations: 889.688 KiB)
10	(-3.638224396142005, Inf, 688.0)
20	(-3.636694153869072, 0.000330028727828894, 675.0)
  0.015438 seconds (17.59 k allocations: 888.969 KiB)
10	(-3.63424043813899, Inf, 1300.0)
20	(-3.6319410893842314, 0.0004964114850312903, 1296.0)
  0.014613 seconds (17.59 k allocations: 888.969 KiB)
10	(-3.6285561713048318, Inf, 2532.0)
20	(-3.6253805348679653, 0.0006865676051791315, 2517.0)
30	(-3.6223799927060587, 0.0006491335992803138, 2486.0)
40	(-3.619528900771111, 0.0006171824002382359, 2480.0)


21×5 Matrix{Float64}:
 0.635705  0.583592  0.612992  0.6929    0.602394
 0.636611  0.581818  0.615737  0.681779  0.604959
 0.635251  0.576497  0.614364  0.662532  0.604959
 0.632986  0.566741  0.613907  0.634731  0.604104
 0.637064  0.556098  0.610704  0.602652  0.586575
 0.617127  0.529047  0.612077  0.583405  0.55451
 0.602628  0.520177  0.611619  0.579127  0.525011
 0.582691  0.523725  0.611162  0.591959  0.503206
 0.571364  0.512639  0.610247  0.594525  0.487815
 0.571364  0.511752  0.609332  0.585115  0.476272
 0.574082  0.512639  0.60796   0.578272  0.469859
 0.575442  0.51663   0.607045  0.578272  0.461308
 0.569551  0.515299  0.60613   0.578272  0.45746
 0.564567  0.510865  0.605215  0.577844  0.45404
 0.565927  0.506874  0.604758  0.574423  0.458316
 0.567286  0.50643   0.6043    0.573139  0.457033
 0.56638   0.503326  0.603385  0.567579  0.457033
 0.565473  0.505987  0.602928  0.567579  0.455323
 0.564567  0.505987  0.602928  0.568435  0.456178
 0.565473  0.505543  0.602928  

In [21]:
lambda_idx = argmax(mean(score; dims=2)[:])
lambda = lambdas[lambda_idx]

1.0e-5

In [22]:
p = GroupNormL2{T,A}(1e-5, somatic_grpidx)
U = ParProx.COXUpdate(; maxiter=20000, step=20, tol=1e-6, verbose=true)
V = ParProx.COXVariables{T,A}(adapt(A{T}, somatic_predictors), 
    adapt(A{T}, survival_event), 
    adapt(A{T}, survival_time), p; eval_obj=true)
@time ParProx.fit!(U, V)

20	(-3.8007367881101715, Inf, 64.0)
40	(-3.8000982685581324, 0.00013302218336269303, 64.0)
60	(-3.799519652528114, 0.00012055707068805052, 64.0)
80	(-3.798989782156095, 0.00011041289856229299, 64.0)
100	(-3.7984998925870324, 0.0001020922329954877, 63.0)
120	(-3.7980424695234998, 9.53353511224772e-5, 63.0)
140	(-3.797613148997801, 8.948627418794014e-5, 63.0)
160	(-3.7971949761103683, 8.717029212176865e-5, 47.0)
180	(-3.7967749444466903, 8.756543063673786e-5, 43.0)
200	(-3.796367539344382, 8.494034265853564e-5, 43.0)
220	(-3.7959750068970943, 8.184622453687983e-5, 43.0)
240	(-3.7955959090885143, 7.905124113180029e-5, 43.0)
260	(-3.7952290006731277, 7.651530622104868e-5, 43.0)
280	(-3.7948731975623224, 7.420490514455833e-5, 43.0)
300	(-3.7945275512472096, 7.209184041979033e-5, 43.0)
320	(-3.79419122794955, 7.015224918416324e-5, 43.0)
340	(-3.7938634915131537, 6.836582580552943e-5, 43.0)
360	(-3.7935435052704376, 6.675359102599125e-5, 41.0)
380	(-3.7932219217543532, 6.70913054588247e-5, 38

In [23]:
nonzero_idxs = (1:length(V.β))[V.β .!= 0];

[33m[1m│ [22m[39mInvocation of getindex resulted in scalar indexing of a GPU array.
[33m[1m│ [22m[39mThis is typically caused by calling an iterating implementation of a method.
[33m[1m│ [22m[39mSuch implementations *do not* execute on the GPU, but very slowly on the CPU,
[33m[1m│ [22m[39mand therefore are only permitted from the REPL for prototyping purposes.
[33m[1m│ [22m[39mIf you did intend to index this array, annotate the caller with @allowscalar.
[33m[1m└ [22m[39m[90m@ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:106[39m


In [24]:
for (variable_name, value) in zip(somatic_variable_names[nonzero_idxs], V.β[nonzero_idxs])
    println(variable_name, "\t", value)
end

Q9BS26	30	139	84	Thioredoxin	ERP44	ENSG00000023318	PIU	-0.00021515512685920483
Q6NUM9	NA	NA	NA	LU	RETSAT	ENSG00000042445	LU	-0.0013266367715626894
Q8TAP4	NA	NA	NA	NCU	LMO3	ENSG00000048540	NCU	-1.705536300596848e-5
Q5JXM2	115	342	228	Methyltransf_22	METTL24	ENSG00000053328	PIU	-0.0013266367715626894
P27815	NA	NA	NA	LU	PDE4A	ENSG00000065989	LU	-0.004071125250190334
Q12888	NA	NA	NA	NCU	TP53BP1	ENSG00000067369	NCU	-0.003759039353448859
P25490	NA	NA	NA	LU	YY1	ENSG00000100811	LU	-0.004071125250190334
O00264	NA	NA	NA	LU	PGRMC1	ENSG00000101856	LU	-0.0013266367715626894
Q9HC57	NA	NA	NA	LU	WFDC1	ENSG00000103175	LU	-0.004071125250190334
Q9NUM4	NA	NA	NA	NCU	TMEM106B	ENSG00000106460	NCU	-0.0013266367715626894
Q9Y2J4	NA	NA	NA	LU	AMOTL2	ENSG00000114019	LU	-0.003759039353448859
Q16644	NA	NA	NA	LU	MAPKAPK3	ENSG00000114738	LU	-0.00021515512685920483
Q9H019	10	255	132	Mito_fiss_reg	MTFR1L	ENSG00000117640	PIU	-1.705536300596848e-5
Q8TEH3	NA	NA	NA	LU	DENND1A	ENSG00000119522	LU	-0.0013266367715626894
Q8N490

## **Logistic Regression with overlapping groups**

In [5]:
brca_data = DataFrame(CSV.File(transcode(GzipDecompressor, Mmap.mmap(
            ParProx.datadir("exprdata_set2_cut.txt.gz")))))
gene_names = brca_data[!, :Gene]

1043-element Vector{String15}:
 "RFC2"
 "PAX8"
 "THRA"
 "EPHB3"
 "CFL1"
 "YY1"
 "ZPR1"
 "RHOA"
 "GUK1"
 "DSP"
 "RAD21"
 "SF3B2"
 "NDRG1"
 ⋮
 "ORAI3"
 "FAM174B"
 "ZNF335"
 "SEH1L"
 "LOC389906"
 "KIF18B"
 "RACGAP1"
 "CASP8AP2"
 "KLK5"
 "OR7E47P"
 "SCAF4"
 "SNHG17"

In [6]:
# patient names in the gene expression file
patient_barcode = map(x -> string(x), names(brca_data)[2:end])

X = collect(transpose(Matrix(brca_data[:, 2:end])))

subject_data = CSV.File(transcode(GzipDecompressor, Mmap.mmap(
            ParProx.datadir("sample_set2.txt.gz")))) |> DataFrame
normalize(x) = (x .- mean(x; dims=1))./ std(x; dims=1)
Age  = normalize(subject_data[!, :Age])
interceptcol = ones(size(X, 1))

X_unpen = [Age interceptcol] 

y = subject_data[!, :PCR_outcome]

# match subject ids
subject_data = filter(x -> x[:GEO_accession] in patient_barcode, subject_data) 
@assert all(subject_data[!, :GEO_accession] .== patient_barcode)

In [7]:
group_info = DataFrame(CSV.File(transcode(GzipDecompressor, Mmap.mmap(
            ParProx.datadir("PATHWAYgroup.txt.gz")))))
first(group_info, 5)

Row,ENSG,Pathwayid,Function,Genesym
Unnamed: 0_level_1,String15,String15,String,String15
1,ENSG00000148584,GO:0006397,mRNA processing,A1CF
2,ENSG00000148584,GO:0010467,gene expression,A1CF
3,ENSG00000148584,GO:0016554,cytidine to uridine editing,A1CF
4,ENSG00000148584,GO:0016556,mRNA modification,A1CF
5,ENSG00000148584,GO:0050821,protein stabilization,A1CF


In [8]:
pathwayids = Vector{String}()
variable_to_groups = [Int[] for i in 1:length(gene_names)] 
group_to_variables = Vector{Vector{Int}}()
genesym_to_variableidx = Dict{String, Int}()
pathwayid_to_groupidx = Dict{String, Int}();

In [9]:
pathwayid_to_pathwayfunction = Dict{String, String}()
genesym_to_genefunction = Dict{String, String}();

In [10]:
for (i, v) in enumerate(gene_names)
    genesym_to_variableidx[v] = i
end

gene_info = CSV.File(transcode(GzipDecompressor, Mmap.mmap(
            ParProx.datadir("gene_info.txt.gz")))) |> DataFrame
for r in eachrow(gene_info)
    genesym_to_genefunction[r[2]] = r[4]
end

In [11]:
num_groups = 0
for r in eachrow(group_info)
    if !(r.Genesym in keys(genesym_to_variableidx)) # skip any gene not in our gene list
        continue
    end
    # get variable index from the gene symbol
    variable_idx = genesym_to_variableidx[r.Genesym]
    
    if !(r.Pathwayid in pathwayids) # if the pathway id is new ... #!haskey(pathwayid_to_genesym, r.Pathwayid) 
        # add a new group
        num_groups += 1
        push!(pathwayids, r.Pathwayid)
        push!(group_to_variables, Int[])
        pathwayid_to_groupidx[r.Pathwayid] = num_groups
        group_idx = num_groups
    else
        group_idx = pathwayid_to_groupidx[r.Pathwayid]
    end
    
    # insert membership information
    push!(variable_to_groups[variable_idx], group_idx)
    push!(group_to_variables[group_idx], variable_idx)
    
    # retrieve pathway information
    if !haskey(pathwayid_to_pathwayfunction, r.Pathwayid)
        pathwayid_to_pathwayfunction[r.Pathwayid] = r.Function
    else
        # if redundant, just check correctness
        @assert pathwayid_to_pathwayfunction[r.Pathwayid] == r.Function
    end
end

In [12]:
cnt = 0
for (i, v) in enumerate(gene_names)
    if length(variable_to_groups[i]) > 0 # gene is in some of the groups
        continue
    else # gene is not in any of the groups
        num_groups += 1
        cnt += 1
        # update group info
        push!(variable_to_groups[i], num_groups)
        push!(group_to_variables, [i])
        
        # update variable names
        pathwayid_placeholder = 
            if haskey(genesym_to_genefunction, v)
                "Singleton of: $v"
            else # some genes do not appear in the gene function list
                "Singleton of: $v"
            end
        push!(pathwayids, pathwayid_placeholder)
    end
end
cnt

90

In [13]:
variable_names = [gene_names; "Age\t\t"; "intercept\t\t"]

1045-element Vector{String}:
 "RFC2"
 "PAX8"
 "THRA"
 "EPHB3"
 "CFL1"
 "YY1"
 "ZPR1"
 "RHOA"
 "GUK1"
 "DSP"
 "RAD21"
 "SF3B2"
 "NDRG1"
 ⋮
 "ZNF335"
 "SEH1L"
 "LOC389906"
 "KIF18B"
 "RACGAP1"
 "CASP8AP2"
 "KLK5"
 "OR7E47P"
 "SCAF4"
 "SNHG17"
 "Age\t\t"
 "intercept\t\t"

In [14]:
pathways_with_functions = map(pathwayids) do x
    if haskey(pathwayid_to_pathwayfunction, x)
        "$x: $(pathwayid_to_pathwayfunction[x])"
    else
        x
    end
end

5719-element Vector{String}:
 "KEGG:00280: Valine, leucine and isoleucine degradation"
 "KEGG:04727: GABAergic synapse"
 "KEGG:01100: Metabolic pathways"
 "KEGG:00250: Alanine, aspartate and glutamate metabolism"
 "KEGG:00640: Propanoate metabolism"
 "KEGG:00410: beta-Alanine metabolism"
 "KEGG:00650: Butanoate metabolism"
 "GO:0001666: response to hypoxia"
 "GO:0007268: synaptic transmission"
 "GO:0007269: neurotransmitter secretion"
 "GO:0007620: copulation"
 "GO:0007626: locomotory behavior"
 "GO:0009449: gamma-aminobutyric acid biosynthetic process"
 ⋮
 "Singleton of: TEX13B"
 "Singleton of: IQCG"
 "Singleton of: LOC100131532"
 "Singleton of: C6orf25"
 "Singleton of: TMEM14B"
 "Singleton of: FAM64A"
 "Singleton of: GPR124"
 "Singleton of: FAM174B"
 "Singleton of: LOC389906"
 "Singleton of: OR7E47P"
 "Singleton of: SCAF4"
 "Singleton of: SNHG17"

In [15]:
variable_names_replicated = String[]
cnt = 0
for i in 1:length(group_to_variables)
    for v in group_to_variables[i]
        pf = pathways_with_functions[i]
        g = gene_names[v]
        gt = genesym_to_genefunction[g]
        push!(variable_names_replicated, "$pf\t$g\t$gt")
    end
end

In [16]:
variable_names_replicated = [variable_names_replicated; "Age\t\t"; "intercept\t\t"] 

21965-element Vector{String}:
 "KEGG:00280: Valine, leucine and" ⋯ 30 bytes ⋯ "-aminobutyrate aminotransferase"
 "KEGG:00280: Valine, leucine and" ⋯ 39 bytes ⋯ "ing protein/enoyl-CoA hydratase"
 "KEGG:00280: Valine, leucine and" ⋯ 41 bytes ⋯ "thyl-3-methylglutaryl-CoA lyase"
 "KEGG:00280: Valine, leucine and" ⋯ 25 bytes ⋯ "VD\tisovaleryl-CoA dehydrogenase"
 "KEGG:00280: Valine, leucine and" ⋯ 40 bytes ⋯ "onoyl-CoA carboxylase 1 (alpha)"
 "KEGG:00280: Valine, leucine and" ⋯ 39 bytes ⋯ "tonoyl-CoA carboxylase 2 (beta)"
 "KEGG:00280: Valine, leucine and" ⋯ 21 bytes ⋯ "on\tMUT\tmethylmalonyl CoA mutase"
 "KEGG:04727: GABAergic synapse\tABAT\t4-aminobutyrate aminotransferase"
 "KEGG:04727: GABAergic synapse\tG" ⋯ 25 bytes ⋯ "acid (GABA) A receptor, alpha 4"
 "KEGG:04727: GABAergic synapse\tPRKCA\tprotein kinase C, alpha"
 "KEGG:04727: GABAergic synapse\tPRKX\tprotein kinase, X-linked"
 "KEGG:01100: Metabolic pathways\tABAT\t4-aminobutyrate aminotransferase"
 "KEGG:01100: Metabolic pathways\

In [17]:
lambdas = 10 .^ (range(-6, stop=-9, length=31))

31-element Vector{Float64}:
 1.0e-6
 7.943282347242822e-7
 6.30957344480193e-7
 5.011872336272725e-7
 3.981071705534969e-7
 3.162277660168379e-7
 2.5118864315095823e-7
 1.9952623149688787e-7
 1.584893192461114e-7
 1.2589254117941662e-7
 1.0000000000000001e-7
 7.943282347242822e-8
 6.30957344480193e-8
 ⋮
 1.2589254117941661e-8
 1.0e-8
 7.943282347242822e-9
 6.309573444801943e-9
 5.011872336272715e-9
 3.981071705534969e-9
 3.1622776601683795e-9
 2.511886431509582e-9
 1.9952623149688828e-9
 1.584893192461111e-9
 1.2589254117941663e-9
 1.0e-9

In [22]:
group_to_variables

5719-element Vector{Vector{Int64}}:
 [489, 386, 157, 243, 876, 661, 174]
 [489, 602, 515, 282]
 [489, 461, 522, 234, 981, 536, 41, 530, 386, 726  …  709, 926, 221, 514, 591, 137, 363, 206, 423, 594]
 [489, 234, 530, 704]
 [489, 523, 174]
 [489]
 [489, 234, 726, 157]
 [489, 171, 690, 86, 180, 826, 614, 496, 641, 49  …  44, 186, 541, 183, 677, 666, 591, 36, 634, 707]
 [489, 234, 500, 418, 653, 496, 307, 62, 589, 602  …  428, 340, 732, 515, 44, 211, 481, 159, 993, 274]
 [489, 234, 481, 159, 993, 274, 697]
 [489, 246]
 [489, 569, 496, 494, 165, 540, 122, 281, 591]
 [489, 159]
 ⋮
 [1003]
 [1009]
 [1010]
 [1017]
 [1019]
 [1024]
 [1030]
 [1033]
 [1036]
 [1041]
 [1042]
 [1043]

In [38]:
using Random
Random.seed!(222)
T = Float64
A = CuArray
U = ParProx.LogisticUpdate(; maxiter=30000, step=20, tol=5e-4, verbose=true)
lambdas = 10 .^ (range(-6, stop=-9, length=31))

scores = ParProx.cross_validate(U, 
    adapt(A{T}, X), adapt(A{T}, X_unpen), adapt(A{Int32}, y), group_to_variables, lambdas, 5; 
    T=Float64, criteria=auc)

20	(-0.6931278407968279, Inf, 24.0)
40	(-0.6931085039174216, 1.1420933366946265e-5, 24.0)
  0.965229 seconds (3.62 M allocations: 188.632 MiB, 4.58% gc time, 84.16% compilation time)
20	(-0.6930719830960618, Inf, 101.0)
40	(-0.693035492909439, 2.1553113786212154e-5, 101.0)
  0.005232 seconds (30.97 k allocations: 1.410 MiB)
20	(-0.6929267374464422, Inf, 526.0)
40	(-0.692818447699606, 6.397008904489673e-5, 526.0)
  0.004937 seconds (30.97 k allocations: 1.410 MiB)
20	(-0.6924619379860725, Inf, 3343.0)
40	(-0.6921116026960825, 0.0002070402977154451, 3309.0)
  0.004891 seconds (30.97 k allocations: 1.410 MiB)
20	(-0.6910827445082395, Inf, 10166.0)
40	(-0.6900947584096456, 0.0005845743818077585, 10063.0)
60	(-0.6891454048844164, 0.0005620318549747001, 9976.0)
80	(-0.6882328789844652, 0.0005405213411672185, 9915.0)
100	(-0.6873551759353641, 0.0005201649668182754, 9700.0)
120	(-0.6865096828074815, 0.0005013271708438443, 9501.0)
140	(-0.6856943820469292, 0.00048365870423216096, 9342.0)
  0.01

31×5 Matrix{Float64}:
 0.621754  0.621754  0.668563  0.548421  0.567719
 0.693333  0.670877  0.774538  0.780351  0.712281
 0.715088  0.676491  0.943812  0.811228  0.711579
 0.729825  0.703158  0.931721  0.812632  0.723509
 0.756491  0.727719  0.918919  0.81614   0.73614
 0.762807  0.734737  0.914651  0.812632  0.735439
 0.765614  0.740351  0.916785  0.81193   0.733333
 0.767018  0.742456  0.921764  0.813333  0.73614
 0.766316  0.743158  0.923186  0.814737  0.738246
 0.767719  0.74807   0.926031  0.814737  0.740351
 0.767018  0.748772  0.927454  0.815439  0.741754
 0.767018  0.74807   0.928165  0.814737  0.740351
 0.767719  0.745263  0.928876  0.815439  0.739649
 ⋮                                       
 0.775439  0.744561  0.932432  0.815439  0.744561
 0.775439  0.744561  0.932432  0.81614   0.74386
 0.77614   0.74386   0.932432  0.81614   0.745965
 0.77614   0.745263  0.932432  0.81614   0.747368
 0.777544  0.744561  0.932432  0.814035  0.747368
 0.778947  0.744561  0.933144  0.814737

In [39]:
lambda_idx = argmax(mean(scores; dims=2)[:])
lambda = lambdas[lambda_idx]

1.2589254117941663e-9

In [40]:
U = ParProx.LogisticUpdate(; maxiter=30000, step=20, tol=1e-5, verbose=true)
V = ParProx.LogisticVariables{Float64}(adapt(A{T}, X), 
    adapt(A{T}, X_unpen), adapt(A{Int}, y), lambda, group_to_variables)
@time ParProx.fit!(U, V)

20	(-0.6882550184514394, Inf, 21956.0)
40	(-0.6837195250794479, 0.002693734499383167, 21956.0)
60	(-0.6795117875823292, 0.0025053337096107796, 21965.0)
80	(-0.6756051178381279, 0.0023314978586611657, 21965.0)
100	(-0.6719749245958293, 0.002171200769159959, 21965.0)
120	(-0.6685985823457865, 0.0020234598577305463, 21965.0)
140	(-0.6654553003492103, 0.0018873409547029748, 21965.0)
160	(-0.6625259943670732, 0.0017619610111734433, 21965.0)
180	(-0.6597931594205434, 0.001646491269721723, 21965.0)
200	(-0.6572407509578937, 0.0015401555031605378, 21965.0)
220	(-0.6548540705430974, 0.0014422301381613735, 21965.0)
240	(-0.6526196585152259, 0.0013520425080015294, 21965.0)
260	(-0.6505251938689246, 0.0012689686010740066, 21965.0)
280	(-0.6485593985535516, 0.0011924322029875233, 21965.0)
300	(-0.6467119475916011, 0.0011219029318712856, 21965.0)
320	(-0.6449733953001221, 0.0010568877870282007, 21965.0)
340	(-0.643335091809648, 0.0009969381769058471, 21965.0)
360	(-0.6417891189675797, 0.000941639108

In [41]:
_, grpmat, _ = ParProx.mapper_mat_idx(group_to_variables, length(gene_names));
size(grpmat)

(1043, 21963)

In [42]:
β_orig = vcat(grpmat * collect(V.β[1:end-2]), collect(V.β)[end-1:end]);

In [43]:
for (v, β) in zip(variable_names_replicated[V.β .!= 0], V.β[V.β .!= 0])
    println("$v\t$β")
end

KEGG:00280: Valine, leucine and isoleucine degradation	ABAT	4-aminobutyrate aminotransferase	0.0008323543838552684
KEGG:00280: Valine, leucine and isoleucine degradation	AUH	AU RNA binding protein/enoyl-CoA hydratase	-0.0014593425218672941
KEGG:00280: Valine, leucine and isoleucine degradation	HMGCL	3-hydroxymethyl-3-methylglutaryl-CoA lyase	-0.0010615724131936334
KEGG:00280: Valine, leucine and isoleucine degradation	IVD	isovaleryl-CoA dehydrogenase	-0.0009348557467914733
KEGG:00280: Valine, leucine and isoleucine degradation	MCCC1	methylcrotonoyl-CoA carboxylase 1 (alpha)	8.306609564157845e-5
KEGG:00280: Valine, leucine and isoleucine degradation	MCCC2	methylcrotonoyl-CoA carboxylase 2 (beta)	-0.0007834943178521192
KEGG:00280: Valine, leucine and isoleucine degradation	MUT	methylmalonyl CoA mutase	0.002754887637335804
KEGG:04727: GABAergic synapse	ABAT	4-aminobutyrate aminotransferase	0.0008232735040322656
KEGG:04727: GABAergic synapse	GABRA4	gamma-aminobutyric acid (GABA) A receptor

Excessive output truncated after 524339 bytes.

KEGG:05166: HTLV-I infection	PDGFRB	platelet-derived growth factor receptor, beta polypeptide	-0.00036110176032153475
KEGG:05166: HTLV-I infection	PIK3CD	phosphatidylinositol-4,5-bisphosphate 3-kinase, catalytic subunit delta	-0.0007803531999851555
KEGG:05166: HTLV-I infection	POLD1	polymerase (DNA directed), delta 1, catalytic subunit	-0.00044447607543509694
KEGG:05166: HTLV-I infection	POLD2	polymerase (DNA directed), delta 2, accessory subunit	0.0007395022041768161
KEGG:05166: HTLV-I infection	POLD3	polymerase (DNA-directed), delta 3, accessory subunit	0.0017395009022891312
KEGG:05166: HTLV-I infection	POLD4	polymerase (DNA-directed), delta 4, accessory subunit	-0.0006380252484317265
KEGG:05166: HTLV-I infection	POLE	polymerase (DNA directed), epsilon, catalytic subunit	0.0029465902557444783
KEGG:05166: HTLV-I infection	POLE2	polymerase (DNA directed), epsilon 2, accessory subunit	0.0013251313195992874
KEGG:05166: HTLV-I infection	POLE3	polymerase (DNA directed), epsilon 3, accessor

In [44]:
for (v, β) in zip(variable_names[β_orig .!= 0], β_orig[β_orig .!= 0])
    println("$v\t$β")
end

RFC2	0.017828368438572818
PAX8	0.03684757263825838
THRA	0.0280671041410603
EPHB3	0.019420775056801853
CFL1	-0.030460548008858876
YY1	0.013402782148339785
ZPR1	-0.004170459880499052
RHOA	0.043220987345152805
GUK1	0.05284160850694734
DSP	-0.0008924328419420445
RAD21	-0.0016314687869252976
SF3B2	0.004117778158978442
NDRG1	-0.006543627891350259
CD63	-0.015885902110306396
DNAJB1	-0.002074023780979478
XBP1	0.021161231401654882
SPTBN1	0.03684346301197065
HMGB1	-0.03075659133448801
GPX1	0.02123400166278225
NUMA1	0.021341449861578796
ARL6IP5	0.0010924165303893335
STMN1	0.007247922658647789
ODC1	0.009846895321239725
DAZAP2	-0.000785501138690692
TMBIM6	-0.0022818758187170543
GSTP1	0.008155196317299402
ZNF207	-0.0027205559055907868
DHCR24	-0.03582950873937385
DEK	0.017525129994537966
HIF1A	-0.04980123716866088
SNX17	0.011478187829172756
CD9	-0.03889279905594768
JUP	-0.03819876252318151
TGM2	-0.0017271186714704757
MMP2	-0.02598193106155893
THBS1	-0.08446414608472251
TUFM	-0.006526014437838651
POLD2