# ISMB figures notebook

In [None]:
include("engines/init.jl")
include("engines/data_processing.jl")
include("engines/deep_learning.jl")
include("engines/cross_validation.jl")
outpath, session_id = set_dirs() ;

In [None]:
BRCA = MLSurvDataset("Data/TCGA_BRCA_tpm_n1049_btypes_labels_surv.h5")

In [None]:
LAML = MLSurvDataset("Data/LGN_AML_tpm_n300_btypes_labels_surv.h5") 

In [None]:
BRCA_pcoding = BRCA.biotypes .== "protein_coding"
println(sum(BRCA_pcoding))
println(size(BRCA.samples)[1])
println(mean(BRCA.surve .!= 0))


In [None]:
pcoding = [occursin("protein_coding", bt) for bt in LAML.biotypes]
println(size(LAML.genes[pcoding])[1])
println(mean(LAML.surve .!= 0))

In [None]:
# testing VAE-Cox
# set hyper parameters 
DATA = BRCA
keep = BRCA_pcoding
# split train test
folds = split_train_test(Matrix(DATA.data[:,keep]), DATA.survt, DATA.surve, DATA.samples;nfolds =5)


In [None]:
device!()

In [None]:
# testing VAE-Cox
# set hyper parameters 
DATA,nfolds, nepochs, dim_redux = LAML, 5, 100, 125
keep = [occursin("protein_coding", bt) for bt in DATA.biotypes]
println(sum(keep))
println(size(DATA.samples)[1])
println(mean(DATA.surve .!= 0))
params_dict = Dict(
        ## run infos 
        "session_id" => session_id, "nfolds" =>5,  "modelid" => "$(bytes2hex(sha256("$(now())"))[1:Int(floor(end/3))])",
        "machine_id"=>strip(read(`hostname`, String)), "device" => "$(device())", "model_title"=>"AECPHDNN",
        ## data infos 
        "dataset" => "BRCA_data(norm=true)", "nsamples" => size(DATA.samples)[1],
        "nsamples_test" => Int(round(size(DATA.samples)[1] / nfolds)), "ngenes" => size(DATA.genes[keep])[1],
        "nsamples_train" => size(DATA.samples)[1] - Int(round(size(DATA.samples)[1] / nfolds)),
        ## optim infos 
        "nepochs" => nepochs, "ae_lr" =>1e-5, "cph_lr" => 1e-5, "ae_wd" => 1e-6, "cph_wd" => 1e-6,
        ## model infos
        "model_type"=> "aecphdnn", "dim_redux" => dim_redux, "ae_nb_hls" => 2, "ae_hl_size"=> 128,
        "enc_nb_hl" => 2, "enc_hl_size"=> 128,  "dec_nb_hl" => 2 , "dec_hl_size"=> 128,
        "nb_clinf" => 0, "cph_nb_hl" => 2, "cph_hl_size" => 64, 
        "insize" => size(DATA.genes[keep])[1],
        ## metrics
        "model_cv_complete" => false
    )
# split train test
folds = split_train_test(Matrix(DATA.data[:,keep]), DATA.survt, DATA.surve, DATA.samples;nfolds =5)
fold = folds[1]
# format input data  
train_x, train_y_t, train_y_e, NE_frac_tr, test_x, test_y_t, test_y_e, NE_frac_tst = format_train_test(fold)
# create model 
model = build_vaecox(params_dict)
# train model 
## gradient CPH            

# report learning curves
# test model
# report c-index

In [None]:
for iter in 1:nepochs
    ps1 = Flux.params(model["cph"].model, model["enc"])
    gs1 = gradient(ps1) do
        model["cph"].lossf(model["cph"],model["enc"], train_x, train_y_e, NE_frac_tr, params_dict["cph_wd"])
    end 
    ## gradient Auto-Encoder 
    ps2 = Flux.params(model["ae"].net)
    gs2 = gradient(ps2) do
        model["ae"].lossf(model["ae"], train_x, train_x, weight_decay = params_dict["ae_wd"])
    end
    Flux.update!(model["cph"].opt, ps1, gs1)
    Flux.update!(model["ae"].opt, ps2, gs2)

    ######
    OUTS_tr = vec(model["cph"].model(model["enc"](train_x)))
    ae_loss = model["ae"].lossf(model["ae"], train_x, train_x, weight_decay = params_dict["ae_wd"])
    ae_cor =  round(my_cor(vec(train_x), vec(model["ae"].net(train_x))),digits = 3)
    cph_loss = model["cph"].lossf(model["cph"],model["enc"](train_x), train_y_e, NE_frac_tr, params_dict["cph_wd"])
    ae_loss_test = round(model["ae"].lossf(model["ae"], test_x, test_x, weight_decay = params_dict["ae_wd"]), digits = 3)
    ae_cor_test = round(my_cor(vec(test_x), vec(model["ae"].net(test_x))), digits= 3)
    cph_loss_test = round(model["cph"].lossf(model["cph"],model["enc"](test_x), test_y_e, NE_frac_tst, params_dict["cph_wd"]), digits= 3)
                    
    OUTS_tst =  vec(model["cph"].model(model["enc"](test_x)))
            
    cind_tr, cdnt_tr, ddnt_tr, tied_tr  = concordance_index(train_y_t, train_y_e, OUTS_tr)
    cind_test,cdnt_tst, ddnt_tst, tied_tst = concordance_index(test_y_t, test_y_e,OUTS_tst)
    if iter % 10 == 0       
        println("FOLD $(fold["foldn"]) $iter\t TRAIN AE-loss $(round(ae_loss,digits =3)) \t AE-cor: $(round(ae_cor, digits = 3))\t cph-loss-avg: $(round(cph_loss / params_dict["nsamples_train"],digits =6)) \t cph-cind: $(round(cind_tr,digits =3))")
        println("\t\tTEST AE-loss $(round(ae_loss_test,digits =3)) \t AE-cor: $(round(ae_cor_test, digits = 3))\t cph-loss-avg: $(round(cph_loss_test / params_dict["nsamples_test"],digits =6)) \t cph-cind: $(round(cind_test,digits =3)) [$(Int(cdnt_tst)), $(Int(ddnt_tst)), $(Int(tied_tst))]")
    end
end 

In [102]:

struct VariationalEncoder
    linear
    mu
    log_sigma
end
    
function VariationalEncoder(input_dim::Int, latent_dim::Int, hidden_dim::Int;device = gpu) 
    return Encoder(
    device(Dense(input_dim, hidden_dim, leakyrelu)),   # linear
    device(Dense(hidden_dim, latent_dim)),        # mu
    device(Dense(hidden_dim, latent_dim)))        # log sigma
end 

function (encoder::VariationalEncoder)(x)
    h = encoder.linear(x)
    encoder.mu(h), encoder.log_sigma(h)
end
Decoder(input_dim::Int, latent_dim::Int, hidden_dim::Int;device = gpu) = Chain(
    device(Dense(latent_dim, hidden_dim, leakyrelu)),
    device(Dense(hidden_dim, input_dim))
)
function MyReconstruct(encoder, decoder, x;device=gpu)
    mu, log_sigma = encoder(x)
    z = mu + device(randn(Float32, size(log_sigma))) .* exp.(log_sigma)
    mu, log_sigma, decoder(z)
end

MyReconstruct (generic function with 1 method)

In [111]:
# KL-divergence
function VAE_lossf(venc, vdec, X)
    mu, log_sigma, decoder_z = MyReconstruct(venc, vdec, X);
    nb_samples = size(X)[2]
    kl = 0.5f0 * sum(@. exp(log_sigma * 2f0) + mu ^ 2 - 1f0 - 2 * log_sigma) / nb_samples;
    mse = Flux.mse(X, decoder_z,agg=sum) / nb_samples
    return kl + mse 
end 

VAE_lossf (generic function with 1 method)

In [121]:

venc = VariationalEncoder(size(DATA.data[:,keep])[2], 125, 600)
vdec = Decoder(size(DATA.data[:,keep])[2], 125, 600)
opt = Flux.ADAM(1e-4)


Adam(0.0001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

In [122]:
for i in 1:10000
ps = Flux.params(venc, vdec)
gs = gradient(ps) do
    VAE_lossf(venc, vdec, train_x)
end 
VAE_loss = VAE_lossf(venc, vdec, train_x)
VAE_cor = round(my_cor(vec(train_x), vec(MyReconstruct(venc, vdec, train_x)[end])),digits = 3)
VAE_test = round(my_cor(vec(test_x), vec(MyReconstruct(venc, vdec, test_x)[end])),digits = 3)

Flux.update!(opt, ps, gs)
if i % 100 == 0
println("TRAIN - VAE-loss-avg: $VAE_loss\tVAE-cor: $VAE_cor \tTEST - VAE-cor: $VAE_test")
end
end

TRAIN - VAE-loss-avg: 3522.0117	VAE-cor: 0.887 	TEST - VAE-cor: 0.878


TRAIN - VAE-loss-avg: 2874.2417	VAE-cor: 0.922 	TEST - VAE-cor: 0.918


TRAIN - VAE-loss-avg: 2655.922	VAE-cor: 0.935 	TEST - VAE-cor: 0.938


TRAIN - VAE-loss-avg: 2509.874	VAE-cor: 0.943 	TEST - VAE-cor: 0.947


TRAIN - VAE-loss-avg: 2432.244	VAE-cor: 0.949 	TEST - VAE-cor: 0.951


TRAIN - VAE-loss-avg: 2373.6978	VAE-cor: 0.953 	TEST - VAE-cor: 0.954


TRAIN - VAE-loss-avg: 2309.1296	VAE-cor: 0.955 	TEST - VAE-cor: 0.957


TRAIN - VAE-loss-avg: 2277.393	VAE-cor: 0.958 	TEST - VAE-cor: 0.96


TRAIN - VAE-loss-avg: 2234.5557	VAE-cor: 0.961 	TEST - VAE-cor: 0.962


TRAIN - VAE-loss-avg: 2201.6514	VAE-cor: 0.962 	TEST - VAE-cor: 0.963


TRAIN - VAE-loss-avg: 2185.481	VAE-cor: 0.964 	TEST - VAE-cor: 0.965


TRAIN - VAE-loss-avg: 2158.3784	VAE-cor: 0.965 	TEST - VAE-cor: 0.967


TRAIN - VAE-loss-avg: 2137.4517	VAE-cor: 0.968 	TEST - VAE-cor: 0.968


TRAIN - VAE-loss-avg: 2123.0598	VAE-cor: 0.968 	TEST - VAE-cor: 0.968


TRAIN - VAE-loss-avg: 2106.874	VAE-cor: 0.969 	TEST - VAE-cor: 0.969


TRAIN - VAE-loss-avg: 2092.8774	VAE-cor: 0.97 	TEST - VAE-cor: 0.969


TRAIN - VAE-loss-avg: 2085.3708	VAE-cor: 0.97 	TEST - VAE-cor: 0.97


TRAIN - VAE-loss-avg: 2072.2346	VAE-cor: 0.971 	TEST - VAE-cor: 0.97


TRAIN - VAE-loss-avg: 2073.3389	VAE-cor: 0.972 	TEST - VAE-cor: 0.971


TRAIN - VAE-loss-avg: 2051.6719	VAE-cor: 0.972 	TEST - VAE-cor: 0.971


TRAIN - VAE-loss-avg: 2048.0503	VAE-cor: 0.973 	TEST - VAE-cor: 0.971


TRAIN - VAE-loss-avg: 2040.255	VAE-cor: 0.973 	TEST - VAE-cor: 0.972


TRAIN - VAE-loss-avg: 2036.3137	VAE-cor: 0.973 	TEST - VAE-cor: 0.972


TRAIN - VAE-loss-avg: 2023.928	VAE-cor: 0.975 	TEST - VAE-cor: 0.973


TRAIN - VAE-loss-avg: 2026.8007	VAE-cor: 0.974 	TEST - VAE-cor: 0.972


TRAIN - VAE-loss-avg: 2014.8679	VAE-cor: 0.975 	TEST - VAE-cor: 0.973


TRAIN - VAE-loss-avg: 2005.923	VAE-cor: 0.975 	TEST - VAE-cor: 0.973


TRAIN - VAE-loss-avg: 2008.8706	VAE-cor: 0.975 	TEST - VAE-cor: 0.973


TRAIN - VAE-loss-avg: 2002.4517	VAE-cor: 0.976 	TEST - VAE-cor: 0.973


TRAIN - VAE-loss-avg: 1995.9733	VAE-cor: 0.976 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1995.9058	VAE-cor: 0.976 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1994.1041	VAE-cor: 0.976 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1992.3171	VAE-cor: 0.976 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1993.1235	VAE-cor: 0.976 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1986.3955	VAE-cor: 0.976 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1984.8516	VAE-cor: 0.977 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1986.441	VAE-cor: 0.977 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1978.5842	VAE-cor: 0.977 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1985.529	VAE-cor: 0.977 	TEST - VAE-cor: 0.973


TRAIN - VAE-loss-avg: 1979.0437	VAE-cor: 0.977 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1980.7997	VAE-cor: 0.977 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1978.622	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1977.8146	VAE-cor: 0.978 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1971.5024	VAE-cor: 0.977 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1974.6846	VAE-cor: 0.977 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1974.1614	VAE-cor: 0.977 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1972.4421	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1968.2688	VAE-cor: 0.977 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1967.8777	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1972.5979	VAE-cor: 0.977 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1968.2457	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1968.6609	VAE-cor: 0.978 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1964.4517	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1961.3691	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1964.3784	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1961.5537	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1964.698	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1966.4698	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1964.3489	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1959.436	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1958.1172	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1955.9873	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1962.4176	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1956.3342	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1956.4895	VAE-cor: 0.978 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1956.5356	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1960.15	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1955.7245	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1959.3923	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1949.9253	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1957.6525	VAE-cor: 0.979 	TEST - VAE-cor: 0.976


TRAIN - VAE-loss-avg: 1948.9547	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1946.9344	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1949.4592	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1946.5848	VAE-cor: 0.979 	TEST - VAE-cor: 0.976


TRAIN - VAE-loss-avg: 1944.8589	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1945.2628	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1944.3635	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1948.185	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1949.4038	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1944.261	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1936.4939	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1942.6035	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1939.3389	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1942.9274	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1936.2825	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1932.1849	VAE-cor: 0.979 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1932.427	VAE-cor: 0.98 	TEST - VAE-cor: 0.976


TRAIN - VAE-loss-avg: 1939.3226	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1942.4849	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1931.2852	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1930.1016	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1937.567	VAE-cor: 0.98 	TEST - VAE-cor: 0.976


TRAIN - VAE-loss-avg: 1929.2812	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1929.7426	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1933.022	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1928.1499	VAE-cor: 0.98 	TEST - VAE-cor: 0.976


TRAIN - VAE-loss-avg: 1925.0447	VAE-cor: 0.98 	TEST - VAE-cor: 0.975


TRAIN - VAE-loss-avg: 1927.2032	VAE-cor: 0.98 	TEST - VAE-cor: 0.974


TRAIN - VAE-loss-avg: 1923.4485	VAE-cor: 0.98 	TEST - VAE-cor: 0.976
