In [13]:
import warnings
warnings.resetwarnings()

import scprep
import matplotlib.pyplot as plt
import gc
    
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric
from torch.nn.functional import relu, softplus
from torch.nn import Linear, Module, Dropout, MSELoss, CrossEntropyLoss, BatchNorm1d

from torch_geometric.nn import GCNConv, GATConv, GraphNorm
from torch_geometric.data import Data
from torch_sparse import SparseTensor
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, normalized_mutual_info_score
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.cluster import SpectralClustering

import pandas as pd
import numpy as np
import random
import optuna

import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
device = 0
device = torch.device("cuda:{}".format(device) if torch.cuda.is_available() else "cpu")

from tqdm import tqdm

from sklearn.metrics import mean_squared_error as mse

In [14]:
def get_cluster_metrics(pred, labels):
    ari_res = []
    ami_res = []
    nmi_res = []

#     try:
#         r.assign("data", pred.T)
#         seurat = r('''
#             countsData = data.frame(data)
#             pbmc <- CreateSeuratObject(counts = countsData, project = "thal_single_cell", min.cells = 1, min.features = 1)
#             pbmc <- FindVariableFeatures(pbmc, selection.method = "vst", verbose=FALSE)
#             all.genes <- rownames(pbmc)
#             pbmc <- ScaleData(pbmc, features = all.genes, verbose=FALSE)
#             pbmc <- RunPCA(pbmc, verbose=FALSE)
#             pbmc <- FindNeighbors(pbmc, verbose=FALSE)
#             pbmc <- FindClusters(pbmc, verbose=FALSE)
#             Idents(pbmc)
#         ''')
#         ari_res.append(adjusted_rand_score(labels, seurat))
#         ami_res.append(adjusted_mutual_info_score(labels, seurat))
#         nmi_res.append(normalized_mutual_info_score(labels, seurat))
#     except Exception as e:
#         pass

    pred_ = KMeans(n_clusters=len(np.unique(labels)), random_state=42).fit_predict(pred)

    ari_res.append(adjusted_rand_score(labels, pred_))
    ami_res.append(adjusted_mutual_info_score(labels, pred_))
    nmi_res.append(normalized_mutual_info_score(labels, pred_))

    warnings.filterwarnings("error")

    affinities = ['cosine', 'linear', 'poly']

    for i in affinities:
        try:
            pred_ = SpectralClustering(
                n_clusters=len(np.unique(labels)), 
                random_state=42, 
                affinity=i
            ).fit_predict(pred)
            ari_res.append(adjusted_rand_score(labels, pred_))
            ami_res.append(adjusted_mutual_info_score(labels, pred_))
            nmi_res.append(normalized_mutual_info_score(labels, pred_))
        except:
            ari_res.append(0)
            ami_res.append(0)
            nmi_res.append(0)

    warnings.resetwarnings()
    
    return max(ari_res), max(ami_res), max(nmi_res)

In [15]:
def get_topX(X):
    return X * np.array(X > np.percentile(X, 85), dtype=int)

In [16]:
def get_adj(x):
    adj = SparseTensor(
        row= torch.tensor(np.array(x.nonzero()))[0], 
        col= torch.tensor(np.array(x.nonzero()))[1], 
        sparse_sizes=(x.shape[0], x.shape[0])
    ).to(device)
    return adj

In [17]:
def get_data(X, metric='linear'):
    dist = pairwise_kernels(X, metric=metric)
    dist_x = get_topX(dist)
    return torch.tensor(X.values, dtype=torch.float).to(device), get_adj(dist_x)

In [18]:
def get_data_for_i(i):
    df = pd.read_csv('../data/{}/data.csv.gz'.format(i), index_col=0)
    tmp = np.sign(df)
    cols = (np.sum(tmp) > int((df.shape[0])*0.05))
    rows = (np.sum(tmp, axis=1) > int((df.shape[1])*0.05))
    df = np.log(df.loc[rows, cols] + 1)
    df_norm = df.copy()
    df_norm = scprep.normalize.library_size_normalize(df_norm)    
    df_norm = scprep.transform.sqrt(df_norm)
    X_norm = pd.DataFrame(df_norm, columns=df.columns)
    labels = df.index
    data = torch.tensor(df_norm.values, dtype=torch.float).to(device)
    return df_norm, labels, data

In [19]:
def ZINBLoss(y_true, y_pred, theta, pi, eps=1e-10):
    """
    Compute the ZINB Loss.
    
    y_true: Ground truth data.
    y_pred: Predicted mean from the model.
    theta: Dispersion parameter.
    pi: Zero-inflation probability.
    eps: Small constant to prevent log(0).
    """
    
    # Negative Binomial Loss
    nb_terms = -torch.lgamma(y_true + theta) + torch.lgamma(y_true + 1) + torch.lgamma(theta) \
               - theta * torch.log(theta + eps) \
               + theta * torch.log(theta + y_pred + eps) \
               - y_true * torch.log(y_pred + theta + eps) \
               + y_true * torch.log(y_pred + eps)
    
    # Zero-Inflation
    zero_inflated = torch.log(pi + (1 - pi) * torch.pow(1 + y_pred / theta, -theta))
    
    result = -torch.sum(torch.log(pi + (1 - pi) * torch.pow(1 + y_pred / theta, -theta)) * (y_true < eps).float() \
                        + (1 - (y_true < eps).float()) * nb_terms)
    
    return torch.round(result, decimals=3)

In [20]:
def compute_loss(x_original, x_recon, z_mean, z_dropout, z_dispersion, alpha):
    """
    Compute the combined loss: ZINB Loss + MSE Loss.
    
    Parameters:
    - x_original: Original data matrix.
    - x_recon: Reconstructed matrix from the model.
    - z_mean, z_dropout, z_dispersion: Outputs from the model, used for ZINB Loss calculation.
    - device: Device to which tensors should be moved before computation.
    - lambda_1, lambda_2: Weights for ZINB Loss and MSE Loss respectively.
    
    Returns:
    - total_loss: Combined loss value.
    """
    
    # Compute ZINB Loss (assuming ZINBLoss is a properly defined function or class)
    zinb_loss = ZINBLoss(x_original, z_mean, z_dispersion, z_dropout)
    
    # Compute MSE Loss
    mse_loss = MSELoss()(x_recon, x_original)
    
    # Combine the losses
    total_loss = alpha * zinb_loss + (1-alpha) * mse_loss
    
    return total_loss

In [31]:
class VGAE(Module):
    def __init__(
        self, input_dim, hidden0, hidden1, hidden2, 
        hidden3, 
        dropout1, dropout2, 
        dropout4
    ):
        super(VGAE, self).__init__()
        
        self.dropout1 = nn.Dropout(dropout1)
        self.dropout2 = nn.Dropout(dropout2)
        self.dropout4 = nn.Dropout(dropout4)
        
        # Encoder with 2 gat layers
        self.gat1 = GCNConv(input_dim, hidden1)
        self.gn1 = GraphNorm(hidden1)  # Batch normalization after first gat layer
        self.gat2_mean = GCNConv(hidden1, input_dim)
        self.gat2_dropout = GCNConv(hidden1, input_dim)
        self.gat2_dispersion = GCNConv(hidden1, input_dim)

        # Decoder with 2 Linear layers
        self.fc1 = Linear(input_dim, hidden2)
        self.bn2 = BatchNorm1d(hidden2)  # Batch normalization after first linear layer
        self.fc2 = Linear(hidden2, input_dim)
        
        # gene_recon
        self.graph_norm5 = GraphNorm(hidden3)
        self.graph_norm8 = GraphNorm(hidden0)
        
        self.gcn5 = GCNConv(hidden0, hidden3)
        self.gcn8 = GCNConv(hidden3, hidden0)

        self.batch_norm1 = BatchNorm1d(input_dim)
        self.batch_norm2 = BatchNorm1d(hidden0)
        
    def encode(self, x, adj):
        x = relu(self.gn1(self.gat1(x, adj)))  # Apply ReLU and GraphNorm
        x = self.dropout1(x)
        
        z_mean = torch.exp(self.gat2_mean(x, adj.t()))
        z_dropout = torch.sigmoid(self.gat2_dropout(x, adj.t()))
        z_dispersion = torch.exp(self.gat2_dispersion(x, adj.t()))
        return z_mean, z_dropout, z_dispersion

    def decode(self, z):
        z = relu(self.bn2(self.fc1(z)))  # Apply ReLU and BatchNorm
        z = self.dropout2(z)
        return torch.sigmoid(self.fc2(z))
    
    def forward(self, x, adj, x_t, adj_t, ):
        z_mean, z_dropout, z_dispersion = self.encode(x, adj.t())
        x_recon = self.decode(z_mean) + self.batch_norm1(x) + self.batch_norm2(x_t).T
        return x_recon, z_mean, z_dropout, z_dispersion

In [33]:
res = []

for i in tqdm(dir_list):
    df_norm, labels, data = get_data_for_i(i)
    x, adj = get_data(df_norm)
    x_t, adj_t = get_data(df_norm.T)
    torch.cuda.empty_cache()

    input_dim = df_norm.shape[1]
    hidden0 = df_norm.shape[0]
    
    alpha=0.05
    dropout1=0.2
    dropout2=0.4
    epochs=100
    hidden1=128
    hidden2=1024
    lr=0.0001
    
    model = VGAE(input_dim, hidden0, hidden1, hidden2, 
                 hidden3, 
                 dropout1, dropout2, 
                 dropout4
                ).to(device)
    optimizer_name = 'Adam'
    optimizer = getattr(torch.optim, optimizer_name)(
        model.parameters(), 
        lr=lr, 
    )

    losses = []
    for epoch in tqdm(range(epochs)): 
        # Forward pass
        x_recon, z_mean, z_dropout, z_dispersion = model(x, adj, x_t, adj_t)

        # Compute the ZINB Loss using the outputs from the model
        loss = compute_loss(x, x_recon, z_mean, z_dispersion, z_dropout, alpha).to(device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 

        losses.append(loss.item())

    del model, optimizer
    torch.cuda.empty_cache()

    pred = x_recon.cpu().detach().numpy()
    res.append(get_cluster_metrics(pred, labels))

    del x_recon, z_mean, z_dropout, z_dispersion, df_norm, labels, data, x, adj, x_t, adj_t
    gc.collect()
    torch.cuda.empty_cache()

In [34]:
pd.DataFrame(res, columns=['ARI', 'AMI', 'NMI'], index=dir_list)

[I 2023-09-13 12:50:14,674] Using an existing study with name 'cell+x' instead of creating a new one.
  0%|          | 0/15 [00:00<?, ?it/s]
  0%|          | 0/900 [00:00<?, ?it/s][A
  0%|          | 1/900 [00:00<06:45,  2.22it/s][A
  0%|          | 2/900 [00:00<06:18,  2.37it/s][A
  0%|          | 3/900 [00:01<06:08,  2.43it/s][A
  0%|          | 4/900 [00:01<06:04,  2.46it/s][A
  1%|          | 5/900 [00:02<06:01,  2.48it/s][A
  1%|          | 6/900 [00:02<05:58,  2.49it/s][A
  1%|          | 7/900 [00:02<05:56,  2.50it/s][A
  1%|          | 8/900 [00:03<05:55,  2.51it/s][A
  1%|          | 9/900 [00:03<05:54,  2.51it/s][A
  1%|          | 10/900 [00:04<05:53,  2.52it/s][A
  1%|          | 11/900 [00:04<05:54,  2.51it/s][A
  1%|▏         | 12/900 [00:04<05:53,  2.51it/s][A
  1%|▏         | 13/900 [00:05<05:53,  2.51it/s][A
  2%|▏         | 14/900 [00:05<05:53,  2.50it/s][A
  2%|▏         | 15/900 [00:06<05:54,  2.50it/s][A
  2%|▏         | 16/900 [00:06<05:54,  2.50it

KeyboardInterrupt: 

[100, 300, 500, 700, 900]