In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
import torchvision
import torch.optim as optim
import argparse
import matplotlib
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import pandas as pd
import numpy as np
matplotlib.style.use('ggplot')
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


### Import Data

In [2]:
genes = pd.read_csv('../datasets_transpose_csv/genes_transpose.csv')

In [3]:
genes_ = genes.drop(columns=['Unnamed: 0'])
genes_ = genes_.drop(columns=['CELL_LINE'])
genes_np = genes_.to_numpy()

In [4]:
genes_train, genes_test = train_test_split(genes_np, test_size=0.10, random_state=42)

In [5]:
genes_train.shape

(917, 57820)

In [6]:
genes_test.shape

(102, 57820)

In [7]:
genes_

Unnamed: 0,ENSG00000000003.10,ENSG00000000005.5,ENSG00000000419.8,ENSG00000000457.9,ENSG00000000460.12,ENSG00000000938.8,ENSG00000000971.11,ENSG00000001036.9,ENSG00000001084.6,ENSG00000001167.10,...,ENSGR0000237531.1,ENSGR0000237801.1,ENSGR0000263835.1,ENSGR0000263980.1,ENSGR0000264510.1,ENSGR0000264819.1,ENSGR0000265350.1,ENSGR0000265658.1,ENSGR0000266731.1,ENSGR0000270726.1
0,5.28,0.0,73.38,9.76,24.51,0.01,0.08,54.86,118.50,38.05,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,7.01,0.0,108.99,16.76,13.32,0.00,0.23,170.91,93.00,18.64,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,22.80,0.0,56.51,2.58,10.86,0.00,0.06,30.78,22.16,14.33,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,22.88,0.0,45.39,3.25,5.26,0.00,0.28,45.05,19.45,7.91,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,23.09,0.0,99.28,2.73,9.27,0.02,0.45,53.29,8.36,15.27,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1014,47.96,0.0,119.50,5.06,9.58,0.03,0.10,115.62,40.13,21.08,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1015,28.92,0.0,64.09,4.94,13.35,0.22,139.44,118.19,13.53,26.15,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1016,61.08,0.0,109.13,5.91,17.40,0.13,53.25,92.96,23.09,33.22,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1017,8.12,0.0,139.23,15.96,17.45,0.06,1.23,19.75,20.37,12.47,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Convert data to Tensor

In [8]:
genes_train_t = torch.Tensor(genes_train)
genes_test_t = torch.Tensor(genes_test)
genes_full = torch.Tensor(genes_np)

In [9]:
genes_full.dtype

torch.float32

### Training Parameters

In [10]:
epochs = 20
batch_size = 256
lr = 0.05
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
torch.cuda.get_device_name(device)

'NVIDIA GeForce RTX 3060'

### Define DataLoader

In [12]:
train_loader = DataLoader(genes_train_t, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(genes_test_t,batch_size=batch_size, shuffle=False)
full_loader = DataLoader(genes_full,batch_size=batch_size)

### Define VAE

In [13]:
import torch
import torch.nn as nn

class VAE(nn.Module):
    def __init__(self, input_size, level_2, level_3, level_4, latent_dim):
        super(VAE, self).__init__()
        
        # Encoder layers
        self.enc_fc1 = nn.Sequential(
                        nn.Linear(input_size, level_2),
                        nn.BatchNorm1d(level_2),
                        nn.ReLU())
        
        self.enc_fc2 = nn.Sequential(
                        nn.Linear(level_2, level_3),
                        nn.BatchNorm1d(level_3),
                        nn.ReLU())
        
        self.enc_fc3 = nn.Sequential(
                        nn.Linear(level_3, level_4),
                        nn.BatchNorm1d(level_4),
                        nn.ReLU())

        self.enc_fc4_mean = nn.Sequential(
                    nn.Linear(level_4, latent_dim),
                    nn.BatchNorm1d(latent_dim))
        
        self.enc_fc4_log_var = nn.Sequential(
                    nn.Linear(level_4, latent_dim),
                    nn.BatchNorm1d(latent_dim))
        
        
        # Decoder layers
        self.dec_fc4 = nn.Sequential(
                        nn.Linear(latent_dim, level_4),
                        nn.BatchNorm1d(level_4),
                        nn.ReLU())
        
        self.dec_fc3 = nn.Sequential(
                        nn.Linear(level_4, level_3),
                        nn.BatchNorm1d(level_3),
                        nn.ReLU())
        
        self.dec_fc2 = nn.Sequential(
                        nn.Linear(level_3, level_2),
                        nn.BatchNorm1d(level_2),
                        nn.ReLU())
        
        self.dec_fc1 = nn.Sequential(
                    nn.Linear(level_2, input_size),
                    nn.BatchNorm1d(input_size),
                    nn.Sigmoid())


    def encode(self, x):
        l2_layer = self.enc_fc1(x)
        l3_layer = self.enc_fc2(l2_layer)
        l4_layer = self.enc_fc3(l3_layer)
        
        mu = self.enc_fc4_mean(l4_layer)
        logvar = self.enc_fc4_log_var(l4_layer)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        l4_layer = self.dec_fc4(z)
        l3_layer = self.dec_fc3(l4_layer)
        l2_layer = self.dec_fc2(l3_layer)
        x_hat = self.dec_fc1(l2_layer)
        return x_hat

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        z = z.to(device)
        x_hat = self.decode(z)
        return x_hat, mu, logvar
    




### Initialize model parameters

In [14]:
# Initialize the model
input_size = 57820 #dimension of gene expressions
level_2 = 4096
level_3 = 2048
level_4 = 1024
latent_dim = 512 # target latent size
model = VAE(input_size, level_2, level_3, level_4, latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

### Define Reconstruction Loss and KL Divergence Loss

In [15]:
def recon_loss(x_hat, x): # Reconstruction Loss
        lossFunc = torch.nn.MSELoss()
        loss = lossFunc(x_hat, x)
        return loss

def kl_loss(mean, log_var): # KL Divergence
    loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return loss

### Train Function

In [16]:
def train(model, dataloader):
    model.train()
    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0
    for batch_index, sample in enumerate(dataloader):
        data = sample
        data = data.to(device)

        optimizer.zero_grad()

        x_hat, mu, logvar = model(data)
        x_hat = x_hat.to(device)

        train_recon_loss = recon_loss(x_hat, data)
        train_kl_loss = kl_loss(mu, logvar)
        loss = (train_recon_loss + train_kl_loss)
        total_recon += train_recon_loss.item()
        total_kl += train_kl_loss.item()
        total_loss +=loss.item()

        loss.backward()
        optimizer.step()
        
    train_loss = total_loss/len(dataloader.dataset)
    return train_loss, total_recon/len(dataloader.dataset), total_kl/len(dataloader.dataset)

### Validate Function

In [20]:
def validate(model, dataloader):
    model.eval()
    total_loss = 0.0
    total_kl = 0.0
    total_recon = 0.0
    with torch.no_grad():
        mean_store = torch.zeros(1, latent_dim).to(device)
        for batch_index, sample in enumerate(dataloader):
            data = sample
            data = data.to(device)

            optimizer.zero_grad()

            x_hat, mu, logvar = model(data)

            val_recon_loss = recon_loss(x_hat, data)
            val_kl = kl_loss(mu, logvar)
            loss = (val_recon_loss + val_kl)
            
            total_kl += val_kl.item()
            total_recon+= val_recon_loss.item()
            total_loss +=loss.item()

            mean_store = torch.cat((mean_store, mu), 0)

    all_data_mean = mean_store[1:]
    all_data_mean_np = all_data_mean.cpu().numpy()

    input_path = "C:/Users/mdzak/Desktop/GitHub/FYP_Zaki/results"
    input_path_name = input_path.split('/')[-1]
    latent_space_path = '../results/' + input_path_name + str(latent_dim) + 'D_latent_space_gene_exp.tsv'

    all_data_mean_df = pd.DataFrame(all_data_mean_np)
    all_data_mean_df.to_csv(latent_space_path, sep='\t')

    val_avg_total_loss = total_loss/len(dataloader.dataset)
    val_avg_kl = total_kl/len(dataloader.dataset)
    val_avg_recon = total_recon/len(dataloader.dataset)

    return val_avg_total_loss, val_avg_kl, val_avg_recon


### Execute Training and Validation

In [21]:
from tqdm import tqdm

In [22]:
train_loss = []
val_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_recons_loss, train_kl_loss = train(model, train_loader)
    val_epoch_loss, val_kl_loss, val_recon_loss = validate(model, full_loader)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Training Loss (KL plus MSE): {train_epoch_loss:.4f}")
    print(f"Training Loss (MSE): {train_recons_loss:.4f}")
    print(f"Training Loss (KL): {train_kl_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")
    print(f"Val KL Loss: {val_kl_loss:.4f}")
    print(f"Val Recon Loss (MSE): {val_recon_loss:.4f}")
    print()

Epoch 1 of 20
Training Loss (KL plus MSE): 708.3214
Training Loss (MSE): 529.5356
Training Loss (KL): 178.7858
Val Loss: 2198262902637674775969792.0000
Val KL Loss: 2198262902637674775969792.0000
Val Recon Loss (MSE): 481.7674

Epoch 2 of 20
Training Loss (KL plus MSE): 635.1151
Training Loss (MSE): 526.2105
Training Loss (KL): 108.9046
Val Loss: 29271.1119
Val KL Loss: 28789.3337
Val Recon Loss (MSE): 481.7781

Epoch 3 of 20
Training Loss (KL plus MSE): 587.6288
Training Loss (MSE): 525.9501
Training Loss (KL): 61.6787
Val Loss: 3269.7618
Val KL Loss: 2787.9913
Val Recon Loss (MSE): 481.7705

Epoch 4 of 20
Training Loss (KL plus MSE): 556.3485
Training Loss (MSE): 524.1376
Training Loss (KL): 32.2109
Val Loss: 893.2107
Val KL Loss: 411.4673
Val Recon Loss (MSE): 481.7434

Epoch 5 of 20
Training Loss (KL plus MSE): 541.3860
Training Loss (MSE): 524.2798
Training Loss (KL): 17.1062
Val Loss: 592.8197
Val KL Loss: 111.1272
Val Recon Loss (MSE): 481.6925

Epoch 6 of 20
Training Loss (KL p