In [None]:
import numpy as np
import pandas as pd
import os
import torch
from sklearn.model_selection import train_test_split
import scanpy as sc
from scipy.spatial.distance import cdist
from model import *
import torch.optim as optim
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
import warnings
warnings.filterwarnings("ignore")

2024-11-27 18:01:25.982966: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
data_dir = 'spatial_datasets/GSE213264_RAW/'
results_dir = 'results_neigh_loss/'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tissues= ['humanGBM', 'humanskin', 'humanthymus', 'humanspleen', 'humantonsil', 'mousekidney', 'mouseintestine', 'mousecolon', 'mousespleen']

for tissue in tissues:
    rna_data = None
    protein_data = None

    for filename in os.listdir(data_dir):
        file_path = os.path.join(data_dir, filename)
        if tissue in filename and filename.endswith("RNA.tsv.gz"):
            rna_data = pd.read_csv(file_path, sep="\t")
        elif tissue in filename and filename.endswith("protein.tsv.gz"):
            protein_data = pd.read_csv(file_path, sep="\t")

    rna_data.columns = rna_data.columns.astype(str)
    protein_data.columns = protein_data.columns.astype(str)

    rna_data = rna_data.sort_values(by='X')
    protein_data = protein_data.sort_values(by='X')

    rna_data = rna_data.reset_index(drop=True)
    protein_data = protein_data.reset_index(drop=True)
    rna_data.index = rna_data.index.astype(str)
    protein_data.index = protein_data.index.astype(str)

    rna_data[['X', 'Y']] = rna_data['X'].str.split('x', expand=True)
    rna_data['X'] = pd.to_numeric(rna_data['X'], errors='coerce')
    rna_data['Y'] = pd.to_numeric(rna_data['Y'], errors='coerce')
    spatial = rna_data[['X', 'Y']].copy()
    rna_data.drop(['X', 'Y'], axis=1, inplace=True)
    protein_data.drop(['X'], axis=1, inplace=True)

    rna_train, rna_test = train_test_split(rna_data, test_size=0.2, random_state=42)
    protein_train = protein_data.loc[rna_train.index]
    protein_test = protein_data.loc[rna_test.index]

    adata_rna_train = sc.AnnData(rna_train)
    sc.pp.normalize_total(adata_rna_train, target_sum=1e4)
    sc.pp.log1p(adata_rna_train)
    sc.pp.highly_variable_genes(adata_rna_train, n_top_genes=4000, flavor='seurat', subset=True)
    counts_norm = adata_rna_train.X
    rna_counts_norm = torch.FloatTensor(counts_norm).to(device)

    adata_protein_train = sc.AnnData(protein_train)
    sc.pp.normalize_total(adata_protein_train, target_sum=1e4)
    sc.pp.log1p(adata_protein_train)
    counts_norm = adata_protein_train.X
    protein_counts_norm = torch.FloatTensor(counts_norm).to(device)

    spatial_train = torch.tensor(spatial.loc[rna_train.index].values, dtype=torch.float32).to(device)
    combined_data = torch.cat([rna_counts_norm, protein_counts_norm], dim=1).to(device)

    distances = torch.cdist(spatial_train, spatial_train, p=2)
    number_neighbors = 15
    closest_neighbors = {}
    furthest_neighbors = {}

    for i in range(distances.shape[0]):
        sorted_indices = torch.argsort(distances[i])
        closest_neighbors[i] = sorted_indices[1:number_neighbors+1].cpu().numpy()
        furthest_neighbors[i] = sorted_indices[-number_neighbors:].cpu().numpy()

    input_dim = combined_data.shape[1]
    latent_dim = 32
    model = VAE(input_dim, latent_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    num_epochs = 100

    latent_means_all = [None] * combined_data.size(0)

    for epoch in range(num_epochs):
        model.train()
        reconstructed_data, mean, logvar = model(combined_data)

        loss = vae_loss2(reconstructed_data, combined_data, mean, logvar, closest_neighbors, furthest_neighbors,lambda_kl=0.0001, lambda_nl=1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss / len(combined_data):.4f}')


    model.eval()
    with torch.no_grad():
        combined_data = torch.cat([rna_counts_norm, torch.zeros(rna_counts_norm.shape[0], protein_counts_norm.shape[1]).to(device)], dim=1)
        reconstructed_data, mean, logvar = model(combined_data)
        reconstructed_protein_counts = reconstructed_data[:, rna_counts_norm.shape[1]:]

        rmse = np.sqrt(mean_squared_error(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy()))
        pcc = pd.DataFrame(protein_counts_norm.cpu().numpy()).corrwith(pd.DataFrame(reconstructed_protein_counts.cpu().numpy()), axis=1, method='pearson')
        avg_corr_pearson = pcc.mean()
        ssim_val = ssim(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy(), data_range=reconstructed_protein_counts.cpu().numpy().max() - reconstructed_protein_counts.cpu().numpy().min())

        results_df = pd.DataFrame({
            'RMSE': [rmse],
            'Pearson Correlation': [avg_corr_pearson],
            'SSIM':ssim_val
        })

        results_file_path = os.path.join(results_dir, f"{tissue}_training_results.csv")
        results_df.to_csv(results_file_path, index=False)

        adata_rna_test = sc.AnnData(rna_test)
        sc.pp.normalize_total(adata_rna_test, target_sum=1e4)
        sc.pp.log1p(adata_rna_test)
        counts_norm = adata_rna_test[:,  adata_rna_train.var_names].X
        rna_counts_norm = torch.FloatTensor(counts_norm).to(device)

        adata_protein_test = sc.AnnData(protein_test)
        sc.pp.normalize_total(adata_protein_test, target_sum=1e4)
        sc.pp.log1p(adata_protein_test)
        counts_norm = adata_protein_test.X
        protein_counts_norm = torch.FloatTensor(counts_norm).to(device)

        combined_data = torch.cat([rna_counts_norm, torch.zeros(rna_counts_norm.shape[0], protein_counts_norm.shape[1]).to(device)], dim=1)
        reconstructed_data, mean, logvar = model(combined_data)
        reconstructed_protein_counts = reconstructed_data[:, rna_counts_norm.shape[1]:]

        rmse = np.sqrt(mean_squared_error(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy()))
        pcc = pd.DataFrame(protein_counts_norm.cpu().numpy()).corrwith(pd.DataFrame(reconstructed_protein_counts.cpu().numpy()), axis=1, method='pearson')
        avg_corr_pearson = pcc.mean()
        ssim_val = ssim(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy(), data_range=reconstructed_protein_counts.cpu().numpy().max() - reconstructed_protein_counts.cpu().numpy().min())

        results_df = pd.DataFrame({
            'RMSE': [rmse],
            'Pearson Correlation': [avg_corr_pearson],
            'SSIM':ssim_val
        })

    results_file_path = os.path.join(results_dir, f"{tissue}_results.csv")
    results_df.to_csv(results_file_path, index=False)

    print(f"Processed {tissue} successfully.")


Using device: cuda
Epoch [1/100], Loss: 6254.7686
Epoch [2/100], Loss: 5470.3965
Epoch [3/100], Loss: 4999.7456
Epoch [4/100], Loss: 4509.5044
Epoch [5/100], Loss: 4092.1528
Epoch [6/100], Loss: 3748.9604
Epoch [7/100], Loss: 3441.2361
Epoch [8/100], Loss: 3170.8616
Epoch [9/100], Loss: 2972.5476
Epoch [10/100], Loss: 2776.9275
Epoch [11/100], Loss: 2598.2104
Epoch [12/100], Loss: 2437.5288
Epoch [13/100], Loss: 2301.6252
Epoch [14/100], Loss: 2171.8333
Epoch [15/100], Loss: 2056.2671
Epoch [16/100], Loss: 1962.5521
Epoch [17/100], Loss: 1870.1027
Epoch [18/100], Loss: 1777.7982
Epoch [19/100], Loss: 1694.8459
Epoch [20/100], Loss: 1610.7327
Epoch [21/100], Loss: 1547.0195
Epoch [22/100], Loss: 1479.4385
Epoch [23/100], Loss: 1422.4302
Epoch [24/100], Loss: 1370.3320
Epoch [25/100], Loss: 1321.1257
Epoch [26/100], Loss: 1279.3480
Epoch [27/100], Loss: 1241.6064
Epoch [28/100], Loss: 1208.7111
Epoch [29/100], Loss: 1178.7463
Epoch [30/100], Loss: 1149.4192
Epoch [31/100], Loss: 1128.827

In [None]:
tissues= ['humanGBM', 'humanskin', 'humanthymus', 'humanspleen', 'humantonsil', 'mousekidney', 'mouseintestine', 'mousecolon', 'mousespleen']

for tissue in tissues:
    rna_data = None
    protein_data = None

    for filename in os.listdir(data_dir):
        file_path = os.path.join(data_dir, filename)
        if tissue in filename and filename.endswith("RNA.tsv.gz"):
            rna_data = pd.read_csv(file_path, sep="\t")
        elif tissue in filename and filename.endswith("protein.tsv.gz"):
            protein_data = pd.read_csv(file_path, sep="\t")

    print(rna_data.shape)
    print(protein_data.shape)

(1710, 23900)
(1710, 231)
(1691, 15487)
(1691, 284)
(2500, 28279)
(2500, 284)
(2494, 20237)
(2494, 284)
(2492, 28418)
(2492, 284)
(2419, 23751)
(2419, 200)
(902, 20445)
(902, 200)
(2037, 19469)
(2037, 200)
(1303, 19924)
(1303, 200)
