In [None]:
using Pkg
using CSV 
using DataFrames
using Dates
using CUDA
include("embeddings.jl")


## Loading Leucegene dataset
### loading complete transcriptomic profile of pronostic subset 


In [None]:
filename = "/u/sauves/leucegene-shared/Data/lgn_pronostic_GE_TRSC_TPM.csv"
#GE_TRSC_TPM = DataFrame(CSV.File(filename))
@time GE_TRSC_TPM = CSV.read(filename, DataFrame)
print()


### loading LSC17 expressions only


In [None]:
filename = "/u/sauves/leucegene-shared/Data/SIGNATURES/LSC17_lgn_pronostic_expressions.csv"
LSC17_TPM  = CSV.read(filename, DataFrame)
print()


In [None]:
### Loading Clinical Features file


In [None]:
filename = "/u/sauves/leucegene-shared/Data/LEUCEGENE/lgn_pronostic_CF"
CF = CSV.read(filename, DataFrame)
print()


In [None]:
mutable struct Data
    name::String 
    data::Matrix
    d1_index::Array{String,1}
    d2_index::Array{String,1}
    d3_index::Array{Int32,1}
end


In [None]:
cyt_grp = CF[:,"Cytogenetic group"]
grp_unq = unique(cyt_grp)
dct_grp = Dict([(val, i) for (i, val) in enumerate(grp_unq)])
groups = [get(dct_grp, grp, -1) for grp in cyt_grp]
print("done")


In [None]:

data_matrix = Data("LSC17", 
Matrix(LSC17_TPM[:,2:end]), 
LSC17_TPM[:,1], 
names(LSC17_TPM[:,2:end]), 
groups)
data_matrix.d3_index[2]


In [None]:
function prep_data(data::Data; device = gpu)
    ## data preprocessing
    ### remove index columns, log transform
    n = length(data.d1_index)
    m = length(data.d2_index)
    values = Array{Float32,2}(undef, (1, n * m))
    print(size(values))
    d1_index = Array{Int32,1}(undef, n * m)
    d2_index = Array{Int32,1}(undef, n * m)
    d3_index = Array{Int32,1}(undef, n * m)
    for i in 1:n
        for j in 1:m
            index = (i - 1) * m + j 
            values[1, index] = data.data[i, j]
            d1_index[index] = i # Int
            d2_index[index] = j # Int 
            d3_index[index] = data.d3_index[i] # Int 
        end
    end
    return (device(d1_index), device(d2_index), device(d3_index)), device(values)
end 


## Experiments 
### 1 - Training Factorized Embeddings on Leucegene data using patient (factor 1), LSC17 gene expression (factor 2), and cytogenetic group (factor 3) embedding layers. 


In [None]:
X, Y = prep_data(data_matrix)

data = Flux.Data.DataLoader((X, Y), batchsize = 4096)
train_plot(data, X,Y, (2,2,2), "embeddings_$(now())", data_matrix, 2000)
## training 

## plotting results 
### scatterplot - predicted expr. vs true 
### training curve - MSE vs epoch
### scatterplot - trained embedding (UMAP) - colors by cyto-group  
data_matrix.d1_index


### 2 - Training Factorized Embeddings on Leucegene data using all Clinical Factors and Transcriptome Profile. Report accuracy on test set.
Embedding layers: patient (factor 1), gene* (factor 2), cyto. group (factor 3), NPM1 mutation (factor 4), FLT3-ITD mutation (factor 5), IDH1 mutation (factor6), sex (factor 7), age_gt_60 (factor 8). Predict gene expression. *Patient gene expressions is a vector containing 50% most varying genes across dataset. 


In [None]:
# set params
## FE 
# fix archictecture (hidden layers and size)
# fix regularization (L2/weight decay)
# fix optim parameters (mini_bsize, nb_epochs, lr) 
nepochs = 2000
# vary emb size
# set factors
factors = Array("patient", "gene", "Cytogenetic group", "NPM1 mutation", 
"FLT3-ITD mutation", "IDH1 mutation", "Sex", "age_gt_60")
# create data structure 
dataset = create_dataset(gene_exp_fpath, clin_f_fpath)
# split train test
splits = split_train_test(dataset, 5)
# cycle through folds 
for fold_data in splits
    # prep data 
    X_train, Y_train = prep_data(fold_data["train"], factors)
    # train
    train_data_loader = Flux.Data.DataLoader((X_train, Y_train), batchsize = 4096)
    model = train_plot(train_data_loader, X_train, Y_train, emb_sizes, "embeddings_$(now())", fold_data["train"], nepochs)
    # save embeddings, cphdnn_train_data
    # test (interpolate, report R2)
    X_test, Y_test = prep_data(fold_data["test"], factors)
    test_data_loader =  Flux.Data.DataLoader((X_test, Y_test), batchsize = 4096)
    # save embeddings, cphdnn_test_data
    evaluate(test_data_loader, X_test, Y_test, model)
    