# GraphETM Dev Notebook

In [1]:
### Imports
# Local
from model.graphetm_trainer import GraphETMTrainer

# External
import time
import random
import numpy as np
import pandas as pd
import scanpy as sc
from typing import Dict, Any

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch_geometric as pyg

from tqdm.notebook import tqdm, trange
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score

import wandb

### Parameters
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mloicduch[0m ([33mloicduch-mcgill-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# Seeds
pyg.seed_everything(10) # random, np, torch, torch.cuda

In [3]:
# Helper
def evaluate_ari(cell_embed, adata):
    """
        This function is used to evaluate ARI using the lower-dimensional embedding
        cell_embed of the single-cell data
        :param cell_embed: a NxK single-cell embedding generated from NMF or scETM
        :param adata: single-cell AnnData data object (default to to mp_anndata)
        :return: ARI score of the clustering results produced by Louvain
    """
    adata.obsm['cell_embed'] = cell_embed
    sc.pp.neighbors(adata, use_rep="cell_embed", n_neighbors=30)
    sc.tl.louvain(adata, resolution=0.15)
    ari = adjusted_rand_score(adata.obs['Celltype'], adata.obs['louvain'])
    return ari

---
# Model
Model Implementation:

In [4]:
# @title ENCODER
class Encoder(nn.Module):
    """
        Encoder module for GraphETM.

        Attributes:
                q_theta: q_theta
                theta_act: theta_act
                mu_q_theta: mu_q_theta
                logsigma_q_theta: logsigma_q_theta
    """
    def __init__(
            self,
            num_topics: int,
            vocab_size: int,
            encoder_hidden_size: int,
            dropout: float = 0.5,
            theta_act: str = 'tanh'
    ):
        """
            Initialize the Encoder module.

            Args:
                num_topics: Number of topics.
                vocab_size: Size of vocabulary.
                encoder_hidden_size: Size of the hidden layer in the encoder.
                theta_act: Activation function for theta.
        """
        super().__init__()

        # Dropout
        self.thres_dropout = dropout
        self.dropout = nn.Dropout(dropout)

        # Theta Activation
        self.theta_act = self._get_activation(theta_act)

        ## define variational distribution for \theta_{1:D} via amortization
        self.q_theta = nn.Sequential(
            nn.Linear(vocab_size, encoder_hidden_size),
            self.theta_act,
            nn.Linear(encoder_hidden_size, encoder_hidden_size),
            self.theta_act,
        )
        self.mu_q_theta = nn.Linear(encoder_hidden_size, num_topics, bias=True)
        self.logsigma_q_theta = nn.Linear(encoder_hidden_size, num_topics, bias=True)

    def infer_topic_distribution(self, normalized_bows: torch.Tensor) -> torch.Tensor:
        """
            Returns a deterministic topic distribution for evaluation purposes bypassing the stochastic reparameterization step.

            Args:
                normalized_bows (torch.Tensor): Normalized bag-of-words input.

            Returns:
                torch.Tensor: Deterministic topic proportions.
        """
        q_theta = self.q_theta(normalized_bows)
        mu_theta = self.mu_q_theta(q_theta)
        theta = F.softmax(mu_theta, dim=-1)
        return theta

    def forward(self, bow_norm: torch.Tensor):
        """
        Returns parameters of the variational distribution for \theta.

        Args:
            bow_norm: (batch, V) Normalized batch of Bag-of-Words.

        Returns:
            mu_theta: mu_theta
            logsigma_theta: logsigma_theta
            kl_theta: kl_theta

        """
        q_theta = self.q_theta(bow_norm)
        if self.thres_dropout > 0:
            q_theta = self.dropout(q_theta)
        mu_theta = self.mu_q_theta(q_theta)
        logsigma_theta = self.logsigma_q_theta(q_theta)

        # KL[q(theta)||p(theta)] = lnq(theta) - lnp(theta)
        kl_theta = -0.5 * torch.sum(1 + logsigma_theta - mu_theta.pow(2) - logsigma_theta.exp(), dim=-1).mean()

        return mu_theta, logsigma_theta, kl_theta

    def _get_activation(self, act): # TODO: Redundant method.
        if act == 'tanh':
            act = nn.Tanh()
        elif act == 'relu':
            act = nn.ReLU()
        elif act == 'softplus':
            act = nn.Softplus()
        elif act == 'rrelu':
            act = nn.RReLU()
        elif act == 'leakyrelu':
            act = nn.LeakyReLU()
        elif act == 'elu':
            act = nn.ELU()
        elif act == 'selu':
            act = nn.SELU()
        elif act == 'glu':
            act = nn.GLU()
        else:
            print('Defaulting to tanh activations...')
            act = nn.Tanh()
        return act

In [5]:
# @title DECODER
class Decoder(nn.Module):
    """
        Decoder module for GraphETM.

        Attributes:
            rho: Word embedding matrix.
            alphas: Topic embedding matrix.
    """
    def __init__(
            self,
            embedding: torch.Tensor,
            num_topics: int,
            trainable: bool = True,
    ):
        """
            Initialize the Decoder module.

            Args:
                num_topics: Number of topics.
                vocab_size: Size of vocabulary.
                rho_size: Size of rho.

        """
        super().__init__()

        # TODO: Replace word embedding matrix with embeddings derived from the iBKH.
        # TODO: 1) Use GCN to process the iBKH and produce graph embeddings (iBKH-embeddings). (DONE)
        # TODO: 2) Replace rho in the Decoder with the iBKH-embeddings. (DONE)
        # TODO: 3) Pass the iBKH-embeddings through linear -> alpha; followed by softmax -> Beta (beta represents the Decoder weights).
        # TODO: 4) (Optional) Add graph reconstruction loss to the training objective.
        # TODO: Objective: The latent topic distribution theta for (scRNA and EHR) are multiplied with Beta (essentially grounding the latent topics with the knowledge).

        ## define the word embedding matrix \rho
        if trainable: # Trainable
            self.embedding = nn.Parameter(embedding.clone()) # V x L
        else: # Frozen
            self.register_buffer('embedding', embedding.clone()) # V x L

        ## define the matrix containing the topic embeddings
        self.alphas = nn.Linear(self.embedding.size(1), num_topics, bias=False)

    def get_beta(self):
        """
            Retrieve beta by doing softmax over the vocabulary dimension.

            Returns:
                Beta which represents the topic-word (or topic-feature) distributions.
        """
        logits = self.alphas(self.embedding)
        beta = F.softmax(logits, dim=0).transpose(1, 0) # K x V
        return beta

    def forward(self, theta):
        beta = self.get_beta()
        preds = torch.log(torch.mm(theta, beta) + 1e-6)
        return preds

In [6]:
# @title GraphETM
class GraphETM(nn.Module):
    def __init__(
            self,
            encoder_params: Dict[str, Dict[str, int]],
            theta_act: str,
            num_topics: int,
            embedding_sc: torch.Tensor = None,
            embedding_ehr: torch.Tensor = None,
            trainable_embeddings = True,
            dropout=0.2
    ):
        """
            Initialize the ETM model.

            Args:
                encoder_params: Dictionary of the parameters for the encoders. Dictionary {'sc': {str: Any}, 'ehr': {str: Any}}.
                    vocab_size: Size of vocabulary.
                    encoder_hidden_size: Size of the hidden layer in the encoder.
                theta_act: Activation function for theta.
                num_topics: Number of topics.
                embedding_sc:  Word embedding rho for single-celled RNA from a knowledge graph.
                embedding_ehr: Word embedding rho for Diseases from a knowledge graph.
                trainable_embeddings: Whether to fine-tune word embeddings.
                dropout: Dropout rate.

        """
        super(GraphETM, self).__init__()

        self.encoder_params = encoder_params

        self.enc_sc  = Encoder(**encoder_params['sc'],  num_topics=num_topics, dropout=dropout, theta_act=theta_act)
        self.enc_ehr = Encoder(**encoder_params['ehr'], num_topics=num_topics, dropout=dropout, theta_act=theta_act)
        self.dec_sc  = Decoder(embedding=embedding_sc,  num_topics=num_topics, trainable=trainable_embeddings)
        self.dec_ehr = Decoder(embedding=embedding_ehr, num_topics=num_topics, trainable=trainable_embeddings)

    # theta ~ mu + std N(0,1)
    def reparameterize(self, mu, logvar):
        """
            Returns a sample from a Gaussian distribution via reparameterization.
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul_(std).add_(mu)
        else:
            return mu

    def infer_topic_distribution(self, normalized_bows: torch.Tensor) -> torch.Tensor:
        """
            Returns a deterministic topic distribution for evaluation purposes bypassing the stochastic reparameterization step.

            Args:
                normalized_bows (torch.Tensor): Normalized bag-of-words input.

            Returns:
                torch.Tensor: Deterministic topic proportions.
        """
        theta = self.encoder.infer_topic_distribution(normalized_bows)
        return theta

    def _bow_forward(self, encoder, decoder, bow, aggregate=True):
        bow_raw  = bow # integer counts
        lengths  = bow_raw.sum(1, keepdim=True) + 1e-8
        bow_norm = bow_raw / lengths # Normalize

        mu, logvar, kld = encoder(bow_norm)
        z = self.reparameterize(mu, logvar)
        theta = F.softmax(z, dim=-1)

        preds = decoder(theta)
        rec_loss = -(preds * bow_raw).sum(1) / lengths.squeeze(1)
        if aggregate:
            rec_loss = rec_loss.mean()

        return {
            'rec_loss': rec_loss,
            'kl'      : kld,
            'theta'   : theta.detach(),
            'preds'   : preds.detach(),
        }

    def forward(self, bow_sc, bow_ehr):
        # Encoder-Decoder: ScRNA
        output_sc = self._bow_forward(
            bow=bow_sc,
            encoder=self.enc_sc,
            decoder=self.dec_sc)

        # Encoder-Decoder: EHR
        output_ehr = self._bow_forward(
            bow=bow_ehr,
            encoder=self.enc_ehr,
            decoder=self.dec_ehr)

        # Total ELBO Loss
        elbo_loss = (output_sc['rec_loss'] + output_ehr['rec_loss']).mean() + output_sc['kl'] + output_ehr['kl']

        # Update outputs
        output_sc['rec_loss'] = output_sc['rec_loss'].mean()
        output_ehr['rec_loss'] = output_ehr['rec_loss'].mean()
        return {
            'loss': elbo_loss,
            'sc' : output_sc,
            'ehr': output_ehr,
        }

---
# Training

In [7]:
### Device
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
# device = torch.device('cpu')
print(f'Using device: {device}')

Using device: mps


In [8]:
### Data
test_size = 0.2

# Load Rho (Graph Embeddings)
sc_indices  = np.load('inputs/GraphETM/id_embed_sc.npy')
ehr_indices = np.load('inputs/GraphETM/id_embed_ehr.npy')

embedding_full = torch.load('inputs/GraphETM/embedding_full.pt', weights_only=False)

embedding_sc  = embedding_full[sc_indices,  :]
embedding_ehr = embedding_full[ehr_indices, :]

# Load Input Data
X_sc  = torch.load('inputs/GraphETM/X_sc.pt',  weights_only=False)
X_ehr = torch.load('inputs/GraphETM/X_ehr.pt', weights_only=False)

X_sc,  X_sc_val  = train_test_split(X_sc , test_size=test_size, random_state=0)
X_ehr, X_ehr_val = train_test_split(X_ehr, test_size=test_size, random_state=0)

In [9]:
### Parameters
config = {
    'model': dict(
        theta_act='relu',
        num_topics = 50, # K = 50
        encoder_params = { ## Encoder parameters
            'sc': { # Encoder SC
                'vocab_size': X_sc.shape[1],
                'encoder_hidden_size': 64
            },
            'ehr': { # Encoder EHR
                'vocab_size': X_ehr.shape[1],
                'encoder_hidden_size': 64
            }
        }, ## Embedding Parameters
        embedding_sc  = embedding_sc ,
        embedding_ehr = embedding_ehr,
        trainable_embeddings = True
    ),

    'dataloader': dict(
        batch_size=32,
        shuffle=True
    ),

    'training': dict(
        lr = 0.001,
        epochs = 25,
    ),

    'device': device,
}


### Model
trainer = GraphETMTrainer(
    model = GraphETM(**config['model']).to(device), # Model
    dataloader_sc  = DataLoader(**config['dataloader'], dataset = X_sc ), # Dataloaders
    dataloader_ehr = DataLoader(**config['dataloader'], dataset = X_ehr),
    val_dataloader_sc  = DataLoader(**config['dataloader'], dataset = X_sc_val ),
    val_dataloader_ehr = DataLoader(**config['dataloader'], dataset = X_ehr_val),
    device = device,
    wandb_run = wandb.init(
        project ='GraphETM',
        group = 'GraphETM',
        name = f'GraphETM_{int(time.time())}',
        config=config, save_code=True) # Start Wandb
)

### Training
trainer.train(
    epochs = config['training']['epochs'],
    optimizer = optim.Adam(trainer.model.parameters(), lr=config['training']['lr']) # Optimizer
)
# TODO: Close the trainer.wandb instance.

Training GraphETM:   0%|          | 0/25 [00:00<?, ?epoch/s]

In [10]:
trainer.wandb.finish()

0,1
batch,▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇█████
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train/ehr/kld,▅▃▁▄▄▄▄▅▅▅▅▆▅▆▅▆▇▄▆▅▄▅▅▅▅▅▅▆▅█▆▇▅▅▅█▅▇▇▆
train/ehr/recon_loss,▇▇▃▃▅▁▃▃▂▃▅▅▂▄▃▇▅▄▄▄▂▁▃▄▂█▄▁▂▄▁▂▃▃▅▃▄▄▃▃
train/sc/kld,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/sc/recon_loss,█▇▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_loss,██▂▂▂▂▁▂▂▂▂▂▁▁▂▁▁▂▂▁▂▂▁▂▂▁▂▂▂▂▂▁▂▂▂▂▂▁▂▂
val/ehr/kld,▁▄▆▆▆▇▆▇▇▇▇▇▆▇▇▇▇██████▇█
val/ehr/recon_loss,█▃▂▂▂▂▂▁▂▁▂▂▁▂▁▂▁▂▂▁▁▁▁▁▂
val/sc/kld,█▅▄▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch,2149.0
epoch,24.0
train/ehr/kld,0.04152
train/ehr/recon_loss,2.4151
train/sc/kld,0.0
train/sc/recon_loss,6.90927
train/total_loss,9.3659
val/ehr/kld,0.02905
val/ehr/recon_loss,2.32473
val/sc/kld,0.0


In [None]:
# DONE