# Code For Debugging

In [5]:
import argparse
import time
import warnings

# tools
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
import magic
from tqdm import tqdm
import sys
import os

from mousipy import translate

module_path = os.path.abspath(os.path.join('.', 'src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from ClassFlux import FLUX
from DatasetFlux import MyDataset
# from scFEA
# from scFEA_grad
from util import *

# Set parameters

In [6]:
data_path ='data'
input_path ='input'
res_dir ='output'
test_file ='Seurat_geneExpr.csv'
moduleGene_file ='module_gene_glutaminolysis1_m23.csv'
cm_file ='cmMat_glutaminolysis1_c17_m23.csv'
sc_imputation =True
cName_file='cName_glutaminolysis1_c17_m23.csv'
fileName='output/ad1212_flux.csv'
balanceName='output/ad1212_balance.csv'
EPOCH=100

# Check if CUDA is available
cuda_available = torch.cuda.is_available()

print(f"CUDA available: {cuda_available}")

# If CUDA is available, print the number of CUDA devices and their names
if cuda_available:
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"CUDA Device {i}: {torch.cuda.get_device_name(i)}")


# choose cpu or gpu automatically
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print("now you are using device: ", device)

CUDA available: False
now you are using device:  cpu


# Load preprocessed files

In [87]:
# read data
print("Starting load data...")
# geneExpr = pd.read_csv(
#             input_path + '/' + test_file,
#             index_col=0)
# geneExpr = geneExpr.T
# geneExpr = geneExpr * 1.0
# if sc_imputation == True:
#     magic_operator = magic.MAGIC()
#     with warnings.catch_warnings():
#         warnings.simplefilter("ignore")
#         geneExpr = magic_operator.fit_transform(geneExpr)
# if geneExpr.max().max() > 50:
#     geneExpr = (geneExpr + 1).apply(np.log2)  
# geneExprSum = geneExpr.sum(axis=1)
# stand = geneExprSum.mean()
# geneExprScale = geneExprSum / stand

geneExpr = pd.read_pickle('./input/geneExpr1.pkl')
geneExprScale = pd.read_pickle('./input/geneExprScale1.pkl')

Starting load data...
4486


In [89]:
geneExprScale.shape

torch.Size([4486])

In [None]:
geneExprScale = torch.FloatTensor(geneExprScale.values).to(device)

BATCH_SIZE = geneExpr.shape[0]
print(BATCH_SIZE)

In [None]:
geneExprScale

# Load modules data

In [9]:
moduleGene = pd.read_csv(
            data_path + '/' + moduleGene_file,
            sep=',',
            index_col=0)

In [None]:
moduleGene

In [90]:
moduleGene.shape

(23, 28)

In [91]:
moduleGene_lower = moduleGene.applymap(lambda x: x.lower() if isinstance(x, str) else x)

print(moduleGene_lower)

            A      A.1      A.2     A.3     A.4     A.5      A.6      A.7  \
1                                                                           
M_1     adpgk     g6pc    g6pc2   g6pc3    galm     gck      hk1      hk2   
M_2     aldoa    aldob    aldoc    fbp1    fbp2     gpi     pfkl     pfkm   
M_3      bpgm    cl640  coq10d1    coq2   gapdh  gapdhs     msa1    pgam1   
M_4      bpgm    cl640  coq10d1    coq2    eno1    eno2     eno3   hiper1   
M_5      dlat      dld     dldd    dldh      e3    gcsl      lad    pdha1   
M_6      ldha  ldhal6a  ldhal6b    ldhb    ldhc    ldhd      NaN      NaN   
M_7      acly       cs      NaN     NaN     NaN     NaN      NaN      NaN   
M_8      aco1     aco2     idh1    idh2   idh3a   idh3b    idh3g      NaN   
M_9       dld     dldd     dldh    dlst      e3    gcsl      lad     ogdh   
M_10   sucla2   suclg1   suclg2     NaN     NaN     NaN      NaN      NaN   
M_11     sdha     sdhb      NaN     NaN     NaN     NaN      NaN      NaN   

In [92]:
moduleGene = moduleGene_lower

In [None]:
moduleGene.shape[0]

In [93]:
moduleLen = [moduleGene.iloc[i,:].notna().sum() for i in range(moduleGene.shape[0]) ]
moduleLen = np.array(moduleLen)

# find existing gene
module_gene_all = []
for i in range(moduleGene.shape[0]):
    for j in range(moduleGene.shape[1]):
        if pd.isna(moduleGene.iloc[i,j]) == False:
            module_gene_all.append(moduleGene.iloc[i,j])


In [94]:
print(len(moduleLen))
print(module_gene_all[:10])

['adpgk', 'g6pc', 'g6pc2', 'g6pc3', 'galm', 'gck', 'hk1', 'hk2', 'hk3', 'hkdc1']


In [95]:
data_gene_all = geneExpr.columns

# Convert sets to lowercase
data_gene_all = [gene.lower() for gene in data_gene_all]
module_gene_all = [gene.lower() for gene in module_gene_all]

print("data_gene_all: ", sorted(data_gene_all))
print("module_gene_all: ", sorted(module_gene_all))
print("data_gene_all len: ", len(data_gene_all))
print("module_gene_all len: ", len(module_gene_all))

data_gene_all = set(data_gene_all)
module_gene_all = set(module_gene_all)
print("data_gene_all len: ", len(data_gene_all))
print("module_gene_all len: ", len(module_gene_all))

data_gene_all:  ['aacs', 'aadat', 'aass', 'abat', 'acaa1', 'acaa2', 'acaca', 'acacb', 'acad8', 'acadl', 'acadm', 'acads', 'acadsb', 'acadvl', 'acat1', 'acat2', 'acly', 'aco1', 'aco2', 'acox1', 'acox3', 'acsbg1', 'acsbg2', 'acsf3', 'acsl1', 'acsl3', 'acsl4', 'acsl5', 'acsl6', 'acy3', 'ada', 'adc', 'adi1', 'adk', 'adpgk', 'adsl', 'adss', 'adssl1', 'agmat', 'agxt', 'agxt2', 'ahcy', 'ahcyl1', 'ahcyl2', 'ak1', 'ak2', 'ak3', 'ak4', 'ak5', 'ak7', 'ak8', 'alas1', 'alas2', 'aldh1a3', 'aldh1b1', 'aldh2', 'aldh3a1', 'aldh3a2', 'aldh3b1', 'aldh3b2', 'aldh4a1', 'aldh5a1', 'aldh6a1', 'aldh7a1', 'aldh9a1', 'aldoa', 'aldob', 'aldoc', 'alg10', 'alg10b', 'alg5', 'alg6', 'alg8', 'amd1', 'amdhd1', 'amdhd2', 'ampd1', 'ampd2', 'ampd3', 'anpep', 'aoc2', 'aoc3', 'aox1', 'apip', 'aprt', 'arg1', 'arg2', 'asl', 'asns', 'aspa', 'aspg', 'asrgl1', 'ass1', 'atic', 'auh', 'b3galt6', 'b3galtl', 'b3gat2', 'b3gat3', 'b4galt1', 'b4galt2', 'b4galt3', 'b4galt7', 'bcat1', 'bcat2', 'bckdha', 'bckdhb', 'bhmt', 'bhmt2', 'bpgm'

In [97]:
# Find intersection in lowercase
gene_overlap = list(data_gene_all.intersection(module_gene_all))
gene_overlap.sort()

# Optional: Map back to original case (choosing data_gene_all as source)
# gene_overlap = [gene for gene in data_gene_all if gene.lower() in gene_overlap]

print("Gene overlap:", gene_overlap)
print("Gene overlap len:", len(gene_overlap))

Gene overlap: ['abat', 'acly', 'aco1', 'aco2', 'adpgk', 'aldh5a1', 'aldoa', 'aldob', 'aldoc', 'bpgm', 'coq2', 'cs', 'dlat', 'dld', 'dlst', 'eno1', 'eno2', 'eno3', 'fbp1', 'fbp2', 'fh', 'g6pc', 'g6pc2', 'g6pc3', 'g6pd', 'gad1', 'gad2', 'galm', 'gapdh', 'gapdhs', 'gck', 'gls', 'gls2', 'glud1', 'glud2', 'glul', 'got1', 'got2', 'gpi', 'h6pd', 'hk1', 'hk2', 'hk3', 'hkdc1', 'idh1', 'idh2', 'idh3a', 'idh3b', 'idh3g', 'ldha', 'ldhal6a', 'ldhal6b', 'ldhb', 'ldhc', 'ldhd', 'mdh1', 'mdh2', 'minpp1', 'ogdh', 'ogdhl', 'pck1', 'pck2', 'pdha1', 'pdha2', 'pdhb', 'pfkl', 'pfkm', 'pfkp', 'pgam1', 'pgam2', 'pgam4', 'pgd', 'pgk1', 'pgk2', 'pgls', 'pgm1', 'pgm2', 'phb', 'phgdh', 'pklr', 'pkm', 'prps1', 'prps1l1', 'prps2', 'psat1', 'psph', 'rbks', 'rpe', 'rpia', 'sdha', 'sdhb', 'slc1a1', 'slc1a2', 'slc1a3', 'slc1a5', 'slc1a6', 'slc1a7', 'slc2a1', 'slc2a2', 'slc2a3', 'slc2a4', 'sucla2', 'suclg1', 'suclg2', 'taldo1', 'tkt', 'tktl1', 'tktl2', 'tpi1']
Gene overlap len: 109


In [17]:
# Assuming gene_overlap is a list of genes you're interested in
missing_genes = [gene for gene in gene_overlap if gene not in data_gene_all]
print(f'missing_genes: {missing_genes}')
print(f'missing_genes len: {len(missing_genes)}')

missing_genes: []
missing_genes len: 0


# Load stochiometirc matrix

In [18]:
cmMat = pd.read_csv(
        data_path + '/' + cm_file,
        sep=',',
        header=None)
cmMat = cmMat.values
print(cmMat[:5,:5])
cmMat = torch.FloatTensor(cmMat).to(device)


[[ 0  0  0  1 -1]
 [ 0  0  0  0  1]
 [ 0  0  0  0  0]
 [ 0  0  0  0  0]
 [ 0  0  0  0  0]]


In [19]:
if cName_file != 'noCompoundName':
    print("Load compound name file, the balance output will have compound name.")
    cName = pd.read_csv(
            data_path + '/' + cName_file,
            sep=',',
            header=0)
    cName = cName.columns
print("Load data done.")
print(cName[:5])

Load compound name file, the balance output will have compound name.
Load data done.
Index(['Pyruvate', 'Acetyl-Coa', 'Glutamate', '2OG', 'Oxaloacetate'], dtype='object')


In [20]:
print(gene_overlap)

['abat',
 'acly',
 'aco1',
 'aco2',
 'adpgk',
 'aldh5a1',
 'aldoa',
 'aldob',
 'aldoc',
 'bpgm',
 'coq2',
 'cs',
 'dlat',
 'dld',
 'dlst',
 'eno1',
 'eno2',
 'eno3',
 'fbp1',
 'fbp2',
 'fh',
 'g6pc',
 'g6pc2',
 'g6pc3',
 'g6pd',
 'gad1',
 'gad2',
 'galm',
 'gapdh',
 'gapdhs',
 'gck',
 'gls',
 'gls2',
 'glud1',
 'glud2',
 'glul',
 'got1',
 'got2',
 'gpi',
 'h6pd',
 'hk1',
 'hk2',
 'hk3',
 'hkdc1',
 'idh1',
 'idh2',
 'idh3a',
 'idh3b',
 'idh3g',
 'ldha',
 'ldhal6a',
 'ldhal6b',
 'ldhb',
 'ldhc',
 'ldhd',
 'mdh1',
 'mdh2',
 'minpp1',
 'ogdh',
 'ogdhl',
 'pck1',
 'pck2',
 'pdha1',
 'pdha2',
 'pdhb',
 'pfkl',
 'pfkm',
 'pfkp',
 'pgam1',
 'pgam2',
 'pgam4',
 'pgd',
 'pgk1',
 'pgk2',
 'pgls',
 'pgm1',
 'pgm2',
 'phb',
 'phgdh',
 'pklr',
 'pkm',
 'prps1',
 'prps1l1',
 'prps2',
 'psat1',
 'psph',
 'rbks',
 'rpe',
 'rpia',
 'sdha',
 'sdhb',
 'slc1a1',
 'slc1a2',
 'slc1a3',
 'slc1a5',
 'slc1a6',
 'slc1a7',
 'slc2a1',
 'slc2a2',
 'slc2a3',
 'slc2a4',
 'sucla2',
 'suclg1',
 'suclg2',
 'taldo1',
 't

In [21]:
geneExpr.columns = geneExpr.columns.str.lower()
geneExpr

Unnamed: 0,aldoc,maoa,ass1,ahcyl2,fut4,gstp1,pgm1,pecr,auh,alas1,...,pomt2,gmps,aldob,hkdc1,nme3,uap1,dhodh,coq2,mat2a,slc2a11
Cy72_CD45_H02_S758_comb,4.162494,0.379308,4.751620,2.030544,1.542162,7.404518,4.033231,4.572774,3.218883,3.796396,...,1.271138,5.080247,1.399227,1.592996,4.403850,3.399304,3.949059,3.372828,6.774857,2.964164
CY58_1_CD45_B02_S974_comb,4.859296,0.795616,3.330219,2.361895,1.477859,9.025884,4.023865,4.827697,3.645465,4.742562,...,1.390741,5.914823,0.474377,1.783641,5.430277,4.074199,3.830121,3.561910,7.262892,2.999523
Cy71_CD45_D08_S524_comb,5.182878,1.517675,2.871011,3.260576,0.843524,10.617559,5.981542,4.526658,4.424308,5.099694,...,2.776770,6.036689,0.069217,1.561195,5.346934,5.710514,4.419876,3.388582,7.324919,2.717605
Cy81_FNA_CD45_B01_S301_comb,4.992984,1.954019,3.732040,2.566664,0.991903,9.459421,6.488310,4.646741,3.912212,5.969391,...,2.054426,6.204179,0.028445,3.245249,5.989700,6.570992,4.176897,3.989681,7.641259,3.024334
Cy80_II_CD45_B07_S883_comb,7.029913,2.433470,2.178750,3.863141,1.097812,10.997918,6.783676,3.881353,4.159309,5.196337,...,2.833173,5.992828,0.010164,0.560663,6.023677,5.510684,3.721844,3.588498,7.380202,2.921199
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CY75_1_CD45_CD8_1__S25_comb,5.809091,2.266891,3.929175,2.610841,1.161813,9.125980,6.174220,3.709370,4.082228,5.544936,...,1.407577,6.248338,0.085321,0.995510,5.246235,6.390967,3.894019,4.020771,7.626858,2.519521
CY75_1_CD45_CD8_7__S223_comb,4.916321,1.293391,3.672784,2.980495,1.591621,7.884446,4.333308,4.050476,4.110849,4.542700,...,1.063681,5.355580,0.703826,1.618579,5.394490,3.865086,4.109748,5.272684,6.757629,2.868305
CY75_1_CD45_CD8_1__S65_comb,5.562529,1.122860,3.847335,1.515839,1.804045,7.456344,4.681336,4.647157,5.017220,5.743851,...,0.677716,5.640190,0.146503,1.498265,5.792783,3.713064,4.581443,4.082584,7.277182,2.723543
CY75_1_CD45_CD8_1__S93_comb,4.537577,0.656279,3.308691,2.223227,2.264631,7.513527,4.672429,3.935887,3.533805,4.118017,...,0.830313,5.628801,0.333738,1.492997,5.691360,4.105751,4.079076,4.374409,7.395787,3.103700


In [22]:
geneExpr[gene_overlap] 

Unnamed: 0,abat,acly,aco1,aco2,adpgk,aldh5a1,aldoa,aldob,aldoc,bpgm,...,slc2a3,slc2a4,sucla2,suclg1,suclg2,taldo1,tkt,tktl1,tktl2,tpi1
Cy72_CD45_H02_S758_comb,1.619334,5.393955,3.077637,6.071186,6.555041,3.837932,8.410754,1.399227,4.162494,4.312776,...,7.360305,2.893232,4.508215,5.904954,5.668446,6.262318,6.337298,0.918780,0.262385,7.719445
CY58_1_CD45_B02_S974_comb,3.210308,5.501335,3.340630,5.683188,6.074348,3.392198,9.255990,0.474377,4.859296,5.086820,...,7.121979,2.731376,5.125515,6.388649,5.909568,6.715170,6.647697,1.805171,0.286715,8.676920
Cy71_CD45_D08_S524_comb,1.298633,6.429331,4.407308,6.514881,5.381328,2.495401,10.499494,0.069217,5.182878,5.430829,...,6.413446,1.540063,5.531359,6.683166,5.697473,8.022339,8.721973,2.160871,0.000000,10.128833
Cy81_FNA_CD45_B01_S301_comb,2.827067,6.860292,5.015911,5.927352,5.942245,3.278216,10.610909,0.028445,4.992984,5.757218,...,6.385994,1.677644,6.574049,6.985858,6.482216,8.424013,8.468856,3.927885,0.000000,10.422541
Cy80_II_CD45_B07_S883_comb,0.004257,5.898215,4.751307,6.887203,5.587563,2.628317,12.132136,0.010164,7.029913,4.503378,...,7.299772,1.736276,5.402984,6.923287,5.081079,8.067327,8.913856,0.145644,0.000000,10.760754
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CY75_1_CD45_CD8_1__S25_comb,2.895584,6.519866,4.445059,6.795247,5.810038,3.343531,11.071963,0.085321,5.809091,5.488430,...,7.864660,2.282407,6.646685,6.621746,6.384924,8.361448,7.895446,2.210032,0.000000,10.300524
CY75_1_CD45_CD8_7__S223_comb,2.766231,5.583633,3.532842,5.997116,6.196168,4.229354,9.105878,0.703826,4.916321,5.304255,...,7.209435,2.861909,4.560218,6.426677,5.753618,6.613837,6.931993,1.262908,0.048131,9.551816
CY75_1_CD45_CD8_1__S65_comb,4.186460,6.196514,2.641443,6.436015,6.097563,3.501956,9.817492,0.146503,5.562529,5.421930,...,7.220867,3.053332,5.459949,6.425533,5.996223,7.019283,6.884874,1.367629,0.014555,9.692799
CY75_1_CD45_CD8_1__S93_comb,2.298494,5.831191,3.873472,5.859530,5.908105,3.711157,9.925250,0.333738,4.537577,4.759855,...,7.483905,2.844059,5.060556,5.975061,6.243924,6.387062,6.314195,0.949864,0.069182,8.293973


In [23]:
moduleGene

Unnamed: 0_level_0,A,A.1,A.2,A.3,A.4,A.5,A.6,A.7,A.8,A.9,...,A.18,A.19,A.20,A.21,A.22,A.23,A.24,A.25,A.26,A.27
1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
M_1,adpgk,g6pc,g6pc2,g6pc3,galm,gck,hk1,hk2,hk3,hkdc1,...,,,,,,,,,,
M_2,aldoa,aldob,aldoc,fbp1,fbp2,gpi,pfkl,pfkm,pfkp,tpi1,...,,,,,,,,,,
M_3,bpgm,cl640,coq10d1,coq2,gapdh,gapdhs,msa1,pgam1,pgam2,pgam4,...,,,,,,,,,,
M_4,bpgm,cl640,coq10d1,coq2,eno1,eno2,eno3,hiper1,minpp1,minpp2,...,pklr,pkm,ppt,,,,,,,
M_5,dlat,dld,dldd,dldh,e3,gcsl,lad,pdha1,pdha2,pdhb,...,,,,,,,,,,
M_6,ldha,ldhal6a,ldhal6b,ldhb,ldhc,ldhd,,,,,...,,,,,,,,,,
M_7,acly,cs,,,,,,,,,...,,,,,,,,,,
M_8,aco1,aco2,idh1,idh2,idh3a,idh3b,idh3g,,,,...,,,,,,,,,,
M_9,dld,dldd,dldh,dlst,e3,gcsl,lad,ogdh,ogdhl,phe3,...,,,,,,,,,,
M_10,sucla2,suclg1,suclg2,,,,,,,,...,,,,,,,,,,


In [24]:
print("Starting process data...")
emptyNode = []
# extract overlap gene
geneExpr = geneExpr[gene_overlap] 
print(f'geneExpr: {geneExpr.head()}')

gene_names = geneExpr.columns
print(f'gene_names: {gene_names[:5]}')

cell_names = geneExpr.index.astype(str)
print(f'cell_names: {cell_names[:5]}')

n_modules = moduleGene.shape[0]
n_genes = len(gene_names)
n_cells = len(cell_names)
n_comps = cmMat.shape[0]
print(f'n_modules: {n_modules}, n_genes: {n_genes}, n_cells: {n_cells}, n_comps: {n_comps}')

geneExprDf = pd.DataFrame(columns = ['Module_Gene'] + list(cell_names))
print(geneExprDf)

Starting process data...
geneExpr:                                  abat      acly      aco1      aco2     adpgk  \
Cy72_CD45_H02_S758_comb      1.619334  5.393955  3.077637  6.071186  6.555041   
CY58_1_CD45_B02_S974_comb    3.210308  5.501335  3.340630  5.683188  6.074348   
Cy71_CD45_D08_S524_comb      1.298633  6.429331  4.407308  6.514881  5.381328   
Cy81_FNA_CD45_B01_S301_comb  2.827067  6.860292  5.015911  5.927352  5.942245   
Cy80_II_CD45_B07_S883_comb   0.004257  5.898215  4.751307  6.887203  5.587563   

                              aldh5a1      aldoa     aldob     aldoc  \
Cy72_CD45_H02_S758_comb      3.837932   8.410754  1.399227  4.162494   
CY58_1_CD45_B02_S974_comb    3.392198   9.255990  0.474377  4.859296   
Cy71_CD45_D08_S524_comb      2.495401  10.499494  0.069217  5.182878   
Cy81_FNA_CD45_B01_S301_comb  3.278216  10.610909  0.028445  4.992984   
Cy80_II_CD45_B07_S883_comb   2.628317  12.132136  0.010164  7.029913   

                                 bpgm  ...   

In [25]:
geneExprDf

Unnamed: 0,Module_Gene,Cy72_CD45_H02_S758_comb,CY58_1_CD45_B02_S974_comb,Cy71_CD45_D08_S524_comb,Cy81_FNA_CD45_B01_S301_comb,Cy80_II_CD45_B07_S883_comb,Cy81_Bulk_CD45_B10_S118_comb,Cy72_CD45_D09_S717_comb,Cy74_CD45_A03_S387_comb,Cy71_CD45_B05_S497_comb,...,CY75_1_CD45_CD8_7__S242_comb,CY75_1_CD45_CD8_8__S334_comb,CY75_1_CD45_CD8_3__S127_comb,CY75_1_CD45_CD8_1__S61_comb,CY75_1_CD45_CD8_1__S12_comb,CY75_1_CD45_CD8_1__S25_comb,CY75_1_CD45_CD8_7__S223_comb,CY75_1_CD45_CD8_1__S65_comb,CY75_1_CD45_CD8_1__S93_comb,CY75_1_CD45_CD8_1__S76_comb


In [26]:
n_modules

23

In [27]:
moduleGene

Unnamed: 0_level_0,A,A.1,A.2,A.3,A.4,A.5,A.6,A.7,A.8,A.9,...,A.18,A.19,A.20,A.21,A.22,A.23,A.24,A.25,A.26,A.27
1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
M_1,adpgk,g6pc,g6pc2,g6pc3,galm,gck,hk1,hk2,hk3,hkdc1,...,,,,,,,,,,
M_2,aldoa,aldob,aldoc,fbp1,fbp2,gpi,pfkl,pfkm,pfkp,tpi1,...,,,,,,,,,,
M_3,bpgm,cl640,coq10d1,coq2,gapdh,gapdhs,msa1,pgam1,pgam2,pgam4,...,,,,,,,,,,
M_4,bpgm,cl640,coq10d1,coq2,eno1,eno2,eno3,hiper1,minpp1,minpp2,...,pklr,pkm,ppt,,,,,,,
M_5,dlat,dld,dldd,dldh,e3,gcsl,lad,pdha1,pdha2,pdhb,...,,,,,,,,,,
M_6,ldha,ldhal6a,ldhal6b,ldhb,ldhc,ldhd,,,,,...,,,,,,,,,,
M_7,acly,cs,,,,,,,,,...,,,,,,,,,,
M_8,aco1,aco2,idh1,idh2,idh3a,idh3b,idh3g,,,,...,,,,,,,,,,
M_9,dld,dldd,dldh,dlst,e3,gcsl,lad,ogdh,ogdhl,phe3,...,,,,,,,,,,
M_10,sucla2,suclg1,suclg2,,,,,,,,...,,,,,,,,,,


In [28]:
genes = moduleGene.iloc[21,:].values.astype(str)
genes

array(['slc1a1', 'slc1a2', 'slc1a3', 'slc1a5', 'slc1a6', 'slc1a7',
       'slc17a6', 'slc17a8', 'slc17a7', 'nan', 'nan', 'nan', 'nan', 'nan',
       'nan', 'nan', 'nan', 'nan', 'nan', 'nan', 'nan', 'nan', 'nan',
       'nan', 'nan', 'nan', 'nan', 'nan'], dtype='<U7')

In [29]:
for i in range(n_modules):
    genes = moduleGene.iloc[i,:].values.astype(str)
    genes = [g for g in genes if g != 'nan']
    if not genes:
        emptyNode.append(i)
        continue
    temp = geneExpr.copy()
    temp.loc[:, [g for g in gene_names if g not in genes]] = 0
    temp = temp.T
    temp['Module_Gene'] = ['%02d_%s' % (i,g) for g in gene_names]
    # geneExprDf = geneExprDf.append(temp, ignore_index = True, sort=False)
    geneExprDf = pd.concat([geneExprDf, temp], ignore_index=True, sort=False)
geneExprDf.index = geneExprDf['Module_Gene']
geneExprDf.drop('Module_Gene', axis = 'columns', inplace = True)

In [30]:
geneExprDf

Unnamed: 0_level_0,Cy72_CD45_H02_S758_comb,CY58_1_CD45_B02_S974_comb,Cy71_CD45_D08_S524_comb,Cy81_FNA_CD45_B01_S301_comb,Cy80_II_CD45_B07_S883_comb,Cy81_Bulk_CD45_B10_S118_comb,Cy72_CD45_D09_S717_comb,Cy74_CD45_A03_S387_comb,Cy71_CD45_B05_S497_comb,Cy80_II_CD45_C09_S897_comb,...,CY75_1_CD45_CD8_7__S242_comb,CY75_1_CD45_CD8_8__S334_comb,CY75_1_CD45_CD8_3__S127_comb,CY75_1_CD45_CD8_1__S61_comb,CY75_1_CD45_CD8_1__S12_comb,CY75_1_CD45_CD8_1__S25_comb,CY75_1_CD45_CD8_7__S223_comb,CY75_1_CD45_CD8_1__S65_comb,CY75_1_CD45_CD8_1__S93_comb,CY75_1_CD45_CD8_1__S76_comb
Module_Gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
00_abat,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
00_acly,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
00_aco1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
00_aco2,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
00_adpgk,6.555041,6.074348,5.381328,5.942245,5.587563,5.644646,5.785019,6.188581,5.474063,5.862765,...,6.128818,5.936722,6.037975,6.213277,6.442883,5.810038,6.196168,6.097563,5.908105,5.359468
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22_taldo1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
22_tkt,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
22_tktl1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
22_tktl2,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [None]:
geneExprDf.max()

In [31]:
geneExprDf.iloc[21,]

Cy72_CD45_H02_S758_comb         2.290205
CY58_1_CD45_B02_S974_comb       2.287347
Cy71_CD45_D08_S524_comb         1.392454
Cy81_FNA_CD45_B01_S301_comb     1.304417
Cy80_II_CD45_B07_S883_comb      0.757544
                                  ...   
CY75_1_CD45_CD8_1__S25_comb     1.464093
CY75_1_CD45_CD8_7__S223_comb    2.054929
CY75_1_CD45_CD8_1__S65_comb     2.245117
CY75_1_CD45_CD8_1__S93_comb     2.117607
CY75_1_CD45_CD8_1__S76_comb     1.484051
Name: 00_g6pc, Length: 4486, dtype: float64

In [32]:
X = geneExprDf.values.T
X = torch.FloatTensor(X).to(device)

In [33]:
df = geneExprDf
print(df.index[20:50])

Index(['00_fh', '00_g6pc', '00_g6pc2', '00_g6pc3', '00_g6pd', '00_gad1',
       '00_gad2', '00_galm', '00_gapdh', '00_gapdhs', '00_gck', '00_gls',
       '00_gls2', '00_glud1', '00_glud2', '00_glul', '00_got1', '00_got2',
       '00_gpi', '00_h6pd', '00_hk1', '00_hk2', '00_hk3', '00_hkdc1',
       '00_idh1', '00_idh2', '00_idh3a', '00_idh3b', '00_idh3g', '00_ldha'],
      dtype='object', name='Module_Gene')


In [34]:
df

Unnamed: 0_level_0,Cy72_CD45_H02_S758_comb,CY58_1_CD45_B02_S974_comb,Cy71_CD45_D08_S524_comb,Cy81_FNA_CD45_B01_S301_comb,Cy80_II_CD45_B07_S883_comb,Cy81_Bulk_CD45_B10_S118_comb,Cy72_CD45_D09_S717_comb,Cy74_CD45_A03_S387_comb,Cy71_CD45_B05_S497_comb,Cy80_II_CD45_C09_S897_comb,...,CY75_1_CD45_CD8_7__S242_comb,CY75_1_CD45_CD8_8__S334_comb,CY75_1_CD45_CD8_3__S127_comb,CY75_1_CD45_CD8_1__S61_comb,CY75_1_CD45_CD8_1__S12_comb,CY75_1_CD45_CD8_1__S25_comb,CY75_1_CD45_CD8_7__S223_comb,CY75_1_CD45_CD8_1__S65_comb,CY75_1_CD45_CD8_1__S93_comb,CY75_1_CD45_CD8_1__S76_comb
Module_Gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
00_abat,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
00_acly,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
00_aco1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
00_aco2,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
00_adpgk,6.555041,6.074348,5.381328,5.942245,5.587563,5.644646,5.785019,6.188581,5.474063,5.862765,...,6.128818,5.936722,6.037975,6.213277,6.442883,5.810038,6.196168,6.097563,5.908105,5.359468
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22_taldo1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
22_tkt,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
22_tktl1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
22_tktl2,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [35]:
len(df.index.unique())

2507

In [36]:
#prepare data for constraint of module variation based on gene
df.index = [i.split('_')[0] for i in df.index]
df.index = df.index.astype(int)   # mush change type to ensure correct order, T column name order change!
#module_scale = df.groupby(df.index).sum(axis=1).T   # pandas version update

In [37]:
len(df.index.unique())

23

In [38]:
module_scale = df.groupby(df.index).sum().T  
module_scale = torch.FloatTensor(module_scale.values/ moduleLen) 
print("Process data done.")

Process data done.


In [39]:
module_scale.shape

torch.Size([4486, 23])

In [40]:
LEARN_RATE = 0.001

In [41]:
X.shape

torch.Size([4486, 2507])

# NN Model

In [78]:
import torch
import torch.nn as nn

class FLUX(nn.Module):
    def __init__(self, matrix, n_modules, f_in = 50, f_out = 1):
        super(FLUX, self).__init__()
        # gene to flux
        self.inSize = f_in     
        
        self.m_encoder = nn.ModuleList([
                                        nn.Sequential(nn.Linear(self.inSize,8, bias = False),
                                                      nn.Tanhshrink(),
                                                      nn.Linear(8, f_out),
                                                      nn.Tanhshrink()
                                                      )
                                        for i in range(n_modules)])

    
    def updateC(self, m, n_comps, cmMat): # stoichiometric matrix
        
        c = torch.zeros((m.shape[0], n_comps))
        for i in range(c.shape[1]):
            tmp = m * cmMat[i,:]
            c[:,i] = torch.sum(tmp, dim=1)
        
        return c
        

    def forward(self, x, n_modules, n_genes, n_comps, cmMat):
        print(x.shape)
        for i in range(n_modules):
            x_block = x[:,i*n_genes: (i+1)*n_genes,]
            print(x_block.shape)
            subnet = self.m_encoder[i]
            if i == 0:
                m = subnet(x_block) 
            else:
                m = torch.cat((m, subnet(x_block)),1)

        c = self.updateC(m, n_comps, cmMat)
        
        return m, c

In [79]:
# =============================================================================
#NN
torch.manual_seed(16)
print(f'X: {X.size()}, n_modules: {n_modules}, n_genes: {n_genes}')
net = FLUX(X, n_modules, f_in = n_genes, f_out = 1).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr = LEARN_RATE)
# =============================================================================

X: torch.Size([4486, 2507]), n_modules: 23, n_genes: 109


In [80]:
#Dataloader
dataloader_params = {'batch_size': BATCH_SIZE,
                        'shuffle': False,
                        'num_workers': 0,
                        'pin_memory': False}
# dataloader_params = {
#     'batch_size': BATCH_SIZE,
#     'shuffle': True,  # Shuffling might be beneficial depending on the training regime
#     'num_workers': 4,  # Adjust based on your system's capabilities and the dataset size
#     'pin_memory': True,  # Helps with faster data transfer to CUDA devices
# }

dataSet = MyDataset(X, geneExprScale, module_scale)
train_loader = torch.utils.data.DataLoader(dataset=dataSet,
                                            **dataloader_params)

In [81]:
for i, (X_batch, X_scale_batch, m_scale_batch) in enumerate(train_loader):
    print(f"Batch {i}: X_batch shape: {X_batch.shape}")


Batch 0: X_batch shape: torch.Size([4486, 2507])


In [82]:
print("Starting train neural network...")
start = time.time()  
#   training
loss_v = []
loss_v1 = []
loss_v2 = []
loss_v3 = []
loss_v4 = []
net.train()
timestr = time.strftime("%Y%m%d-%H%M%S")
lossName = "./output/lossValue_" + timestr + ".txt"
file_loss = open(lossName, "a")

Starting train neural network...


In [83]:
cmMat.size()

torch.Size([17, 23])

In [84]:
LAMB_BA = 1
LAMB_NG = 1 
LAMB_CELL =  1
LAMB_MOD = 1e-2 

In [85]:

def myLoss(m, c, lamb1 = 0.2, lamb2= 0.2, lamb3 = 0.2, lamb4 = 0.2, geneScale = None, moduleScale = None):    
    
    # balance constrain
    total1 = torch.pow(c, 2)
    total1 = torch.sum(total1, dim = 1) 
    
    # non-negative constrain
    error = torch.abs(m) - m
    total2 = torch.sum(error, dim=1)
    
    
    # sample-wise variation constrain 
    diff = torch.pow(torch.sum(m, dim=1) - geneScale, 2)
    #total3 = torch.pow(diff, 0.5)
    if sum(diff > 0) == m.shape[0]: # solve Nan after several iteraions
        total3 = torch.pow(diff, 0.5)
    else:
        total3 = diff
    
    # module-wise variation constrain
    if lamb4 > 0 :
        corr = torch.FloatTensor(np.ones(m.shape[0]))
        for i in range(m.shape[0]):
            corr[i] = pearsonr(m[i, :], moduleScale[i, :])
        corr = torch.abs(corr)
        penal_m_var = torch.FloatTensor(np.ones(m.shape[0])) - corr
        total4 = penal_m_var
    else:
        total4 = torch.FloatTensor(np.zeros(m.shape[0]))
            
    # loss
    loss1 = torch.sum(lamb1 * total1)
    loss2 = torch.sum(lamb2 * total2)
    loss3 = torch.sum(lamb3 * total3)
    loss4 = torch.sum(lamb4 * total4)
    loss = loss1 + loss2 + loss3 + loss4
    return loss, loss1, loss2, loss3, loss4
   

In [86]:
for epoch in tqdm(range(EPOCH)):
    loss, loss1, loss2, loss3, loss4 = 0,0,0,0,0
    # print(f'X: {X.size()}, X_scale: {X_scale.size()}, m_scale: {m_scale.size()}')
    for i, (X, X_scale, m_scale) in enumerate(train_loader):

        X_batch = Variable(X.float().to(device))
        print(f"X_batch: {X_batch.shape}")
        print(f"n_modules: {n_modules}")
        print(f"n_genes: {n_genes}")
        print(f"n_comps: {n_comps}")
        print(f"cmMat: {cmMat.shape}")
        X_scale_batch = Variable(X_scale.float().to(device))
        m_scale_batch = Variable(m_scale.float().to(device))
        
        out_m_batch, out_c_batch = net(X_batch, n_modules, n_genes, n_comps, cmMat)
        loss_batch, loss1_batch, loss2_batch, loss3_batch, loss4_batch  = myLoss(out_m_batch, out_c_batch, 
                                                                                    lamb1 = LAMB_BA, lamb2 = LAMB_NG, lamb3 = LAMB_CELL, lamb4 = LAMB_MOD, 
                                                                                    geneScale = X_scale_batch, moduleScale = m_scale_batch)

        optimizer.zero_grad()
        loss_batch.backward()
        optimizer.step()
        
        loss += loss_batch.cpu().data.numpy()
        loss1 += loss1_batch.cpu().data.numpy()
        loss2 += loss2_batch.cpu().data.numpy()
        loss3 += loss3_batch.cpu().data.numpy()
        loss4 += loss4_batch.cpu().data.numpy()
        
    #print('epoch: %02d, loss1: %.8f, loss2: %.8f, loss3: %.8f, loss4: %.8f, loss: %.8f' % (epoch+1, loss1, loss2, loss3, loss4, loss))
    file_loss.write('epoch: %02d, loss1: %.8f, loss2: %.8f, loss3: %.8f, loss4: %.8f, loss: %.8f. \n' % (epoch+1, loss1, loss2, loss3, loss4, loss))
    
    loss_v.append(loss)
    loss_v1.append(loss1)
    loss_v2.append(loss2)
    loss_v3.append(loss3)
    loss_v4.append(loss4)

  0%|          | 0/100 [00:00<?, ?it/s]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


  1%|          | 1/100 [00:00<01:34,  1.05it/s]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


  2%|▏         | 2/100 [00:02<02:09,  1.32s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


  3%|▎         | 3/100 [00:03<02:02,  1.26s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


  4%|▍         | 4/100 [00:05<02:04,  1.29s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


  5%|▌         | 5/100 [00:06<02:04,  1.31s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


  6%|▌         | 6/100 [00:07<01:54,  1.22s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


  7%|▋         | 7/100 [00:08<01:48,  1.16s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


  8%|▊         | 8/100 [00:09<01:45,  1.15s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


  9%|▉         | 9/100 [00:10<01:40,  1.11s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


 10%|█         | 10/100 [00:11<01:37,  1.08s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


 11%|█         | 11/100 [00:12<01:36,  1.08s/it]

X_batch: torch.Size([4486, 2507])
n_modules: 23
n_genes: 109
n_comps: 17
cmMat: torch.Size([17, 23])
torch.Size([4486, 2507])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])
torch.Size([4486, 109])


 12%|█▏        | 12/100 [00:13<01:40,  1.15s/it]


KeyboardInterrupt: 

In [None]:
end = time.time()
print("Training time: ", end - start) 

file_loss.close()
plt.plot(loss_v, '--')
plt.plot(loss_v1)
plt.plot(loss_v2)
plt.plot(loss_v3)
plt.plot(loss_v4)
plt.legend(['total', 'balance', 'negative', 'cellVar', 'moduleVar']);
imgName = './' + res_dir + '/loss_' + timestr + ".png"
plt.savefig(imgName)
timeName =  './' + res_dir + '/time_' + timestr + ".txt"
f = open(timeName, "a")
runTimeStr = str(end - start)
f.write(runTimeStr)
f.close()   

In [None]:
#    Dataloader
dataloader_params = {'batch_size': 1,
                        'shuffle': False,
                        'num_workers': 0,
                        'pin_memory': False}

dataSet = MyDataset(X, geneExprScale, module_scale)
test_loader = torch.utils.data.DataLoader(dataset=dataSet,
                        **dataloader_params)

#testing
fluxStatuTest = np.zeros((n_cells, n_modules), dtype='f') #float32
balanceStatus = np.zeros((n_cells, n_comps), dtype='f')
net.eval()
for epoch in range(1):
    loss, loss1, loss2 = 0,0,0
    
    for i, (X, X_scale, _) in enumerate(test_loader):

        X_batch = Variable(X.float().to(device))
        out_m_batch, out_c_batch = net(X_batch, n_modules, n_genes, n_comps, cmMat)
        
        # save data
        # fluxStatuTest[i, :] = out_m_batch.detach().numpy()
        fluxStatuTest[i, :] = out_m_batch.detach().cpu().numpy()
        # balanceStatus[i, :] = out_c_batch.detach().numpy()
        balanceStatus[i, :] = out_c_batch.detach().cpu().numpy()

        
                

# save to file
if fileName == 'NULL':
    # user do not define file name of flux
    fileName = "./" + res_dir + "/" + test_file[-len(test_file):-4] + "_module" + str(n_modules) + "_cell" + str(n_cells) + "_batch" + str(BATCH_SIZE) + \
                "_LR" + str(LEARN_RATE) + "_epoch" + str(EPOCH) + "_SCimpute_" + str(sc_imputation)[0] + \
                "_lambBal" + str(LAMB_BA) + "_lambSca" + str(LAMB_NG) + "_lambCellCor" + str(LAMB_CELL) + "_lambModCor_1e-2" + \
                '_' + timestr + ".csv"
setF = pd.DataFrame(fluxStatuTest)
setF.columns = moduleGene.index
setF.index = geneExpr.index.tolist()
setF.to_csv(fileName)

setB = pd.DataFrame(balanceStatus)
setB.rename(columns = lambda x: x + 1)
setB.index = setF.index
if cName_file != 'noCompoundName':
    setB.columns = cName
if balanceName == 'NULL':
    # user do not define file name of balance
    balanceName = "./output/balance_" + timestr + ".csv"
setB.to_csv(balanceName)


print("scFEA job finished. Check result in the desired output folder.")