# GraphETM Dev Notebook

In [1]:
### Imports
## Local
from model.graphetm import GraphETM

## External
import numpy as np
import pandas as pd

# Torch
import torch
from torch.utils.data import DataLoader

# Torch-Geometric
import torch_geometric as pyg
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader

# Sklearn
from sklearn.model_selection import train_test_split

# Plot
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import wandb

### Parameters
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mloicduch[0m ([33mloicduch-mcgill-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# Seeds
pyg.seed_everything(10) # random, np, torch, torch.cuda

---
# Training

In [3]:
### Device
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
# device = torch.device('cpu')
print(f'Using device: {device}')

Using device: mps


In [4]:
### Data
test_size = 0.2

# Load metadata
input_sc = pd.read_csv('inputs/GraphETM/input_PBMC.csv')
input_ehr = pd.read_csv('inputs/GraphETM/input_EHR.csv')

# Load Rho (Graph Embeddings)
sc_indices  = np.load('inputs/GraphETM/id_embed_sc.npy') # TODO: Check wether these indexes are still relevant.
ehr_indices = np.load('inputs/GraphETM/id_embed_ehr.npy')

embedding_full = torch.load('inputs/GraphETM/embedding_full.pt', weights_only=False) # V x L
edge_index = torch.load('inputs/GraphETM/edge_index.pt', weights_only=True)

# Load Input Data
X_sc  = torch.load('inputs/GraphETM/X_sc.pt',  weights_only=False) # (num_samples (cells), num_genes)
X_ehr = torch.load('inputs/GraphETM/X_ehr.pt', weights_only=False)

X_sc,  X_sc_val  = train_test_split(X_sc , test_size=test_size, random_state=0)
X_ehr, X_ehr_val = train_test_split(X_ehr, test_size=test_size, random_state=0)

In [9]:
### Parameters
config = {
    ### ETM Model Config
    'etm_model': dict(
        num_topics = 15, # K = 15
        theta_act = 'tanh',
        dropout = 0.2,
        encoder_params = { ## Encoder parameters
            'sc': { # Encoder SC
                'vocab_size': X_sc.shape[1],
                'encoder_hidden_size': 64},
            'ehr': { # Encoder EHR
                'vocab_size': X_ehr.shape[1],
                'encoder_hidden_size': 64}
        }),

    ### Graph Model Config
    'graph_model': dict(
        hidden_dim = 64,
        embedding_dim = 32, # TODO: I should use this instead to define the in_features of alphas. TODO: Rename this "embedding dimension".
    ),

    'dataloader': dict(
        batch_size = 64,
        shuffle = False
    ),

    'neighborloader': dict(
        num_neighbors = [10, 10],
        batch_size = 2048,
        shuffle = True
    ),

    'training': dict(
        optimizer = torch.optim.Adam,
        lr = 0.001,
        epochs = 25,
        kl_annealing_epochs = None
    ),

    'device': device,
}


### Model
model = GraphETM(
    model_cfg = config['etm_model'], # ETM Model config
    gcn_cfg   = config['graph_model'], # Graph Model config
    ## Embedding Parameters
    embedding  = embedding_full,
    edge_index = edge_index,
    id_embed_sc  = sc_indices,  # TODO: Shape issue source. Linear layer is created with shape 1000 from these.
    id_embed_ehr = ehr_indices, # TODO: These are likely a different shape, but still wrong.
    graphloader_cfg=config['neighborloader'],

    ## Data Parameters
    dataloader_sc  = DataLoader(**config['dataloader'], dataset = X_sc ), # Dataloaders
    dataloader_ehr = DataLoader(**config['dataloader'], dataset = X_ehr),
    val_dataloader_sc  = DataLoader(**config['dataloader'], dataset = X_sc_val ),
    val_dataloader_ehr = DataLoader(**config['dataloader'], dataset = X_ehr_val),
    device = device,
    # wandb_run = wandb.init(
    #     project ='GraphETM',
    #     group = 'GraphETM',
    #     # name = f'GraphETM_{int(time.time())}',
    #     name = 'GraphETM_filter_rev1',
    #     config=config, save_code=True) # Start Wandb
)

### Training
model.train(
    epochs = config['training']['epochs'],
    optimizer = config['training']['optimizer']([ # Optimizer
        {'params': model.etm_model.parameters()},
        {'params': model.graph_model.parameters()}],
        lr = config['training']['lr']),
    kl_annealing_duration = config['training']['kl_annealing_epochs'],
)
# TODO: Close the trainer.wandb instance.

Training GraphETM:   0%|          | 0/25 [00:00<?, ?epoch/s]

RuntimeError: shape mismatch: value tensor of shape [59, 32] cannot be broadcast to indexing result of shape [59, 128]

In [None]:
### FUNCTION TOP_K
TOP_K = 5

# TODO REVERSE RANKING + CORRESPONDING

def top_k_per_topic(input_df, modality, k=5):
    beta = model.etm_model.get_beta(modality=modality)

    top_k_indices = np.argsort(beta, axis=1)[:, -k:]
    top_k_indices = top_k_indices.flatten()
    top_k = input_df.columns[top_k_indices]

    prob = beta[:, top_k_indices].T
    return pd.DataFrame(prob, index=top_k)

### GET TOP K PER TOPIC
sc_prob_df = top_k_per_topic(input_df=input_sc, modality='sc', k=TOP_K) # SC
ehr_prob_df = top_k_per_topic(input_df=input_ehr, modality='ehr', k=TOP_K) # EHR

### PLOT
fig = make_subplots(
    rows=1, cols=2,
    horizontal_spacing=0.06,
    subplot_titles=['Top Gene per Topics', 'Top ICD-9 Code per Topics'],
)

fig.update_layout(
    template='plotly_white',
    width=1200, height=1500,
    font=dict(color='black', size=10),
)

# PARAMS
heatmap_params = dict(
    colorscale='OrRd',
    xgap=0.9,
    ygap=0.9,
)

yaxes_params = dict(
    tickfont=dict(size=10, color='black')
)

# SC Plot
fig.add_trace(
    go.Heatmap(
        name='SC',
        z=sc_prob_df.values,
        x=sc_prob_df.columns,
        y=list(range(len(sc_prob_df.index))),
        **heatmap_params
    ),
    row=1, col=1
)
fig.update_yaxes(
    tickvals=list(range(len(sc_prob_df.index))),
    ticktext=sc_prob_df.index,
    autorange='reversed', type='category',
    row=1, col=1,
    **yaxes_params
)

# EHR Plot
fig.add_trace(
    go.Heatmap(
        name='EHR',
        z=ehr_prob_df.values,
        x=ehr_prob_df.columns,
        y=list(range(len(ehr_prob_df.index))),
        **heatmap_params
    ),
    row=1, col=2
)
fig.update_yaxes(
    tickvals=list(range(len(ehr_prob_df.index))),
    ticktext=ehr_prob_df.index,
    autorange='reversed', type='category',
    row=1, col=2,
    **yaxes_params
)

# Horizontal separations
for i in range(TOP_K, ehr_prob_df.shape[0], TOP_K):
    fig.add_hline(
        y = i - 0.5,
        line_width=4,
        line_color='white'
    )

# Adjust vertical title location
for annotation in fig['layout']['annotations']:
    annotation['y'] += 0.01

fig.show()

In [None]:
# fig.write_html('top_k.html')

In [None]:
beta_sc = model.etm_model.dec_sc.get_beta()      # K × 4340
beta_ehr = model.etm_model.dec_ehr.get_beta()      # K × 4340

uniq_top1_sc = np.unique(beta_sc.numpy(force=True).argmax(1)).size
uniq_top1_ehr = np.unique(beta_ehr.numpy(force=True).argmax(1)).size
print(f'unique top-1 tokens: sc = {uniq_top1_sc}/{beta_sc.shape[0]}, ehr = {uniq_top1_ehr}/{beta_ehr.shape[0]}')

entropy_sc = -(beta_sc * beta_sc.clamp_min(1e-9).log()).sum(1)
entropy_ehr = -(beta_ehr * beta_ehr.clamp_min(1e-9).log()).sum(1)
print(f'entropy per topic: sc = {entropy_sc.numpy(force=True)}, ehr = {entropy_ehr.numpy(force=True)}')

In [None]:
### OCCURRENCE COUNT
TOP_N = 25

gene_counts = sc_prob_df.index.value_counts()
icd_counts  = ehr_prob_df.index.value_counts()

gene_counts_top = gene_counts.head(TOP_N)
icd_counts_top  = icd_counts.head(TOP_N)

# fig_num_topic = make_subplots(
#     rows=1, cols=2,
#     shared_xaxes=False,
#     # horizontal_spacing=0.06,
#     subplot_titles=[f'Top {TOP_N} genes by num_topics (K={K})', f'Top {TOP_N} ICD-9 codes by num_topics (K={K})']
# )
#
# fig_num_topic.update_layout(
#     template='plotly_white',
#     font=dict(color='black', size=10)
# )
#
# fig_num_topic.add_bar() # TODO: Got lazy.

fig_gene_count = px.bar(
    gene_counts_top.sort_values(ascending=False).reset_index(),
    x='index', y='count',
    title=f'Top {TOP_N} genes by num_topics (K={TOP_N})'
)

fig_icd_count = px.bar(
    icd_counts_top.sort_values(ascending=False).reset_index(),
    x='index', y='count',
    title=f'Top {TOP_N} ICD-9 codes by num_topics (K={TOP_N})'
)

### PROBABILITY WEIGHTED IMPORTANCE
gene_weight = sc_prob_df.groupby(sc_prob_df.index).sum().sum(axis=1)
icd_weight  = ehr_prob_df.groupby(ehr_prob_df.index).sum().sum(axis=1)

gene_weight_top = gene_weight.sort_values(ascending=False).head(TOP_N)
icd_weight_top  = icd_weight.sort_values(ascending=False).head(TOP_N)

fig_gene_weight = px.bar(
    gene_weight_top.reset_index(),
    x='index', y=0,
    title=f'Top {TOP_N} genes by cumulative beta-probability',
    labels={'index':'Gene', 0:'Σ β'},
    template='plotly_white'
)

fig_icd_weight = px.bar(
    icd_weight_top.reset_index(),
    x='index', y=0,
    title=f'Top {TOP_N} ICD-9 codes by cumulative beta-probability',
    labels={'index':'ICD-9', 0:'Σ β'},
    template='plotly_white'
)

### FORMAT FIGURES
font_params = dict(color='black', size=12)
for fig in [fig_gene_count, fig_icd_count, fig_gene_weight, fig_icd_weight]:
    fig.update_layout(
        template='plotly_white',
        font=font_params,
        title_font=dict(color='black', size=16)
    )
    fig.update_xaxes(tickfont=font_params, title_font=dict(color='black', size=14))
    fig.update_yaxes(tickfont=font_params, title_font=dict(color='black', size=14))

fig_gene_count.show()
fig_icd_count.show()
fig_gene_weight.show()
fig_icd_weight.show()


############################################################################
### CUMULATIVE VS UBIQUITY
font_params = dict(color='black', size=12)

fig_scatter = px.scatter(
    data_frame = pd.DataFrame({
        'term':  list(gene_counts.index) + list(icd_counts.index),
        'num_topics':  gene_counts.tolist()    + icd_counts.tolist(),
        'cum_beta': pd.concat([gene_weight, icd_weight]).values,
        'type': ['Gene']*len(gene_counts) + ['ICD-9']*len(icd_counts)
    })
    .query('cum_beta > 0')
    ,
    x='num_topics', y='cum_beta',
    color='type', # two colors = Genes vs ICD-9
    hover_data=['term', 'num_topics', 'cum_beta'],
    marginal_x='violin',
    marginal_y='violin',
    # log_y=True, # keeps long-tail terms visible
    template='plotly_white',
    title='Term ubiquity vs cumulative probability',
)

# Update visuals
fig_scatter.update_layout(
    font=font_params,
    title_font=dict(color='black', size=16),
    legend_title_text='Term type',
)
fig_scatter.update_xaxes(title_font=font_params, tickfont=font_params,
                         rangemode='tozero')
fig_scatter.update_yaxes(title_font=font_params, tickfont=font_params,
                         rangemode='tozero')

fig_scatter.show()

In [None]:
# wandb.log({'Top K per Topics': fig}) # TODO: Fix visualization.

wandb.log({
    'Gene freq':        wandb.Plotly(fig_gene_count),
    'ICD freq':         wandb.Plotly(fig_icd_count),
    'Gene importance':  wandb.Plotly(fig_gene_weight),
    'ICD importance':   wandb.Plotly(fig_icd_weight),
})

# fig_scatter.write_html('scatter.html', include_plotlyjs='cdn') # TODO: Fix visualization
# scatter_artifact = wandb.Artifact('ubiquity_vs_importance', type='visualization')
# scatter_artifact.add_file('scatter.html')
# wandb.log_artifact(scatter_artifact)

In [None]:
# TODO: Implement Plotly Clustergram.

In [8]:
model.wandb.finish()

0,1
batch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▇▇▇▇▇███
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train/ehr/kld,▃▁▁▂▂▃▄▃▄▄▄▆▆▄▅▅▆▅▆▇▆▅▆▄▅▅▅▅▆█▅▆▇▆▆▇▆▇▆▇
train/ehr/recon_loss,▂▄▃▆▃▅▆▃▃█▃▅▃▃▂▁▃▃▅▄▃▄▃▂▂▄▂▃▂▂▂▁▂▅▂▄▂▅▅▆
train/graph_recon_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/sc/kld,▁▁▁▁▁▁▁▂▂▂▃▃▃▄▄▅▄▅▅▅▅▅▅▅▅▆▆▅▆▆▇▆▆▆▇▆▆▆██
train/sc/recon_loss,█▄▅▆▃▅▃▁▃▁▃▄▅▄▄▄▃▄▄▄▇▆▂▂▃▂▄▅▄▇▅█▄▁▄▇▆▁▄▅
train/total_loss,▅█▃▁▆▄▃▃▅▅▄▁▅▅▄▄▅▂▁▇▃▄▁▃▄▅▄▄▃▄▂▄▇▂▃▇▃▄▃▁
val/ehr/ari,█▆▆▆▄▅▃▂▁▃▄▃▃▃▃▃▂▃▂▂▂▂▁▁
val/ehr/kld,▁▂▄▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇█▇█████

0,1
batch,1349.0
epoch,24.0
train/ehr/kld,0.24611
train/ehr/recon_loss,19.82779
train/graph_recon_loss,1.28858
train/sc/kld,5.31191
train/sc/recon_loss,890.68433
train/total_loss,917.3587
val/ehr/ari,0.1365
val/ehr/kld,0.25535


In [None]:
# DONE