In [1]:
# Standard library imports
import os
import math
import random
from time import time
from collections import deque

# Scientific computing and data handling
import numpy as np
import pandas as pd
import scipy.stats as stats

# Machine learning and clustering
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances
from sklearn.decomposition import PCA
from sklearn.neighbors import kneighbors_graph

# Visualization
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize

# Single-cell and spatial analysis
import scanpy as sc
import harmonypy as hm
import igraph as ig
import leidenalg
from umap import umap_ as umap

# PyTorch and deep learning utilities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.distributions import Normal, kl_divergence

# spaVAE-specific imports
from spaVAE.spaVAE import SPAVAE
from spaVAE.SVGP import SVGP
from spaVAE.I_PID import PIDControl
from spaVAE.VAE_utils import *
from spaVAE.preprocess import normalize, geneSelection


lipid_N_reduced_list = pd.read_csv('data/negative_globus_pallidus_ion_clustering_using_logfc_patterns_over_time.csv').iloc[:, 0].tolist()
inducing_point_steps = 6
loc_range = 20
with open("data/lipid_names.txt", "r") as file:
    lipid_names = [line.strip() for line in file]
def fetch_data(id):
    return np.load('data/' + id + '_lipids.npy'), np.load('data/' + id + '_coords.npy')

def extract_best_features(adata, reduced_list):
    indices = np.array([elem in reduced_list for elem in adata.var_names.tolist()])
    adata = adata[:, indices]
    return adata



In [2]:
def fetch_preprocess_data(names):
    dataset = {}
    for sample in names:
        lipidomics_matrix, coordinates_matrix = fetch_data(sample)

        scaler = MinMaxScaler()
        coordinates_matrix = scaler.fit_transform(coordinates_matrix) * loc_range

        adata = sc.AnnData(lipidomics_matrix, dtype="float64")
        adata.var_names = lipid_names
            
        adata = extract_best_features(adata, lipid_N_reduced_list)
        adata = normalize(adata, logtrans_input=False, normalize_input=False, total_ion_current=True, size_factors=False)
        dataset[sample] = [adata, coordinates_matrix]
    return dataset

def normalize_across_brains(data_dict):
    normalized_dict = {}
    all_arrays = ()
    original_shapes = []
    for key, value in data_dict.items():
        x = value[0]
        original_shapes.append(x.shape)
        all_arrays = all_arrays + (x.X,)
    concatenated_array = np.concatenate(all_arrays, axis=0)     
    dobject = sc.AnnData(concatenated_array)
    dobject.raw = dobject.copy()
        
    sc.pp.log1p(dobject)
    sc.pp.scale(dobject)
    idx = 0
    L = 0
    for key, value in data_dict.items():
        num_samples = original_shapes[L][0]
        end_idx = idx + num_samples
        dat_matrix = sc.AnnData(dobject.X[idx:end_idx], dtype="float64")
        dat_matrix.raw = sc.AnnData(dobject.raw.X[idx:end_idx], dtype="float64") 
        normalized_dict[key] = [dat_matrix, value[1]]
        idx = end_idx
        L = L + 1
    return normalized_dict


In [3]:
def train(datasets, maxiter, samp):
    eps = 1e-5
    initial_inducing_points = np.mgrid[0:(1+eps):(1./inducing_point_steps), 0:(1+eps):(1./inducing_point_steps)].reshape(2, -1).T * loc_range
    print(initial_inducing_points.shape)

        
    noise = 0
    dropoutE = 0
    dropoutD = 0
    encoder_layers = [256, 128, 64]
    GP_dim = 1
    Normal_dim = 9
    decoder_layers = [64, 128, 256]
    dynamicVAE = True
    init_beta = 15
    min_beta = 10
    max_beta = 25
    KL_loss = 1
    fix_inducing_points = True
    fixed_gp_params = True
    kernel_scale = 20
    device = 'cpu'

    total_nt = 0
    for dset in datasets:
        total_nt = total_nt + dset[0].n_obs
    vae = SPAVAE(input_dim=datasets[0][0].n_vars, 
                 GP_dim=GP_dim, Normal_dim=Normal_dim, 
                 encoder_layers=encoder_layers, decoder_layers=decoder_layers,
            noise=noise, encoder_dropout=dropoutE, decoder_dropout=dropoutD,
            fixed_inducing_points=fix_inducing_points, initial_inducing_points=initial_inducing_points, 
            fixed_gp_params=fixed_gp_params, kernel_scale=kernel_scale, N_train=total_nt, KL_loss=KL_loss, dynamicVAE=dynamicVAE, 
            init_beta=init_beta, min_beta=min_beta, max_beta=max_beta, dtype=torch.float64, device=device)
    num_samples = 1
    lr = 1e-3
    weight_decay = 1e-6
    batch_size = 256
    model_file = 'checkpoints/' + str(samp) + '.pt'
    if not os.path.isfile(model_file): 
            t0 = time()
            vae.train_model(pos=datasets[0][1], ncounts=datasets[0][0].X,
                    lr=lr, weight_decay=weight_decay, batch_size=batch_size, num_samples=num_samples,
                    maxiter=maxiter, save_model=True, model_weights=model_file, print_kernel_scale=False)
            print('Training time: %d seconds.' % int(time() - t0))
    else:
            vae.load_model(model_file)
    return vae

In [4]:
def refine(sample_id, pred, dis, shape=4):
            refined_pred=[]
            pred=pd.DataFrame({"pred": pred}, index=sample_id)
            dis_df=pd.DataFrame(dis, index=sample_id, columns=sample_id)
            num_nbs = shape
            cccn = 0
            for i in range(len(sample_id)):
                index=sample_id[i]
                dis_tmp=dis_df.loc[index, :].sort_values()
                nbs=dis_tmp.iloc[0:(num_nbs+1)]
                nbs_pred=pred.loc[nbs.index, "pred"]
                self_pred=pred.loc[index, "pred"]
                v_c=nbs_pred.value_counts()
                if (v_c.loc[self_pred] <= num_nbs/2):
                    refined_pred.append(v_c.idxmax())
                    cccn = cccn + 1
                else:           
                    refined_pred.append(self_pred)
                if (i+1) % 1000 == 0:
                    print("Processed", i+1, "lines")
            return np.array(refined_pred)

def leiden_clustering(data, res=0.5, n_neighbors=10):
        adjacency_matrix = kneighbors_graph(data, n_neighbors=n_neighbors, mode='connectivity', include_self=False)
        sources, targets = adjacency_matrix.nonzero()
        weights = adjacency_matrix.data
        g = ig.Graph(directed=False)
        g.add_vertices(adjacency_matrix.shape[0])
        g.add_edges(list(zip(sources, targets)))
        g.es['weight'] = weights
        partition = leidenalg.find_partition(g, leidenalg.RBConfigurationVertexPartition, resolution_parameter=res)
        labels = partition.membership
        return labels

def get_colors_for_labels(labels):
        unique_labels = np.unique(labels)
        cmap = plt.get_cmap('tab20')
        colors = [cmap(i) for i in range(len(unique_labels))]
        label_to_color = {label: colors[i % len(colors)] for i, label in enumerate(unique_labels)}
        color_list = [label_to_color[label] for label in labels]
        return color_list

def plot_embeddings_2d(latent, colors, neighbors):
        umap_res = umap.UMAP(random_state=123, n_neighbors=neighbors).fit_transform(latent)
        spaVAE_umap = pd.DataFrame(umap_res, columns=['X1', 'X2'])
        plt.figure(figsize=(5, 5))
        plt.scatter(x=spaVAE_umap['X1'], y=spaVAE_umap['X2'], color=colors,s=1, cmap='viridis')
        plt.show()

In [7]:
def get_clusters(datasets, model, batch_ids, res=0.8, n_neigh=5):
        final_latent = model.batching_latent_samples(X= datasets[0][1], Y= datasets[0][0].X)
        pred_refined = np.array(leiden_clustering(final_latent, res=res, n_neighbors=n_neigh))
        return final_latent, pred_refined

def plot_results(datasets_comb, final_latent, clusters):
    segmentation_colors = get_colors_for_labels(clusters)
    idx = 0
    for i in range(len(datasets_comb)):
        plt.figure(figsize=(5, 5))
        plt.scatter(datasets_comb[i][1][:,0], datasets_comb[i][1][:, 1], c=segmentation_colors[idx: idx+ len(datasets_comb[i][1])], s=5)
        plt.xlabel("X Coordinates")
        plt.ylabel("Y Coordinates")
        plt.show()
        idx = idx + len(datasets_comb[i][1])

    plot_embeddings_2d(final_latent, segmentation_colors, 10)

In [None]:
batch_ids = [0 for i in range(18)]
brain_names = [['6moWT_0'], ['6moWT_1'], ['6moWT_2'], ['6moAD_0'], ['6moAD_1'], ['6moAD_2'], ['12moWT_0'], ['12moWT_1'], ['12moWT_2'], ['12moAD_0'], ['12moAD_1'], ['12moAD_2'], ['22moWT_0'], ['22moWT_1'], ['22moWT_2'], ['22moAD_0'], ['22moAD_1'], ['22moAD_2']]
for i in range(len(brain_names)):
    group_brain_names, group_batch_ids = brain_names[i], batch_ids[i]
    data_arr = list(normalize_across_brains(fetch_preprocess_data(group_brain_names)).values())
    model = train(data_arr, samp=group_brain_names, maxiter=50)
    latent, clusters = get_clusters(data_arr, model, group_batch_ids)
    plot_results(data_arr, latent, clusters)
