In [10]:
import numpy as np
import pandas as pd
import os
import scanpy as sc
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
import torch
from model import *
import torch.optim as optim
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
import warnings
warnings.filterwarnings("ignore")

In [12]:
data_dir = 'spatial_datasets/GSE213264_RAW/'
results_dir = 'results_dca_2/'

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def build_knn_graph(spatial_data, k=15):
    neighbors = NearestNeighbors(n_neighbors=k, metric='euclidean')
    neighbors.fit(spatial_data)
    adjacency_matrix = neighbors.kneighbors_graph(spatial_data).toarray()
    normalized_matrix = adjacency_matrix / adjacency_matrix.sum(axis=1, keepdims=True)
    return torch.FloatTensor(normalized_matrix).to(device) 

def random_walk_with_restart(adjacency_matrix, alpha=0.15, max_iter=100):
    num_nodes = adjacency_matrix.shape[0]
    restart_vector = torch.ones(num_nodes).to(device) / num_nodes  

    walk_vectors = torch.zeros((num_nodes, num_nodes), device=device) 

    for node_idx in range(num_nodes):
        current_prob = torch.zeros(num_nodes, device=device)
        current_prob[node_idx] = 1  

        for _ in range(max_iter):
            current_prob = (1 - alpha) * adjacency_matrix.matmul(current_prob) + alpha * restart_vector

        walk_vectors[node_idx] = current_prob

    return walk_vectors

def perform_dca_random_walk(adjacency_matrix, n_components=16):
    node_scores = random_walk_with_restart(adjacency_matrix)
    pca = PCA(n_components=n_components)
    spatial_coords = pca.fit_transform(node_scores.cpu().numpy())  
    return torch.FloatTensor(spatial_coords).to(device) 

tissues = ['humanGBM', 'humanskin', 'humanthymus', 'humanspleen', 'humantonsil', 'mousekidney', 'mouseintestine', 'mousecolon', 'mousespleen']

for tissue in tissues:
    rna_data = None
    protein_data = None

    for filename in os.listdir(data_dir):
        file_path = os.path.join(data_dir, filename)
        if tissue in filename and filename.endswith("RNA.tsv.gz"):
            rna_data = pd.read_csv(file_path, sep="\t")
        elif tissue in filename and filename.endswith("protein.tsv.gz"):
            protein_data = pd.read_csv(file_path, sep="\t")

    rna_data.columns = rna_data.columns.astype(str)
    protein_data.columns = protein_data.columns.astype(str)

    rna_data = rna_data.sort_values(by='X')
    protein_data = protein_data.sort_values(by='X')

    rna_data = rna_data.reset_index(drop=True)
    protein_data = protein_data.reset_index(drop=True)
    rna_data.index = rna_data.index.astype(str)
    protein_data.index = protein_data.index.astype(str)

    rna_data[['X', 'Y']] = rna_data['X'].str.split('x', expand=True)
    rna_data['X'] = pd.to_numeric(rna_data['X'], errors='coerce')
    rna_data['Y'] = pd.to_numeric(rna_data['Y'], errors='coerce')
    spatial = rna_data[['X', 'Y']].copy()
    rna_data.drop(['X', 'Y'], axis=1, inplace=True)
    protein_data.drop(['X'], axis=1, inplace=True)

    rna_train, rna_test = train_test_split(rna_data, test_size=0.2, random_state=42)
    protein_train = protein_data.loc[rna_train.index]
    protein_test = protein_data.loc[rna_test.index]

    adjacency_matrix = build_knn_graph(spatial.values) 
    spatial_coords = perform_dca_random_walk(adjacency_matrix)

    adata_rna_train = sc.AnnData(rna_train)
    sc.pp.normalize_total(adata_rna_train, target_sum=1e4)
    sc.pp.log1p(adata_rna_train)
    sc.pp.highly_variable_genes(adata_rna_train, n_top_genes=4000, flavor='seurat', subset=True)
    counts_norm = adata_rna_train.X
    rna_counts_norm = torch.FloatTensor(counts_norm).to(device)

    adata_protein_train = sc.AnnData(protein_train)
    sc.pp.normalize_total(adata_protein_train, target_sum=1e4)
    sc.pp.log1p(adata_protein_train)
    counts_norm = adata_protein_train.X
    protein_counts_norm = torch.FloatTensor(counts_norm).to(device)

    spatial_train = torch.FloatTensor(spatial.loc[rna_train.index].values).to(device)
    combined_data = torch.cat([rna_counts_norm, protein_counts_norm, spatial_train], dim=1).to(device)

    input_dim = combined_data.shape[1]
    latent_dim = 32
    model = VAE(input_dim, latent_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    num_epochs = 100
    batch_size = 64

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        permutation = torch.randperm(combined_data.size(0))

        for i in range(0, combined_data.size(0), batch_size):
            optimizer.zero_grad()
            indices = permutation[i:i+batch_size]
            batch_data = combined_data[indices]

            reconstructed_data, mean, logvar = model(batch_data)
            loss = vae_loss(reconstructed_data, batch_data, mean, logvar, lambda_kl=0.0001)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(combined_data):.4f}')

    model.eval()
    with torch.no_grad():
        combined_data = torch.cat([rna_counts_norm, torch.zeros(rna_counts_norm.shape[0], protein_counts_norm.shape[1]).to(device), spatial_train], dim=1)
        reconstructed_data, mean, logvar = model(combined_data)
        reconstructed_protein_counts = reconstructed_data[:, rna_counts_norm.shape[1]:rna_counts_norm.shape[1] + protein_test.shape[1]]

        rmse = np.sqrt(mean_squared_error(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy()))
        pcc = pd.DataFrame(protein_counts_norm.cpu().numpy()).corrwith(pd.DataFrame(reconstructed_protein_counts.cpu().numpy()), axis=1, method='pearson')
        avg_corr_pearson = pcc.mean()
        ssim_val = ssim(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy(), data_range=reconstructed_protein_counts.cpu().numpy().max() - reconstructed_protein_counts.cpu().numpy().min())

        results_df = pd.DataFrame({
            'RMSE': [rmse],
            'Pearson Correlation': [avg_corr_pearson],
            'SSIM': ssim_val
        })

        results_file_path = os.path.join(results_dir, f"{tissue}_training_results.csv")
        results_df.to_csv(results_file_path, index=False)

        adata_rna_test = sc.AnnData(rna_test)
        sc.pp.normalize_total(adata_rna_test, target_sum=1e4)
        sc.pp.log1p(adata_rna_test)
        counts_norm = adata_rna_test[:, adata_rna_train.var_names].X
        rna_counts_norm = torch.FloatTensor(counts_norm).to(device)

        adata_protein_test = sc.AnnData(protein_test)
        sc.pp.normalize_total(adata_protein_test, target_sum=1e4)
        sc.pp.log1p(adata_protein_test)
        counts_norm = adata_protein_test.X
        protein_counts_norm = torch.FloatTensor(counts_norm).to(device)

        spatial_test = torch.FloatTensor(spatial.loc[rna_test.index].values).to(device)

        combined_data = torch.cat([rna_counts_norm, torch.zeros(rna_counts_norm.shape[0], protein_counts_norm.shape[1]).to(device), spatial_test], dim=1)
        reconstructed_data, mean, logvar = model(combined_data)
        reconstructed_protein_counts = reconstructed_data[:, rna_counts_norm.shape[1]:rna_counts_norm.shape[1] + protein_test.shape[1]]

        rmse = np.sqrt(mean_squared_error(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy()))
        pcc = pd.DataFrame(protein_counts_norm.cpu().numpy()).corrwith(pd.DataFrame(reconstructed_protein_counts.cpu().numpy()), axis=1, method='pearson')
        avg_corr_pearson = pcc.mean()
        ssim_val = ssim(protein_counts_norm.cpu().numpy(), reconstructed_protein_counts.cpu().numpy(), data_range=reconstructed_protein_counts.cpu().numpy().max() - reconstructed_protein_counts.cpu().numpy().min())

        results_df = pd.DataFrame({
            'RMSE': [rmse],
            'Pearson Correlation': [avg_corr_pearson],
            'SSIM': ssim_val
        })

        results_file_path = os.path.join(results_dir, f"{tissue}_results.csv")
        results_df.to_csv(results_file_path, index=False)

    print(f"Processed {tissue} successfully.")


Using device: cuda
Epoch [1/100], Loss: 4872.5705
Epoch [2/100], Loss: 2596.9005
Epoch [3/100], Loss: 1779.2873
Epoch [4/100], Loss: 1462.5997
Epoch [5/100], Loss: 1337.2755
Epoch [6/100], Loss: 1284.2700
Epoch [7/100], Loss: 1252.2813
Epoch [8/100], Loss: 1228.3434
Epoch [9/100], Loss: 1223.2982
Epoch [10/100], Loss: 1213.6352
Epoch [11/100], Loss: 1215.1675
Epoch [12/100], Loss: 1209.4060
Epoch [13/100], Loss: 1204.8574
Epoch [14/100], Loss: 1200.9913
Epoch [15/100], Loss: 1196.7482
Epoch [16/100], Loss: 1196.7839
Epoch [17/100], Loss: 1199.9000
Epoch [18/100], Loss: 1192.6798
Epoch [19/100], Loss: 1194.8823
Epoch [20/100], Loss: 1194.0638
Epoch [21/100], Loss: 1193.8033
Epoch [22/100], Loss: 1190.3849
Epoch [23/100], Loss: 1189.7940
Epoch [24/100], Loss: 1185.8849
Epoch [25/100], Loss: 1190.2297
Epoch [26/100], Loss: 1180.0890
Epoch [27/100], Loss: 1171.3156
Epoch [28/100], Loss: 1127.2103
Epoch [29/100], Loss: 1095.9123
Epoch [30/100], Loss: 1077.4171
Epoch [31/100], Loss: 1071.062