In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
import torchvision
import torch.optim as optim
import argparse
import matplotlib
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import pandas as pd
import numpy as np
matplotlib.style.use('ggplot')
from sklearn.model_selection import train_test_split
import lightning.pytorch as pl
from pytorch_lightning.loggers import TensorBoardLogger

### Define PyTorch Model

In [16]:
class VAE(nn.Module):
    def __init__(self, input_size, level_2, level_3, latent_dim):
        super(VAE, self).__init__()
        
        # Encoder layers
        self.enc_fc1 = nn.Sequential(
                        nn.Linear(input_size, level_2),
                        nn.BatchNorm1d(level_2),
                        nn.ReLU())
        
        self.enc_fc2 = nn.Sequential(
                        nn.Linear(level_2, level_3),
                        nn.BatchNorm1d(level_3),
                        nn.ReLU())

        self.enc_fc3_mean = nn.Sequential(
                    nn.Linear(level_3, latent_dim),
                    nn.BatchNorm1d(latent_dim))
        
        self.enc_fc3_log_var = nn.Sequential(
                    nn.Linear(level_3, latent_dim),
                    nn.BatchNorm1d(latent_dim))
        
        
        # Decoder layers
        self.dec_fc3 = nn.Sequential(
                        nn.Linear(latent_dim, level_3),
                        nn.BatchNorm1d(level_3),
                        nn.ReLU())
        
        self.dec_fc2 = nn.Sequential(
                        nn.Linear(level_3, level_2),
                        nn.BatchNorm1d(level_2),
                        nn.ReLU())
        
        self.dec_fc1 = nn.Sequential(
                    nn.Linear(level_2, input_size),
                    nn.BatchNorm1d(input_size),
                    nn.Sigmoid())
        
       


    def encode(self, x):
        l2_layer = self.enc_fc1(x)
        l3_layer = self.enc_fc2(l2_layer)
        
        mu = self.enc_fc3_mean(l3_layer)
        logvar = self.enc_fc3_log_var(l3_layer)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        l3_layer = self.dec_fc3(z)
        l2_layer = self.dec_fc2(l3_layer)
        x_hat = self.dec_fc1(l2_layer)
        return x_hat

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar, z

### Loss Function

In [17]:
def loss_function(x_hat, x, mean, log_var): # recon loss and kld loss
        bce = torch.nn.functional.binary_cross_entropy(x_hat, x, reduction = 'sum')
        kld = 0.5 * torch.sum(log_var.exp() + mean.pow(2) - 1 - log_var)
        loss = kld + bce
        return loss, kld, bce

### Define Lightning Module

In [18]:
class LitVAE(pl.LightningModule):
    def __init__ (self, VAE):
        super().__init__()
        self.VAE = VAE

    def training_step(self, batch, batch_idx):
        x = batch
        x = x.view(x.size(0), -1)
        x_hat, mu, logvar, z = self.VAE(x)
        loss, kld, bce = loss_function(x_hat, x, mu, logvar)
        self.log('Total Train Loss', loss/len(x), on_epoch=True)
        self.log('KL Train Loss', kld/len(x), on_epoch=True)
        self.log('Recon TRain Loss', bce/len(x), on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch
        x = x.view(x.size(0), -1)
        x_hat, mu, logvar, z = self.VAE(x)
        return mu

### Define Dataloader

In [19]:
# import csv
dna_meth = pd.read_csv('../datasets_transpose_csv/dna_meth_transpose.csv')
dna_meth = dna_meth.drop(columns=['Unnamed: 0'])
cell_lines = dna_meth['CpG_sites_hg19']
dna_meth = dna_meth.drop(columns=['CpG_sites_hg19'])
dna_meth = dna_meth.drop(columns=['Unnamed: 81039'])
dna_meth_np = dna_meth.to_numpy()

In [20]:
batch_size = 32
dna_meth_full = torch.Tensor(dna_meth_np)
full_loader = DataLoader(dna_meth_full,batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)

### Train model

In [21]:
torch.set_float32_matmul_precision("medium")
torch.manual_seed(33)

<torch._C.Generator at 0x25153fce130>

In [22]:
input_size = 81037 #dimension of gene expressions
level_2 = 2048
level_3 = 1500
latent_dim = 1024 # target latent size
# model
VAE_model = LitVAE(VAE(input_size, level_2, level_3, latent_dim))

# train model
logger = TensorBoardLogger('tb_logs', name= 'vae_dna_meth')
trainer = pl.Trainer(accelerator="gpu", max_epochs=100, enable_checkpointing= False, logger=logger)
trainer.fit(model=VAE_model, train_dataloaders=full_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
0 | VAE  | VAE  | 342 M 
------------------------------
342 M     Trainable params
0         Non-trainable params
342 M     Total params
1,371.800 Total estimated model params size (MB)
  rank_zero_warn(


Epoch 99: 100%|██████████| 27/27 [00:06<00:00,  4.34it/s, v_num=4]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 27/27 [00:06<00:00,  4.34it/s, v_num=4]


In [9]:
pred = trainer.predict(VAE_model, full_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 27/27 [00:00<00:00, 64.90it/s]


PREDICT Profiler Report

-----------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                             	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
-----------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                              	|  -              	|  82072          	|  872.38         	|  100 %          	|
-----------------------------------------------------------------------------------------------------------------------------------------------------------
|  run_training_epoch                                 	|  6.2807         	|  100            	|  628.07         	|  71.996         	|
|  run_training_batch                                 	|  0.11025        	|  2700           	|  297.67    

In [10]:
pred_list = []

for i in range(len(pred)):
    for j in range(pred[0].shape[0]):
        try:
            pred_list.append(pred[i][j])
        except:
            pass

In [11]:
# Assuming you have a list of tensors named "tensor_list"
pred_list_np = [tensor.numpy() for tensor in pred_list]

# Convert the list of NumPy arrays to a single NumPy array
pred_np = np.array(pred_list_np)

In [12]:
file_path = 'test.tsv'
pred_df = pd.DataFrame(pred_np)
pred_df.to_csv(file_path, sep='\t')

In [13]:
pred_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
0,0.429570,-1.909910,-1.472437,0.748720,-0.016270,0.655753,-0.239222,0.928533,-1.746679,-1.498429,...,-1.167183,-0.828607,1.255535,0.486317,-0.253607,1.616354,-0.522952,-1.643377,1.878682,-1.064827
1,0.799768,-1.317464,-0.283014,-0.348736,0.054551,-0.999007,0.488884,-0.439297,-1.248220,0.057584,...,-0.464248,-1.455493,0.738102,-0.426222,0.916154,0.237763,0.262881,-0.647636,-0.563344,-0.511420
2,-0.063261,0.137814,0.425855,1.471842,1.735726,-0.114720,0.246615,-1.031568,1.238343,-0.124922,...,0.169023,0.792818,-0.337984,1.076447,0.334452,-0.209356,-0.257603,0.988091,-0.565004,0.755501
3,-0.134225,1.318798,0.302894,0.531689,-0.053797,-1.038170,0.176276,-0.443220,0.889723,-0.112907,...,-0.365432,0.479045,-0.164296,0.110523,-0.554673,-0.434859,0.692086,0.668916,-0.138584,0.248618
4,0.029383,0.827876,0.367270,0.795097,0.525147,-0.352542,0.858132,-0.536660,1.324821,0.810718,...,2.028498,0.786475,-0.896838,0.305146,0.597017,0.786401,-1.115265,0.561232,-0.741169,0.706336
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
838,0.406915,0.407979,0.790799,-0.065704,-0.178693,0.133986,-0.294570,0.907708,-0.532263,0.511409,...,1.018077,0.388512,0.163653,-0.042266,0.109107,-1.311447,0.074531,-0.708374,-0.111912,0.241496
839,-1.108940,1.210906,1.041628,-0.037763,-0.117001,0.395379,-0.342543,0.380366,0.976832,0.584943,...,0.453464,0.929581,-0.500538,0.133473,-0.978117,0.073899,0.300841,0.970426,-0.090345,0.103342
840,-0.353198,0.095412,0.901895,0.286198,0.049067,0.979891,-0.113612,0.622324,0.131119,0.881184,...,0.595284,0.551816,-0.225171,-0.099614,0.022435,-0.195926,0.816269,0.122387,-0.223316,-0.330512
841,-0.852346,0.393635,0.945212,0.263071,-0.075022,1.196889,-0.259242,0.906962,0.612942,0.578142,...,0.169133,0.684954,-0.239488,0.127601,-0.640551,0.255922,0.485135,0.284296,0.089397,-0.179119
