In [1]:
import warnings
warnings.resetwarnings()

import scprep
import matplotlib.pyplot as plt
import gc
    
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric
from torch.nn.functional import relu, softplus
from torch.nn import Linear, Module, Dropout, MSELoss, CrossEntropyLoss, BatchNorm1d

from torch_geometric.nn import GCNConv, GATConv, GraphNorm
from torch_geometric.data import Data
from torch_sparse import SparseTensor
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, normalized_mutual_info_score
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.cluster import SpectralClustering

import pandas as pd
import numpy as np
import random
import optuna

import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
device = 1
device = torch.device("cuda:{}".format(device) if torch.cuda.is_available() else "cpu")

from tqdm import tqdm

from sklearn.metrics import mean_squared_error as mse

In [2]:
def get_topX(X):
    return X * np.array(X > np.percentile(X, 85), dtype=int)

In [3]:
def get_adj(x):
    adj = SparseTensor(
        row= torch.tensor(np.array(x.nonzero()))[0], 
        col= torch.tensor(np.array(x.nonzero()))[1], 
        sparse_sizes=(x.shape[0], x.shape[0])
    ).to(device)
    return adj

In [4]:
def get_data(X, metric='linear'):
    dist = pairwise_kernels(X, metric=metric)
    dist_x = get_topX(dist)
    return torch.tensor(X.values, dtype=torch.float).to(device), get_adj(dist_x)

In [5]:
def get_data_for_i(i):
    df = pd.read_csv('../data/{}/data.csv.gz'.format(i), index_col=0)
    tmp = np.sign(df)
    cols = (np.sum(tmp) > int((df.shape[0])*0.05))
    rows = (np.sum(tmp, axis=1) > int((df.shape[1])*0.05))
    df = np.log(df.loc[rows, cols] + 1)
    df_norm = df.copy()
    df_norm = scprep.normalize.library_size_normalize(df_norm)    
    df_norm = scprep.transform.sqrt(df_norm)
    X_norm = pd.DataFrame(df_norm, columns=df.columns)
    labels = df.index
    data = torch.tensor(df_norm.values, dtype=torch.float).to(device)
    return df_norm, labels, data

In [6]:
def ZINBLoss(y_true, y_pred, theta, pi, eps=1e-10):
    """
    Compute the ZINB Loss.
    
    y_true: Ground truth data.
    y_pred: Predicted mean from the model.
    theta: Dispersion parameter.
    pi: Zero-inflation probability.
    eps: Small constant to prevent log(0).
    """
    
    # Negative Binomial Loss
    nb_terms = -torch.lgamma(y_true + theta) + torch.lgamma(y_true + 1) + torch.lgamma(theta) \
               - theta * torch.log(theta + eps) \
               + theta * torch.log(theta + y_pred + eps) \
               - y_true * torch.log(y_pred + theta + eps) \
               + y_true * torch.log(y_pred + eps)
    
    # Zero-Inflation
    zero_inflated = torch.log(pi + (1 - pi) * torch.pow(1 + y_pred / theta, -theta))
    
    result = -torch.sum(torch.log(pi + (1 - pi) * torch.pow(1 + y_pred / theta, -theta)) * (y_true < eps).float() \
                        + (1 - (y_true < eps).float()) * nb_terms)
    
    return torch.round(result, decimals=3)

In [7]:
def compute_loss(x_original, x_recon, z_mean, z_dropout, z_dispersion, alpha):
    """
    Compute the combined loss: ZINB Loss + MSE Loss.
    
    Parameters:
    - x_original: Original data matrix.
    - x_recon: Reconstructed matrix from the model.
    - z_mean, z_dropout, z_dispersion: Outputs from the model, used for ZINB Loss calculation.
    - device: Device to which tensors should be moved before computation.
    - lambda_1, lambda_2: Weights for ZINB Loss and MSE Loss respectively.
    
    Returns:
    - total_loss: Combined loss value.
    """
    
    # Compute ZINB Loss (assuming ZINBLoss is a properly defined function or class)
    zinb_loss = ZINBLoss(x_original, z_mean, z_dispersion, z_dropout)
    
    # Compute MSE Loss
    mse_loss = MSELoss()(x_recon, x_original)
    
    # Combine the losses
    total_loss = alpha * zinb_loss + (1-alpha) * mse_loss
    
    return total_loss


In [8]:
class VGAE(Module):
    def __init__(
        self, input_dim, hidden0, hidden1, hidden2, 
        # hidden3, 
        dropout1, dropout2, 
        # dropout4
    ):
        super(VGAE, self).__init__()
        
        self.dropout1 = nn.Dropout(dropout1)
        self.dropout2 = nn.Dropout(dropout2)
        # self.dropout4 = nn.Dropout(dropout4)
        
        # Encoder with 2 gat layers
        self.gat1 = GCNConv(input_dim, hidden1)
        self.gn1 = GraphNorm(hidden1)  # Batch normalization after first gat layer
        self.gat2_mean = GCNConv(hidden1, input_dim)
        self.gat2_dropout = GCNConv(hidden1, input_dim)
        self.gat2_dispersion = GCNConv(hidden1, input_dim)

        # Decoder with 2 Linear layers
        self.fc1 = Linear(input_dim, hidden2)
        self.bn2 = BatchNorm1d(hidden2)  # Batch normalization after first linear layer
        self.fc2 = Linear(hidden2, input_dim)
        
        self.batch_norm1 = BatchNorm1d(input_dim)
        self.batch_norm2 = BatchNorm1d(hidden0)
        
    def encode(self, x, adj):
        x = relu(self.gn1(self.gat1(x, adj)))  # Apply ReLU and GraphNorm
        x = self.dropout1(x)
        
        z_mean = torch.exp(self.gat2_mean(x, adj.t()))
        z_dropout = torch.sigmoid(self.gat2_dropout(x, adj.t()))
        z_dispersion = torch.exp(self.gat2_dispersion(x, adj.t()))
        return z_mean, z_dropout, z_dispersion

    def decode(self, z):
        z = relu(self.bn2(self.fc1(z)))  # Apply ReLU and BatchNorm
        z = self.dropout2(z)
        return relu(self.fc2(z))

    def forward(self, x, adj, x_t, adj_t, ):
        z_mean, z_dropout, z_dispersion = self.encode(x, adj.t())
        x_recon = self.decode(z_mean) + self.batch_norm1(x) + self.batch_norm2(x_t).T
        return x_recon, z_mean, z_dropout, z_dispersion


In [9]:
res = []

alpha=0.05
dropout1=0.2
dropout2=0.4
epochs=100
hidden1=128
hidden2=1024
lr=0.0001

df_norm, labels, data = get_data_for_i('brosens')
x, adj = get_data(df_norm)
x_t, adj_t = get_data(df_norm.T)
torch.cuda.empty_cache()

input_dim = df_norm.shape[1]
hidden0 = df_norm.shape[0]

model = VGAE(input_dim, hidden0, hidden1, hidden2, 
             dropout1, dropout2, 
#                  dropout4
            ).to(device)
optimizer_name = 'Adam'
optimizer = getattr(torch.optim, optimizer_name)(
    model.parameters(), 
    lr=lr, 
)

losses = []
for epoch in tqdm(range(epochs)): 
    # Forward pass
    x_recon, z_mean, z_dropout, z_dispersion = model(x, adj, x_t, adj_t)

    # Compute the ZINB Loss using the outputs from the model
    loss = compute_loss(x, x_recon, z_mean, z_dispersion, z_dropout, alpha).to(device)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() 

    losses.append(loss.item())

pred = x_recon.cpu().detach().numpy()

100%|██████████| 100/100 [00:38<00:00,  2.57it/s]


In [10]:
pred

array([[ 1.9266362 , -0.8055201 ,  3.1580687 , ..., -1.2080903 ,
         1.3530047 , -0.46975672],
       [ 1.2539519 , -0.43764263, -0.49765468, ...,  3.139907  ,
         1.2629893 , -0.8915373 ],
       [ 2.9251595 , -0.66047966, -1.0427344 , ..., -1.1094288 ,
         2.0136003 ,  2.9641027 ],
       ...,
       [ 0.05846319, -0.6997969 , -0.4236164 , ...,  6.136872  ,
         0.03928235, -0.17400971],
       [-0.24071535, -0.46588382, -0.0536544 , ..., -0.12295473,
        -0.43186972, -0.1757173 ],
       [ 5.253864  , -0.03770646, -0.07805464, ...,  6.4354267 ,
        -0.6020671 , -0.03226513]], dtype=float32)

In [12]:
pd.DataFrame(pred).to_csv('result/scVGAE.csv')