In [1]:
import numpy as np
import pandas as pd
import os
import scanpy as sc
from sklearn.model_selection import train_test_split
import torch
from model import *
import torch.optim as optim
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
import warnings
warnings.filterwarnings("ignore")

2024-11-27 22:01:29.180646: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
data_dir = 'spatial_datasets/GSE213264_RAW/'
results_dir = 'results_spatial_coord/'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tissues= ['humanGBM', 'humanskin', 'humanthymus', 'humanspleen', 'humantonsil', 'mousekidney', 'mouseintestine', 'mousecolon', 'mousespleen']

for tissue in tissues:
    rna_data = None
    protein_data = None

    for filename in os.listdir(data_dir):
        file_path = os.path.join(data_dir, filename)
        if tissue in filename and filename.endswith("RNA.tsv.gz"):
            rna_data = pd.read_csv(file_path, sep="\t")
        elif tissue in filename and filename.endswith("protein.tsv.gz"):
            protein_data = pd.read_csv(file_path, sep="\t")

    rna_data.columns = rna_data.columns.astype(str)
    protein_data.columns = protein_data.columns.astype(str)
            
    rna_data = rna_data.sort_values(by='X')
    protein_data = protein_data.sort_values(by='X')
    
    rna_data = rna_data.reset_index(drop=True)
    protein_data = protein_data.reset_index(drop=True)
    rna_data.index = rna_data.index.astype(str)
    protein_data.index = protein_data.index.astype(str)
  
    rna_data[['X', 'Y']] = rna_data['X'].str.split('x', expand=True)
    rna_data['X'] = pd.to_numeric(rna_data['X'], errors='coerce')
    rna_data['Y'] = pd.to_numeric(rna_data['Y'], errors='coerce')
    spatial = rna_data[['X', 'Y']].copy()
    rna_data.drop(['X', 'Y'], axis=1, inplace=True)
    protein_data.drop(['X'], axis=1, inplace=True)
  
    rna_train, rna_test = train_test_split(rna_data, test_size=0.2, random_state=42)
    protein_train = protein_data.loc[rna_train.index]
    protein_test = protein_data.loc[rna_test.index]
  
    adata_rna_train = sc.AnnData(rna_train)
    sc.pp.normalize_total(adata_rna_train, target_sum=1e4)
    sc.pp.log1p(adata_rna_train)
    sc.pp.highly_variable_genes(adata_rna_train, n_top_genes=4000, flavor='seurat', subset=True)
    counts_norm = adata_rna_train.X
    rna_counts_norm = torch.FloatTensor(counts_norm).to(device)
  
    adata_protein_train = sc.AnnData(protein_train) 
    sc.pp.normalize_total(adata_protein_train, target_sum=1e4)
    sc.pp.log1p(adata_protein_train)
    counts_norm = adata_protein_train.X
    protein_counts_norm = torch.FloatTensor(counts_norm).to(device)
    
    spatial_train = torch.FloatTensor(spatial.loc[rna_train.index].values).to(device)
    combined_data = torch.cat([rna_counts_norm, protein_counts_norm, spatial_train], dim=1).to(device)
    
    input_dim = combined_data.shape[1] 
    latent_dim = 32
    model = VAE(input_dim, latent_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    num_epochs = 100
    batch_size = 64
  
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        permutation = torch.randperm(combined_data.size(0))

        for i in range(0, combined_data.size(0), batch_size):
            optimizer.zero_grad()
            indices = permutation[i:i+batch_size]
            batch_data = combined_data[indices]

            reconstructed_data, mean, logvar = model(batch_data)
            loss = vae_loss(reconstructed_data, batch_data, mean, logvar, lambda_kl=0.0001)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(combined_data):.4f}')

    model.eval()
    with torch.no_grad():
        combined_data = torch.cat([rna_counts_norm, torch.zeros(rna_counts_norm.shape[0], protein_counts_norm.shape[1]).to(device), spatial_train], dim=1)
        reconstructed_data, mean, logvar = model(combined_data)
        reconstructed_protein_counts = reconstructed_data[:, rna_counts_norm.shape[1]:rna_counts_norm.shape[1] + protein_test.shape[1]]
        
        rmse = np.sqrt(mean_squared_error(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy()))
        pcc = pd.DataFrame(protein_counts_norm.cpu().numpy()).corrwith(pd.DataFrame(reconstructed_protein_counts.cpu().numpy()), axis=1, method='pearson')
        avg_corr_pearson = pcc.mean()
        ssim_val = ssim(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy(), data_range=reconstructed_protein_counts.cpu().numpy().max() - reconstructed_protein_counts.cpu().numpy().min())
        
        results_df = pd.DataFrame({
            'RMSE': [rmse],
            'Pearson Correlation': [avg_corr_pearson],
            'SSIM':ssim_val
        })
        
        results_file_path = os.path.join(results_dir, f"{tissue}_training_results.csv")
        results_df.to_csv(results_file_path, index=False)
        
        adata_rna_test = sc.AnnData(rna_test) 
        sc.pp.normalize_total(adata_rna_test, target_sum=1e4)
        sc.pp.log1p(adata_rna_test)
        counts_norm = adata_rna_test[:,  adata_rna_train.var_names].X
        rna_counts_norm = torch.FloatTensor(counts_norm).to(device)
    
        adata_protein_test = sc.AnnData(protein_test) 
        sc.pp.normalize_total(adata_protein_test, target_sum=1e4)
        sc.pp.log1p(adata_protein_test)
        counts_norm = adata_protein_test.X
        protein_counts_norm = torch.FloatTensor(counts_norm).to(device)
        
        spatial_test = torch.FloatTensor(spatial.loc[rna_test.index].values).to(device)
        
        combined_data = torch.cat([rna_counts_norm, torch.zeros(rna_counts_norm.shape[0], protein_counts_norm.shape[1]).to(device), spatial_test], dim=1)
        reconstructed_data, mean, logvar = model(combined_data)
        reconstructed_protein_counts = reconstructed_data[:, rna_counts_norm.shape[1]:rna_counts_norm.shape[1] + protein_test.shape[1]]
          
        rmse = np.sqrt(mean_squared_error(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy()))
        pcc = pd.DataFrame(protein_counts_norm.cpu().numpy()).corrwith(pd.DataFrame(reconstructed_protein_counts.cpu().numpy()), axis=1, method='pearson')
        avg_corr_pearson = pcc.mean()
        ssim_val = ssim(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy(), data_range=reconstructed_protein_counts.cpu().numpy().max() - reconstructed_protein_counts.cpu().numpy().min())
        
        results_df = pd.DataFrame({
            'RMSE': [rmse],
            'Pearson Correlation': [avg_corr_pearson],
            'SSIM':ssim_val
        })
    
    results_file_path = os.path.join(results_dir, f"{tissue}_results.csv")
    results_df.to_csv(results_file_path, index=False)

    print(f"Processed {tissue} successfully.")


Using device: cuda
Epoch [1/100], Loss: 4917.5757
Epoch [2/100], Loss: 2664.1547
Epoch [3/100], Loss: 1795.2347
Epoch [4/100], Loss: 1463.9311
Epoch [5/100], Loss: 1335.2752
Epoch [6/100], Loss: 1279.2134
Epoch [7/100], Loss: 1252.7686
Epoch [8/100], Loss: 1236.4646
Epoch [9/100], Loss: 1218.5511
Epoch [10/100], Loss: 1216.7726
Epoch [11/100], Loss: 1206.2041
Epoch [12/100], Loss: 1208.9180
Epoch [13/100], Loss: 1202.4056
Epoch [14/100], Loss: 1197.5154
Epoch [15/100], Loss: 1205.5190
Epoch [16/100], Loss: 1197.2306
Epoch [17/100], Loss: 1197.5259
Epoch [18/100], Loss: 1195.7079
Epoch [19/100], Loss: 1190.9355
Epoch [20/100], Loss: 1189.7309
Epoch [21/100], Loss: 1191.1119
Epoch [22/100], Loss: 1191.0002
Epoch [23/100], Loss: 1187.6081
Epoch [24/100], Loss: 1186.1822
Epoch [25/100], Loss: 1167.3913
Epoch [26/100], Loss: 1127.9677
Epoch [27/100], Loss: 1100.1129
Epoch [28/100], Loss: 1084.6855
Epoch [29/100], Loss: 1073.8904
Epoch [30/100], Loss: 1071.0550
Epoch [31/100], Loss: 1067.379