In [1]:
import warnings
warnings.resetwarnings()

import scprep
import matplotlib.pyplot as plt
import gc
    
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric
from torch.nn.functional import relu, softplus
from torch.nn import Linear, Module, Dropout, MSELoss, CrossEntropyLoss, BatchNorm1d

from torch_geometric.nn import GCNConv, GATConv, GraphNorm
from torch_geometric.data import Data
from torch_sparse import SparseTensor
from sklearn.metrics.pairwise import pairwise_kernels

import pandas as pd
import numpy as np
import random
import optuna

import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
device = 0
device = torch.device("cuda:{}".format(device) if torch.cuda.is_available() else "cpu")

from tqdm import tqdm

from sklearn.metrics import mean_squared_error as mse

In [2]:
def get_topX(X):
    return X * np.array(X > np.percentile(X, 85), dtype=int)

In [3]:
def get_adj(x):
    adj = SparseTensor(
        row= torch.tensor(np.array(x.nonzero()))[0], 
        col= torch.tensor(np.array(x.nonzero()))[1], 
        sparse_sizes=(x.shape[0], x.shape[0])
    ).to(device)
    return adj

In [4]:
def get_data(X, metric='linear'):
    dist = pairwise_kernels(X, metric=metric)
    dist_x = get_topX(dist)
    return torch.tensor(X.values, dtype=torch.float).to(device), get_adj(dist_x)

In [5]:
def get_data_for_i(i):
    original_ = pd.read_csv('simulation/data.csv', index_col=0)
    df_ = pd.read_csv('simulation/drp_{}0.csv'.format(i), index_col=0)
    df_.index = [int(i) for i in df_.index]
    df_.columns = [int(i) for i in df_.columns]

    original_.columns = df_.columns
    original_.index = df_.index

    n = original_.size
    original_val = original_.values.copy()
    t = list(np.ndindex(original_.shape))
    random.Random(42).shuffle(t)

    mask = t[:int(len(t)/10 * i)]

    thr = np.sum(np.sign(df_)) > 0
    original_ = original_.loc[:, list(thr)]
    df_ = df_.loc[:, list(thr)]

    # original = original_.values
    original = np.log(original_+1)

    # df = df_.values
    df = np.log(df_+1)

    tmp = pd.DataFrame(thr)
    remove = [int(i) for i in tmp[tmp[0] == False].index]
    mask = [i for i in mask if i[1] not in remove]
    
    x, adj = get_data(df)
    data = torch.tensor(df.values, dtype=torch.float).to(device)
    return df, data, original, mask, x, adj

In [6]:
df, data, original, mask, x, adj = get_data_for_i(1)
origin = np.array([original.loc[i] for i in mask])

In [7]:
def ZINBLoss(y_true, y_pred, theta, pi, eps=1e-10):
    """
    Compute the ZINB Loss.
    
    y_true: Ground truth data.
    y_pred: Predicted mean from the model.
    theta: Dispersion parameter.
    pi: Zero-inflation probability.
    eps: Small constant to prevent log(0).
    """
    
    # Negative Binomial Loss
    nb_terms = -torch.lgamma(y_true + theta) + torch.lgamma(y_true + 1) + torch.lgamma(theta) \
               - theta * torch.log(theta + eps) \
               + theta * torch.log(theta + y_pred + eps) \
               - y_true * torch.log(y_pred + theta + eps) \
               + y_true * torch.log(y_pred + eps)
    
    # Zero-Inflation
    zero_inflated = torch.log(pi + (1 - pi) * torch.pow(1 + y_pred / theta, -theta))
    
    result = -torch.sum(torch.log(pi + (1 - pi) * torch.pow(1 + y_pred / theta, -theta)) * (y_true < eps).float() \
                        + (1 - (y_true < eps).float()) * nb_terms)
    
    return torch.round(result, decimals=3)

In [8]:
def compute_loss(x_original, x_recon, z_mean, z_dropout, z_dispersion, alpha):
    """
    Compute the combined loss: ZINB Loss + MSE Loss.
    
    Parameters:
    - x_original: Original data matrix.
    - x_recon: Reconstructed matrix from the model.
    - z_mean, z_dropout, z_dispersion: Outputs from the model, used for ZINB Loss calculation.
    - device: Device to which tensors should be moved before computation.
    - lambda_1, lambda_2: Weights for ZINB Loss and MSE Loss respectively.
    
    Returns:
    - total_loss: Combined loss value.
    """
    
    # Compute ZINB Loss (assuming ZINBLoss is a properly defined function or class)
    zinb_loss = ZINBLoss(x_original, z_mean, z_dispersion, z_dropout)
    
    # Compute MSE Loss
    mse_loss = MSELoss()(x_recon, x_original)
    
    # Combine the losses
    total_loss = alpha * zinb_loss + (1-alpha) * mse_loss
    
    return total_loss

In [9]:
class VGAE(Module):
    def __init__(
        self, trial, input_dim, hidden1, hidden2, dropout1, dropout2, 
    ):
        super(VGAE, self).__init__()
        
        self.dropout1 = nn.Dropout(dropout1)
        self.dropout2 = nn.Dropout(dropout2)
        
        # Encoder with 2 gat layers
        self.gat1 = GCNConv(input_dim, hidden1)
        self.gn1 = GraphNorm(hidden1)  # Batch normalization after first gat layer
        self.gat2_mean = GCNConv(hidden1, input_dim)
        self.gat2_dropout = GCNConv(hidden1, input_dim)
        self.gat2_dispersion = GCNConv(hidden1, input_dim)

        # Decoder with 2 Linear layers
        self.gcn1 = GCNConv(input_dim, hidden2)
        self.gn2 = GraphNorm(hidden2)  # Batch normalization after first linear layer
        self.gcn2 = GCNConv(hidden2, input_dim)
        
    def encode(self, x, adj):
        x = relu(self.gn1(self.gat1(x, adj)))  # Apply ReLU and GraphNorm
        x = self.dropout1(x)
        
        z_mean = softplus(self.gat2_mean(x, adj.t()))
        z_dropout = torch.sigmoid(self.gat2_dropout(x, adj.t()))
        z_dispersion = torch.exp(self.gat2_dispersion(x, adj.t()))
        return z_mean, z_dropout, z_dispersion

    def decode(self, z, adj):
        z = relu(self.gn2(self.gcn1(z, adj.t())))  # Apply ReLU and BatchNorm
        z = self.dropout2(z)
        return torch.sigmoid(self.gcn2(z, adj.t()))

    def forward(self, x, adj):
        z_mean, z_dropout, z_dispersion = self.encode(x, adj.t())
        x_recon = self.decode(z_mean, adj.t())
        return x_recon, z_mean, z_dropout, z_dispersion

In [10]:
def objective(trial):

    input_dim = df.shape[1]
    hidden1 = trial.suggest_categorical('hidden1', [128, 256, 512, 1024])
    hidden2 = trial.suggest_categorical('hidden2', [128, 256, 512, 1024])
    
    dropout1 = trial.suggest_categorical("dropout1", [i/10 for i in range(1, 6)])
    dropout2 = trial.suggest_categorical("dropout2", [i/10 for i in range(1, 6)])
    
    alpha = trial.suggest_categorical("alpha", [0.01, 0.05, 0.1, 0.5, 0.9, 0.95, 0.99])
    epochs = trial.suggest_categorical('epochs', list(range(500, 10500, 500)))
    lr = trial.suggest_categorical("lr", [0.01, 0.001, 0.0001])

    model = VGAE(trial, input_dim, hidden1, hidden2, dropout1, dropout2, ).to(device)
    optimizer_name = 'Adam'
    optimizer = getattr(torch.optim, optimizer_name)(
        model.parameters(), 
        lr=lr, 
    )

    losses = []
    for epoch in tqdm(range(epochs)): 
        # Forward pass
        x_recon, z_mean, z_dropout, z_dispersion = model(x, adj)

        # Compute the ZINB Loss using the outputs from the model
        loss = compute_loss(x, x_recon, z_mean, z_dispersion, z_dropout, alpha).to(device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 

        losses.append(loss.item())
#         print(loss.item())

    pred = x_recon.cpu().detach().numpy()
    pred = pd.DataFrame(pred, columns=df.columns, index=df.index)
    predict = np.array([pred.loc[i] for i in mask])
    
    return mse(origin, predict)

In [11]:
study = optuna.create_study(
    direction="minimize",
    storage="sqlite:///vgae_gcn.sqlite3",
    study_name="vgae_gcn",
    load_if_exists=True,
)

study.optimize(objective, n_trials=1000, gc_after_trial=True)

[I 2023-08-27 17:24:20,063] A new study created in RDB with name: vgae_gcn
100%|██████████| 9500/9500 [12:57<00:00, 12.22it/s]
[I 2023-08-27 17:37:25,310] Trial 0 finished with value: 0.2304583991921863 and parameters: {'alpha': 0.95, 'dropout1': 0.4, 'dropout2': 0.3, 'epochs': 9500, 'hidden1': 512, 'hidden2': 256, 'lr': 0.01}. Best is trial 0 with value: 0.2304583991921863.
100%|██████████| 4000/4000 [05:00<00:00, 13.32it/s]
[I 2023-08-27 17:42:32,979] Trial 1 finished with value: 0.22665514183805494 and parameters: {'alpha': 0.01, 'dropout1': 0.4, 'dropout2': 0.3, 'epochs': 4000, 'hidden1': 256, 'hidden2': 128, 'lr': 0.001}. Best is trial 1 with value: 0.22665514183805494.
100%|██████████| 9000/9000 [13:37<00:00, 11.01it/s]
[I 2023-08-27 17:56:18,289] Trial 2 finished with value: 0.23882693296274782 and parameters: {'alpha': 0.5, 'dropout1': 0.5, 'dropout2': 0.2, 'epochs': 9000, 'hidden1': 512, 'hidden2': 1024, 'lr': 0.001}. Best is trial 1 with value: 0.22665514183805494.
100%|█████

KeyboardInterrupt: 