# GraphETM Dev Notebook

In [1]:
### Imports
## Local
from model.graphetm_trainer import GraphETMTrainer

## External
import numpy as np
import pandas as pd
from typing import Dict

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Torch-Geometric
import torch_geometric as pyg
from torch_geometric.nn import GCNConv
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader

# Sklearn
from sklearn.model_selection import train_test_split

# Plot
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import wandb

### Parameters
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mloicduch[0m ([33mloicduch-mcgill-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# Seeds
pyg.seed_everything(10) # random, np, torch, torch.cuda

---
# Model

In [3]:
class Encoder(nn.Module):
    """
        Encoder module for GraphETM.

        Attributes:
                q_theta: q_theta
                theta_act: theta_act
                mu_q_theta: mu_q_theta
                logsigma_q_theta: logsigma_q_theta
    """
    def __init__(
            self,
            num_topics: int,
            vocab_size: int,
            encoder_hidden_size: int,
            dropout: float = 0.5,
            theta_act: str = 'tanh'
    ):
        """
            Initialize the Encoder module.

            Args:
                num_topics: Number of topics.
                vocab_size: Size of vocabulary.
                encoder_hidden_size: Size of the hidden layer in the encoder.
                theta_act: Activation function for theta.
        """
        super().__init__()

        # Dropout
        self.thres_dropout = dropout
        self.dropout = nn.Dropout(dropout)

        # Theta Activation
        self.theta_act = self._get_activation(theta_act)

        ## define variational distribution for \theta_{1:D} via amortization
        self.q_theta = nn.Sequential(
            nn.Linear(vocab_size, encoder_hidden_size),
            self.theta_act,
            nn.Linear(encoder_hidden_size, 64),
            self.theta_act,
        )
        self.mu_q_theta = nn.Linear(64, num_topics, bias=True)
        self.logsigma_q_theta = nn.Linear(64, num_topics, bias=True)

    def infer_topic_distribution(self, normalized_bows: torch.Tensor) -> torch.Tensor:
        """
            Returns a deterministic topic distribution for evaluation purposes bypassing the stochastic reparameterization step.

            Args:
                normalized_bows (torch.Tensor): Normalized bag-of-words input.

            Returns:
                torch.Tensor: Deterministic topic proportions.
        """
        q_theta = self.q_theta(normalized_bows)
        mu_theta = self.mu_q_theta(q_theta)
        theta = F.softmax(mu_theta, dim=-1)
        return theta

    def forward(self, bow_norm: torch.Tensor, free_nat: int = 0.5):
        """
        Returns parameters of the variational distribution for \theta.

        Args:
            bow_norm: (batch, V) Normalized batch of Bag-of-Words.

        Returns:
            mu_theta: mu_theta
            logsigma_theta: logsigma_theta
            kl_theta: kl_theta

        """
        q_theta = self.q_theta(bow_norm)
        if self.thres_dropout > 0:
            q_theta = self.dropout(q_theta)
        mu_theta = self.mu_q_theta(q_theta)
        logsigma_theta = self.logsigma_q_theta(q_theta)

        # KL[q(theta)||p(theta)] = lnq(theta) - lnp(theta)
        kl_theta = -0.5 * torch.sum(1 + logsigma_theta - mu_theta.pow(2) - logsigma_theta.exp(), dim=1).mean()
        # Free-bits
        # kl_theta = torch.clamp(kl_theta, min=free_nat).mean()

        return mu_theta, logsigma_theta, kl_theta

    def _get_activation(self, act): # TODO: Redundant method.
        if act == 'tanh':
            act = nn.Tanh()
        elif act == 'relu':
            act = nn.ReLU()
        elif act == 'softplus':
            act = nn.Softplus()
        elif act == 'rrelu':
            act = nn.RReLU()
        elif act == 'leakyrelu':
            act = nn.LeakyReLU()
        elif act == 'elu':
            act = nn.ELU()
        elif act == 'selu':
            act = nn.SELU()
        elif act == 'glu':
            act = nn.GLU()
        else:
            print('Defaulting to tanh activations...')
            act = nn.Tanh()
        return act

In [4]:
class Decoder(nn.Module):
    """
        Decoder module for GraphETM.

        Attributes:
            rho: Word embedding matrix.
            alphas: Topic embedding matrix.
    """
    def __init__(
            self,
            in_dim: int,
            num_topics: int,
    ):
        """
            Initialize the Decoder module.

            Args:
                num_topics: Number of topics.

        """
        super().__init__()

        # Objective: The latent topic distribution theta for (scRNA and EHR) are multiplied with Beta (essentially grounding the latent topics with the knowledge).

        ## define the word embedding matrix \rho
        self.rho = None # V x L

        ## define the matrix containing the topic embeddings
        self.alphas = nn.Linear(in_dim, num_topics, bias=False)

    def get_beta(self):
        """
            Retrieve beta by doing softmax over the vocabulary dimension.

            Returns:
                Beta which represents the topic-word (or topic-feature) distributions.
        """
        logits = self.alphas(self.rho).T
        beta = F.softmax(logits, dim=1)
        return beta

    def forward(self, theta, rho):
        self.rho = rho # Update embeddings

        beta = self.get_beta()
        preds = torch.log(torch.mm(theta, beta) + 1e-8)
        return preds

In [5]:
class GraphFilter(nn.Module):
    def __init__(
            self,
            in_dim: int,
            hidden_dim: int,
            out_dim: int,
            edge_index: torch.LongTensor,
    ):
        super(GraphFilter, self).__init__()

        self.edge_index = edge_index

        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)

    def forward(self, embedding: torch.Tensor, batch=None):
        # Batch
        batch_ids = batch.n_id
        embedding_batch = embedding[batch_ids]

        # Forward
        x = self.conv1(embedding_batch, batch.edge_index)
        x = F.relu(x)
        x = self.conv2(x, batch.edge_index) # [N, out_dim]
        # TODO: Normalize/dropout potentially to be added.

        # Loss
        loss = self.loss(x, batch) # Graph Reconstruction Loss # DONE: (NEXT) adapt loss function for mini batch
        return x, loss # TODO: Return batch_ids? (sub_nodes)

    # Graph Reconstruction Loss
    def loss(self, x: torch.Tensor, batch):
        # Positive edges
        pos_edge_index = batch.edge_index  # shape [2, E_pos] # TODO: Get edge_index for batch instead (batch.edge_index)

        # Negative edges
        neg_edge_index = pyg_utils.negative_sampling(
            edge_index=pos_edge_index,
            num_nodes=x.size(0),
            num_neg_samples=pos_edge_index.size(1))

        # Gather embeddings
        src_pos = x[pos_edge_index[0]] # [E_pos, out_dim]
        dst_pos = x[pos_edge_index[1]] # [E_pos, out_dim]
        src_neg = x[neg_edge_index[0]] # [E_pos, out_dim]
        dst_neg = x[neg_edge_index[1]] # [E_pos, out_dim]

        # Inner-product score
        pos_scores = (src_pos * dst_pos).sum(dim=1)
        neg_scores = (src_neg * dst_neg).sum(dim=1)

        # Compute loss
        pos_loss = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
        neg_loss = F.binary_cross_entropy_with_logits(neg_scores, torch.ones_like(neg_scores))
        return pos_loss + neg_loss

In [6]:
class Model(nn.Module):
    def __init__(
            self,
            encoder_params  : Dict[str, Dict[str, int]],
            graphconv_params: Dict,
            theta_act: str,
            num_topics: int,
            embedding: torch.Tensor = None,
            embedding_dataloader = None,
            edge_index: torch.LongTensor = None,
            id_embed_sc : np.ndarray = None,
            id_embed_ehr: np.ndarray = None,
            trainable_embeddings = False, # Must be false to keep embeddings stable.
            dropout = 0.2
    ):
        """
            Initialize the ETM model.

            Args:
                encoder_params: Dictionary of the parameters for the encoders. Dictionary {'sc': {str: Any}, 'ehr': {str: Any}}.
                    vocab_size: Size of vocabulary.
                    encoder_hidden_size: Size of the hidden layer in the encoder.
                theta_act: Activation function for theta.
                num_topics: Number of topics.

                embedding: Initial embedding (also known as rho) computed from the knowledge graph (e.g.: TransE embeddings).
                id_embed_sc : Index map for the genes found in the Gene Expression (BoW) matrix input and the Knowledge Graph embedding genes. It should be a numpy list where each index maps to a gene in the embeddings. This allows aligning the relevant genes to the genes found in the embeddings.
                id_embed_ehr: Index map for the diseases found in the Electronic Health Record (BoW) matrix input and the Knowledge Graph embedding diseases. It should be a numpy list where each index maps to a disease in the embeddings. This allows aligning the relevant genes to the genes found in the embeddings.
                trainable_embeddings: Whether to fine-tune word embeddings.

                dropout: Dropout rate.

        """
        super(Model, self).__init__()

        self.encoder_params = encoder_params

        self.rho = nn.Parameter(embedding.clone(), requires_grad=trainable_embeddings) # V x L
        self.edge_index = edge_index
        self.id_embed_sc  = torch.tensor(id_embed_sc , dtype=torch.long)
        self.id_embed_ehr = torch.tensor(id_embed_ehr, dtype=torch.long)

        self.embedding_dataloader = embedding_dataloader

        in_dim = embedding.shape[1]
        self.graph_filter = GraphFilter(**graphconv_params, in_dim=in_dim, edge_index=self.edge_index)

        self.enc_sc  = Encoder(**encoder_params['sc'],  num_topics=num_topics, dropout=dropout, theta_act=theta_act)
        self.enc_ehr = Encoder(**encoder_params['ehr'], num_topics=num_topics, dropout=dropout, theta_act=theta_act)
        self.dec_sc  = Decoder(in_dim=len(id_embed_sc) , num_topics=num_topics)
        self.dec_ehr = Decoder(in_dim=len(id_embed_ehr), num_topics=num_topics)

    # theta ~ mu + std N(0,1)
    def reparameterize(self, mu, logvar):
        """
            Returns a sample from a Gaussian distribution via reparameterization.
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul_(std).add_(mu)
        else:
            return mu

    def infer_topic_distribution(self, normalized_bows: torch.Tensor) -> torch.Tensor:
        """
            Returns a deterministic topic distribution for evaluation purposes bypassing the stochastic reparameterization step.

            Args:
                normalized_bows (torch.Tensor): Normalized bag-of-words input.

            Returns:
                torch.Tensor: Deterministic topic proportions.
        """
        theta = self.encoder.infer_topic_distribution(normalized_bows) # TODO: To fix.
        return theta

    def get_beta(self, modality: str):
        """
            Retrieve beta for the selecting modality which represents the topic-word (or topic-feature) distributions for that modality. It performs softmax of the vocabulary dimension. Calling this method puts the model into an evaluation state.

            Args:
                modality (str): "sc" single-cell RNA modality or "ehr" Electronic Health Record (diseases) modality.

            Returns:
                 np.ndarray: Beta representing the topic-word (or topic-feature) distributions.
        """
        if modality == 'sc':
            decoder = self.dec_sc
        elif modality == 'ehr':
            decoder = self.dec_ehr
        else:
            raise ValueError('The modality parameter must be either "sc" or "ehr".')

        with torch.no_grad():
            beta = decoder.get_beta().cpu().numpy()
        return beta

    def step_forward(self, encoder, decoder, bow, rho, aggregate=True):
        bow_raw  = bow # integer counts
        lengths  = bow_raw.sum(1, keepdim=True) + 1e-8
        bow_norm = bow_raw / lengths # Normalize

        mu, logvar, kld = encoder(bow_norm)
        z = self.reparameterize(mu, logvar)
        theta = F.softmax(z, dim=-1)

        preds = decoder(theta, rho=rho)
        rec_loss = -(bow_raw * preds).sum(1) / lengths.squeeze(1) # Dev. note: lengths.squeeze(1) is the only key difference.
        if aggregate:
            rec_loss = rec_loss.mean()

        return {
            'rec_loss': rec_loss,
            'kl'      : kld,
            'theta'   : theta.detach(),
            'preds'   : preds.detach(),
        }

    def forward(self, bow_sc, bow_ehr, kl_annealing=1.0): # TODO: Update forward updated trainable graph layer.
        # Update Embeddings
        if self.embedding_dataloader is not None:
            graph_loss = 0
            for embed_batch in self.embedding_dataloader:
                rho, batch_graph_loss = self.graph_filter(self.rho, embed_batch)
                graph_loss += batch_graph_loss
        else:
            raise NotImplementedError
            # rho, graph_loss = self.graph_filter(self.rho) # DONE: Add graph (batch?) as param? Since I cant really run it inside this forward method now. Unless I write an alternative method just for it.
            # TODO: Re-implement the who codebase (make trainer GraphETM?)

        rho_sc  = rho[self.id_embed_sc , :]
        rho_ehr = rho[self.id_embed_ehr, :]

        # Encoder-Decoder: ScRNA
        output_sc = self.step_forward(
            bow=bow_sc,
            encoder=self.enc_sc,
            decoder=self.dec_sc,
            rho=rho_sc)

        # Encoder-Decoder: EHR
        output_ehr = self.step_forward(
            bow=bow_ehr,
            encoder=self.enc_ehr,
            decoder=self.dec_ehr,
            rho=rho_ehr)

        # Total ELBO Loss
        elbo_loss = (output_sc['rec_loss'] + output_ehr['rec_loss']).mean() + (output_sc['kl'] + output_ehr['kl']) * kl_annealing + graph_loss

        # Update outputs
        output_sc['rec_loss'] = output_sc['rec_loss'].mean()
        output_ehr['rec_loss'] = output_ehr['rec_loss'].mean()
        return {
            'loss': elbo_loss,
            'graph_loss': graph_loss,
            'sc' : output_sc,
            'ehr': output_ehr,
        }

---
# Training

In [7]:
### Device
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
# device = torch.device('cpu')
print(f'Using device: {device}')

Using device: mps


In [8]:
### Data
test_size = 0.2

# Load metadata
input_sc = pd.read_csv('inputs/GraphETM/input_PBMC.csv')
input_ehr = pd.read_csv('inputs/GraphETM/input_EHR.csv')

# Load Rho (Graph Embeddings)
sc_indices  = np.load('inputs/GraphETM/id_embed_sc.npy')
ehr_indices = np.load('inputs/GraphETM/id_embed_ehr.npy')

embedding_full = torch.load('inputs/GraphETM/embedding_full.pt', weights_only=False)
edge_index = torch.load('inputs/GraphETM/edge_index.pt', weights_only=True)

embedding_data = Data(x=embedding_full, edge_index=edge_index)
embedding_dataloader = NeighborLoader(
    data=embedding_data,
    num_neighbors=[10, 5],
    batch_size=1024,
)

# Load Input Data
X_sc  = torch.load('inputs/GraphETM/X_sc.pt',  weights_only=False)
X_ehr = torch.load('inputs/GraphETM/X_ehr.pt', weights_only=False)

X_sc,  X_sc_val  = train_test_split(X_sc , test_size=test_size, random_state=0)
X_ehr, X_ehr_val = train_test_split(X_ehr, test_size=test_size, random_state=0)

In [9]:
### Parameters
config = {
    'model': dict(
        num_topics = 15, # K = 10
        theta_act='tanh',
        dropout=0.2,
        encoder_params = { ## Encoder parameters
            'sc': { # Encoder SC
                'vocab_size': X_sc.shape[1],
                'encoder_hidden_size': 64 },
            'ehr': { # Encoder EHR
                'vocab_size': X_ehr.shape[1],
                'encoder_hidden_size': 64 }
        },
        graphconv_params = { ## Graph-Conv Filter parameters
            'hidden_dim': X_ehr.shape[1],
            'out_dim': 64
        },
        ## Embedding Parameters
        embedding  = embedding_full,
        edge_index = edge_index,
        id_embed_sc  = sc_indices ,
        id_embed_ehr = ehr_indices,
    ),

    'dataloader': dict(
        batch_size=64,
        shuffle=False
    ),

    'training': dict(
        optimizer = torch.optim.Adam,
        lr = 0.001,
        epochs = 100,
        kl_annealing_epochs = None
    ),

    'device': device,
}


### Model
trainer = GraphETMTrainer(
    model = Model(**config['model']).to(device), # Model
    dataloader_sc  = DataLoader(**config['dataloader'], dataset = X_sc ), # Dataloaders
    dataloader_ehr = DataLoader(**config['dataloader'], dataset = X_ehr),
    val_dataloader_sc  = DataLoader(**config['dataloader'], dataset = X_sc_val ),
    val_dataloader_ehr = DataLoader(**config['dataloader'], dataset = X_ehr_val),
    device = device,
    wandb_run = wandb.init(
        project ='GraphETM',
        group = 'GraphETM',
        # name = f'GraphETM_{int(time.time())}',
        name = 'GraphETM_conv_filter',
        config=config, save_code=True) # Start Wandb
)

### Training
trainer.train( # TODO: Check for KL Divergence being magnitude smaller than Recon. Loss.
    epochs = config['training']['epochs'],
    optimizer = config['training']['optimizer']([ # Optimizer
        {'params': trainer.model.enc_sc.parameters()},
        {'params': trainer.model.enc_ehr.parameters()},
        {'params': trainer.model.dec_sc.parameters()},
        {'params': trainer.model.dec_ehr.parameters()}
        # Embedding params: lr=0 until un-frozen
        # {'params': [trainer.model.dec_sc.embedding],  'lr': 0.0, 'name': 'embedding_sc' },
        # {'params': [trainer.model.dec_ehr.embedding], 'lr': 0.0, 'name': 'embedding_ehr'},
    ],
    lr = config['training']['lr']),
    kl_annealing_epochs = config['training']['kl_annealing_epochs'],
)
# TODO: Close the trainer.wandb instance.

Training GraphETM:   0%|          | 0/100 [00:00<?, ?epoch/s]

RuntimeError: Invalid buffer size: 39.21 GB

In [None]:
### FUNCTION TOP_K
TOP_K = 5

# TODO REVERSE RANKING + CORRESPONDING

def top_k_per_topic(input_df, modality, k=5):
    beta = trainer.model.get_beta(modality=modality)

    top_k_indices = np.argsort(beta, axis=1)[:, -k:]
    top_k_indices = top_k_indices.flatten()
    top_k = input_df.columns[top_k_indices]

    prob = beta[:, top_k_indices].T
    return pd.DataFrame(prob, index=top_k)

### GET TOP K PER TOPIC
sc_prob_df = top_k_per_topic(input_df=input_sc, modality='sc', k=TOP_K) # SC
ehr_prob_df = top_k_per_topic(input_df=input_ehr, modality='ehr', k=TOP_K) # EHR

### PLOT
fig = make_subplots(
    rows=1, cols=2,
    horizontal_spacing=0.06,
    subplot_titles=['Top Gene per Topics', 'Top ICD-9 Code per Topics'],
)

fig.update_layout(
    template='plotly_white',
    width=1200, height=1500,
    font=dict(color='black', size=10),
)

# PARAMS
heatmap_params = dict(
    colorscale='OrRd',
    xgap=0.9,
    ygap=0.9,
)

yaxes_params = dict(
    tickfont=dict(size=10, color='black')
)

# SC Plot
fig.add_trace(
    go.Heatmap(
        name='SC',
        z=sc_prob_df.values,
        x=sc_prob_df.columns,
        y=list(range(len(sc_prob_df.index))),
        **heatmap_params
    ),
    row=1, col=1
)
fig.update_yaxes(
    tickvals=list(range(len(sc_prob_df.index))),
    ticktext=sc_prob_df.index,
    autorange='reversed', type='category',
    row=1, col=1,
    **yaxes_params
)

# EHR Plot
fig.add_trace(
    go.Heatmap(
        name='EHR',
        z=ehr_prob_df.values,
        x=ehr_prob_df.columns,
        y=list(range(len(ehr_prob_df.index))),
        **heatmap_params
    ),
    row=1, col=2
)
fig.update_yaxes(
    tickvals=list(range(len(ehr_prob_df.index))),
    ticktext=ehr_prob_df.index,
    autorange='reversed', type='category',
    row=1, col=2,
    **yaxes_params
)

# Horizontal separations
for i in range(TOP_K, ehr_prob_df.shape[0], TOP_K):
    fig.add_hline(
        y = i - 0.5,
        line_width=4,
        line_color='white'
    )

# Adjust vertical title location
for annotation in fig['layout']['annotations']:
    annotation['y'] += 0.01

fig.show()

In [None]:
# fig.write_html('top_k.html')

In [None]:
beta_sc = trainer.model.dec_sc.get_beta()      # K × 4340
beta_ehr = trainer.model.dec_ehr.get_beta()      # K × 4340

uniq_top1_sc = np.unique(beta_sc.numpy(force=True).argmax(1)).size
uniq_top1_ehr = np.unique(beta_ehr.numpy(force=True).argmax(1)).size
print(f'unique top-1 tokens: sc = {uniq_top1_sc}/{beta_sc.shape[0]}, ehr = {uniq_top1_ehr}/{beta_ehr.shape[0]}')

entropy_sc = -(beta_sc * beta_sc.clamp_min(1e-9).log()).sum(1)
entropy_ehr = -(beta_ehr * beta_ehr.clamp_min(1e-9).log()).sum(1)
print(f'entropy per topic: sc = {entropy_sc.numpy(force=True)}, ehr = {entropy_ehr.numpy(force=True)}')

In [None]:
### OCCURRENCE COUNT
TOP_N = 25

gene_counts = sc_prob_df.index.value_counts()
icd_counts  = ehr_prob_df.index.value_counts()

gene_counts_top = gene_counts.head(TOP_N)
icd_counts_top  = icd_counts.head(TOP_N)

# fig_num_topic = make_subplots(
#     rows=1, cols=2,
#     shared_xaxes=False,
#     # horizontal_spacing=0.06,
#     subplot_titles=[f'Top {TOP_N} genes by num_topics (K={K})', f'Top {TOP_N} ICD-9 codes by num_topics (K={K})']
# )
#
# fig_num_topic.update_layout(
#     template='plotly_white',
#     font=dict(color='black', size=10)
# )
#
# fig_num_topic.add_bar() # TODO: Got lazy.

fig_gene_count = px.bar(
    gene_counts_top.sort_values(ascending=False).reset_index(),
    x='index', y='count',
    title=f'Top {TOP_N} genes by num_topics (K={TOP_N})'
)

fig_icd_count = px.bar(
    icd_counts_top.sort_values(ascending=False).reset_index(),
    x='index', y='count',
    title=f'Top {TOP_N} ICD-9 codes by num_topics (K={TOP_N})'
)

### PROBABILITY WEIGHTED IMPORTANCE
gene_weight = sc_prob_df.groupby(sc_prob_df.index).sum().sum(axis=1)
icd_weight  = ehr_prob_df.groupby(ehr_prob_df.index).sum().sum(axis=1)

gene_weight_top = gene_weight.sort_values(ascending=False).head(TOP_N)
icd_weight_top  = icd_weight.sort_values(ascending=False).head(TOP_N)

fig_gene_weight = px.bar(
    gene_weight_top.reset_index(),
    x='index', y=0,
    title=f'Top {TOP_N} genes by cumulative beta-probability',
    labels={'index':'Gene', 0:'Σ β'},
    template='plotly_white'
)

fig_icd_weight = px.bar(
    icd_weight_top.reset_index(),
    x='index', y=0,
    title=f'Top {TOP_N} ICD-9 codes by cumulative beta-probability',
    labels={'index':'ICD-9', 0:'Σ β'},
    template='plotly_white'
)

### FORMAT FIGURES
font_params = dict(color='black', size=12)
for fig in [fig_gene_count, fig_icd_count, fig_gene_weight, fig_icd_weight]:
    fig.update_layout(
        template='plotly_white',
        font=font_params,
        title_font=dict(color='black', size=16)
    )
    fig.update_xaxes(tickfont=font_params, title_font=dict(color='black', size=14))
    fig.update_yaxes(tickfont=font_params, title_font=dict(color='black', size=14))

fig_gene_count.show()
fig_icd_count.show()
fig_gene_weight.show()
fig_icd_weight.show()


############################################################################
### CUMULATIVE VS UBIQUITY
font_params = dict(color='black', size=12)

fig_scatter = px.scatter(
    data_frame = pd.DataFrame({
        'term':  list(gene_counts.index) + list(icd_counts.index),
        'num_topics':  gene_counts.tolist()    + icd_counts.tolist(),
        'cum_beta': pd.concat([gene_weight, icd_weight]).values,
        'type': ['Gene']*len(gene_counts) + ['ICD-9']*len(icd_counts)
    })
    .query('cum_beta > 0')
    ,
    x='num_topics', y='cum_beta',
    color='type', # two colors = Genes vs ICD-9
    hover_data=['term', 'num_topics', 'cum_beta'],
    marginal_x='violin',
    marginal_y='violin',
    # log_y=True, # keeps long-tail terms visible
    template='plotly_white',
    title='Term ubiquity vs cumulative probability',
)

# Update visuals
fig_scatter.update_layout(
    font=font_params,
    title_font=dict(color='black', size=16),
    legend_title_text='Term type',
)
fig_scatter.update_xaxes(title_font=font_params, tickfont=font_params,
                         rangemode='tozero')
fig_scatter.update_yaxes(title_font=font_params, tickfont=font_params,
                         rangemode='tozero')

fig_scatter.show()

In [None]:
# wandb.log({'Top K per Topics': fig}) # TODO: Fix visualization.

wandb.log({
    'Gene freq':        wandb.Plotly(fig_gene_count),
    'ICD freq':         wandb.Plotly(fig_icd_count),
    'Gene importance':  wandb.Plotly(fig_gene_weight),
    'ICD importance':   wandb.Plotly(fig_icd_weight),
})

# fig_scatter.write_html('scatter.html', include_plotlyjs='cdn') # TODO: Fix visualization
# scatter_artifact = wandb.Artifact('ubiquity_vs_importance', type='visualization')
# scatter_artifact.add_file('scatter.html')
# wandb.log_artifact(scatter_artifact)

In [None]:
# TODO: Implement Plotly Clustergram.

In [10]:
trainer.wandb.finish()

In [None]:
# DONE