In [None]:
using Pkg
using CSV 
using DataFrames
using Dates
include("embeddings.jl")
## Loading Leucegene dataset
### loading complete transcriptomic profile of pronostic subset 
filename = "/u/sauves/leucegene-shared/Data/lgn_pronostic_GE_TRSC_TPM.csv"
#GE_TRSC_TPM = DataFrame(CSV.File(filename))
@time GE_TRSC_TPM = CSV.read(filename, DataFrame)
print()
### loading LSC17 expressions only
filename = "/u/sauves/leucegene-shared/Data/SIGNATURES/LSC17_lgn_pronostic_expressions.csv"
LSC17_TPM  = CSV.read(filename, DataFrame)
print()
### Loading Clinical Features file
filename = "/u/sauves/leucegene-shared/Data/LEUCEGENE/lgn_pronostic_CF"
CF = CSV.read(filename, DataFrame)
print()
mutable struct Data
    name::String 
    data::Matrix
    d1_index::Array{String,1}
    d2_index::Array{String,1}
    d3_index::Array{Int32,1}
end
cyt_grp = CF[:,"Cytogenetic group"]
grp_unq = unique(cyt_grp)
dct_grp = Dict([(val, i) for (i, val) in enumerate(grp_unq)])
groups = [get(dct_grp, grp, -1) for grp in cyt_grp]
print("done")
typeof(groups)


data_matrix = Data("LSC17", 
Matrix(LSC17_TPM[:,2:end]), 
LSC17_TPM[:,1], 
names(LSC17_TPM[:,2:end]), 
groups)
data_matrix.d3_index[2]
function prep_data(data::Data; device = gpu)
    ## data preprocessing
    ### remove index columns, log transform
    n = length(data.d1_index)
    m = length(data.d2_index)
    values = Array{Float32,2}(undef, (1, n * m))
    print(size(values))
    d1_index = Array{Int32,1}(undef, n * m)
    d2_index = Array{Int32,1}(undef, n * m)
    d3_index = Array{Int32,1}(undef, n * m)
    for i in 1:n
        for j in 1:m
            index = (i - 1) * m + j 
            values[1, index] = data.data[i, j]
            d1_index[index] = i # Int
            d2_index[index] = j # Int 
            d3_index[index] = data.d3_index[i] # Int 
        end
    end
    return (device(d1_index), device(d2_index), device(d3_index)), device(values)
end 
X_, Y_ = prep_data(data_matrix)


## Training Factorized Embedding models on Leucegene
### Experiment 1: training on Leucegene (300 samples) with LSC17 gene expressions.
X, Y = prep_data(data_matrix)

data = Flux.Data.DataLoader((X, Y), batchsize = 4096)
train_plot(data, X,Y, (2,2,2), "embeddings_$(now())", data_matrix, 2000)
## training 

## plotting results 
### scatterplot - predicted expr. vs true 
### training curve - MSE vs epoch
### scatterplot - trained embedding (UMAP) - colors by cyto-group  
data_matrix.d1_index
