<a href="https://colab.research.google.com/github/joaquim-teixeira/HVGFGL-metabolite/blob/main/notebooks/HVGFGL_Kidney_Data_VI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hierarchical Variational Graph Fused Lasso (HVGFGL)

This notebook provides an implementation of the **Hierarchical Variational Graph Fused Lasso (HVGFGL)** introduced in  
[*A Hierarchical Variational Graph Fused Lasso for Recovering Relative Rates in Spatial Compositional Data*](https://arxiv.org/abs/2509.20636).  

It contains the core code required to apply HVGFGL to real imaging mass spectrometry (IMS) data—specifically, mouse kidney data from [Wang et al., 2022](https://pubmed.ncbi.nlm.nih.gov/35132243/).  
The corresponding results and analyses are presented in **Section 6** of the manuscript.  

---

## Inputs
- **Dataset**: Mouse kidney IMS data from Wang et al., 2022.  
- **Preprocessing**: Data is standardized and filtered to match the experimental setup described in the manuscript.  

---

## Outputs
Running this notebook will generate:  
- **Variational Parameters** for HVGFGL on data.

In [None]:
#@title Dependencies and Utils
# Mount into drive
from google.colab import drive
drive.mount("/content/drive")
%cd '/content/drive/MyDrive/'
from google.colab import files
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import Dirichlet, Multinomial, Laplace, Gamma,InverseGamma, Normal, MultivariateNormal
import time
import scipy as sp
from scipy.spatial import KDTree
import pandas as pd
import matplotlib.colors as mcolors
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("You are using device: %s" % device)



def sample_neg_mult_gamma_cont (n,p,size):
    p0 = 1-p.sum(axis=-1)
    #try:
    #    n_sample = torch.distributions.Gamma(n,1/(1/p0-1)).rsample([size])
    #except:
    n_sample = torch.distributions.Gamma(n,1/(1/(p0+ 1e-40)-1) + 1e-40).rsample([size])
    val = (n_sample/((1-p0+1e-40))).unsqueeze(-1)*p.unsqueeze(0)
    tst=torch.distributions.Normal(val, torch.sqrt(val+1e-40)).rsample()
    return(tst)

def find_neighbors(points, threshold):
    tree = KDTree(points)
    neighbors = tree.query_ball_tree(tree, threshold)
    return neighbors

def create_adjacency_matrix(points, neighbors):
    n = len(points)
    adjacency_matrix = np.zeros((n, n), dtype=int)

    for i, neighbor_list in enumerate(neighbors):
        for neighbor in neighbor_list:
            adjacency_matrix[i][neighbor] = 1
            adjacency_matrix[neighbor][i] = 1
    return adjacency_matrix
def columnwise_corr(A, B, eps=1e-8):
    # A and B should be shape (n, p)
    A_mean = A.mean(dim=0)
    B_mean = B.mean(dim=0)

    A_centered = A - A_mean
    B_centered = B - B_mean

    numerator = (A_centered * B_centered).sum(dim=0)
    denom = (A_centered.square().sum(dim=0).sqrt() * B_centered.square().sum(dim=0).sqrt())

    corr = numerator / (denom + eps)  # eps for numerical stability
    return corr

You are using device: cuda


In [None]:
#@title HVGFGL Lasso
def sample_neg_mult_gamma_cont_cens (n,p,size,cens):
    p0 = 1-(p*(cens)).sum(axis=-1)
    n_sample = torch.distributions.Gamma(n,1/(1/(p0+ 1e-40)-1) + 1e-40).rsample([size])

    val = (n_sample/((1-p0+1e-40))).unsqueeze(-1)*(p*(cens)).unsqueeze(0)
    tst = torch.distributions.Normal(val, torch.sqrt(val+1e-40)).rsample()
    return(tst)

def elbo_lasso_cens(theta_q_loc, x, lam, adj, cens, lods, cr, samp_size, n_samp, cens_exp, gamma_1_raw, gamma_2_raw, lam_1_raw, lam_2_raw, dat_exp, a1, a2, cens_inf):

    q_lam = torch.distributions.Gamma(torch.exp(lam_1_raw), torch.exp(lam_2_raw))
    lam_samp = q_lam.rsample([n_samp])
    q_gamma = torch.distributions.Gamma(torch.exp(gamma_1_raw), torch.exp(gamma_2_raw))
    q_gamma_samp = q_gamma.rsample([n_samp])

    p1, E, q = q_gamma_samp.shape
    N = adj.max().item() + 1
    sums = torch.zeros(p1, N, q, dtype=torch.float32, device=device)

    sums.scatter_add_(1, a1, 1 / q_gamma_samp)
    sums.scatter_add_(1, a2, 1 / q_gamma_samp)

    theta_q = torch.distributions.Normal(theta_q_loc, (1 / sums).sqrt())
    samps_q = theta_q.rsample([1]).squeeze(0)
    theta = torch.softmax(samps_q,dim=-1)

    wot = sample_neg_mult_gamma_cont_cens(dat_exp.sum(axis=-1)[:,~cr], theta[:,~cr,:], samp_size, cens_exp[:,~cr,:])

    soft_indicator = torch.sigmoid(1e5 * (lods[~cr] - wot))
    vals = (((torch.sigmoid(1e5 * (soft_indicator.mean(axis=-1) - 1))) * 2.).mean(axis=0) + 1e-90).log()

    log_theta = samps_q - torch.logsumexp(samps_q-cens_inf, axis=-1).unsqueeze(-1)
    log_likelihood = (dat_exp[~cens_exp] * log_theta[~cens_exp]).sum() / n_samp


    log_likelihood = log_likelihood + vals.sum(axis=1).mean()

    mu = (samps_q[:, adj[:, 0], :] - samps_q[:, adj[:, 1], :]).unsqueeze(0)
    laplace_prior = -(((mu.abs()* (1 / q_gamma_samp.unsqueeze(1))).sum([-1, -2])).mean())
    local_prior = -(q_gamma_samp / lam_samp.unsqueeze(1)).sum([-1, -2]).mean()

    ent_gamma = q_gamma.log_prob(q_gamma_samp).sum([-1, -2]).mean()
    ent_norm = theta_q.log_prob(samps_q).sum([-1, -2]).mean()

    lam_prior = torch.distributions.Gamma(1, lam)
    lam_kl = torch.distributions.kl_divergence(q_lam, lam_prior)
    return log_likelihood + laplace_prior + local_prior - ent_norm - ent_gamma - lam_kl.sum()


In [None]:
#@title Read Kidney Tic Data
df = pd.read_csv('kidney_M2_tic.csv')
pts=np.array([df.x, df.y]).T
df = df.drop(columns = ['x', 'y','Unnamed: 0'])
print(df.shape)
df = df.dropna(axis=1, how='all')
df = np.floor(df)

cens = torch.tensor(np.array(df.eq(df.min())))
data = np.array(df)
data = np.floor(data)
data = torch.from_numpy(data).to(torch.float64)
lods = torch.from_numpy(np.array(df.min())).expand(data.size(0),-1).to(device)
cens_inf = cens.clone()*1.
cens_inf[cens] = torch.inf
cr = (cens.sum(axis=1)==0)

df.columns = df.columns.str.strip()             # Remove leading/trailing spaces
df.columns = df.columns.str.replace(' ', '_')   # Replace spaces with underscores
df.columns = df.columns.str.replace(r'\W', '')  # Remove special characters (non-alphanumeric)
df.columns = df.columns.str.replace('/', '_')     # Replace forward slashes with underscores


(15403, 353)


In [None]:
#@title Get Adjecency Graph
from scipy.sparse import coo_matrix
neighbors = find_neighbors(pts, 1.5)
n_nodes = len(neighbors)

# Create sparse adjacency from edge list
rows = []
cols = []
vals = []

for i, nbrs in enumerate(neighbors):
    for j in nbrs:
        rows.append(i)
        cols.append(j)
        vals.append(1)  # or use weights if available

# Determine size from max index
n = max(max(cols), n_nodes) + 1

# Create sparse adjacency matrix (COO format)
adj = coo_matrix((vals, (rows, cols)), shape=(n, n))
edges = np.vstack((adj.row, adj.col))  # shape: [2, N_edges]
sorted_edges = np.sort(edges, axis=0)  # shape [2, N_edges]
# Remove duplicates using structured array trick
edges_unique = np.unique(sorted_edges, axis=1).transpose()
inds = edges_unique[:,0] == edges_unique[:,1]
edges_unique = edges_unique[~inds]
edges_unique = torch.from_numpy(edges_unique)


In [None]:
#@title Initialize model parameters and hyperparameters

data = torch.tensor(data, dtype=torch.float32, device=device)
cpd = data.clone()
lam_use = data.mean(dim=0).cpu()

# Parameters
theta_q_loc = torch.log((cpd + 1) / (cpd + data.size(1)).sum(dim=1, keepdim=True)).clone()
theta_q_loc = theta_q_loc.to(dtype=torch.float32, device=device).requires_grad_(True)

gamma_1_raw = torch.ones((edges_unique.size(0), data.size(1)), dtype=torch.float32, device=device).requires_grad_(True)
gamma_2_raw = torch.tensor(torch.ones((edges_unique.size(0),data.size(1)), dtype=torch.float32)*(lam_use).log(),device=device,requires_grad=True)

lam_1_raw = torch.zeros_like(lam_use, device=device).requires_grad_(True)
lam_2_raw = lam_use.log().to(dtype=torch.float32, device=device).clone().requires_grad_(True)
lam_group = data.var(dim=0).to(device)

# Training settings
num_steps = 1_000_000
samp_size = 15
n_samp = 2
losses = torch.zeros(num_steps, device=device)

# Expand data/censoring
cens = cens.to(device)
cens_inf = cens_inf.to(device)
edges_unique = edges_unique.to(device)

dat_exp = data.unsqueeze(0).expand(n_samp, -1, -1)
cens_exp = cens.unsqueeze(0).expand(n_samp, -1, -1)

# Graph-related setup
N = edges_unique.max().item() + 1
feat_dim = gamma_2_raw.size(-1)

sums = torch.zeros(n_samp, N, feat_dim, dtype=torch.float32, device=device)
a1 = edges_unique[:, 0].unsqueeze(0).unsqueeze(-1).expand(n_samp, -1, feat_dim)
a2 = edges_unique[:, 1].unsqueeze(0).unsqueeze(-1).expand(n_samp, -1, feat_dim)

# Optimizer
optimizer = torch.optim.Adam(
    [theta_q_loc, gamma_1_raw, gamma_2_raw, lam_1_raw, lam_2_raw],
    lr=0.01,
)



  data = torch.tensor(data, dtype=torch.float32, device=device)
  gamma_2_raw = torch.tensor(torch.ones((edges_unique.size(0),data.size(1)), dtype=torch.float32)*(lam_use).log(),device=device,requires_grad=True)


In [None]:
#@title run model and save output
for step in range(num_steps):
  st=time.time()
  optimizer.zero_grad()
  loss = - elbo_lasso_cens(
          theta_q_loc, data, lam_group, edges_unique, cens, lods, cr,
          samp_size, n_samp, cens_exp, gamma_1_raw, gamma_2_raw, lam_1_raw, lam_2_raw, dat_exp, a1, a2, cens_inf
      )


  loss.backward()
  optimizer.step()
  losses[step] = loss.detach().item()
  with torch.no_grad():
    if step !=0:
        if step % 5000 == 0:
          print(step)
          val = (losses[(step-5000):step].mean())
          if val>TH:
            break
          else:
            print((val - TH)/loss.item())
            TH = val
            torch.save({
            'step': step,
            'theta_q_loc': theta_q_loc.detach().cpu(),
            'gamma_1_raw': gamma_1_raw.detach().cpu(),
            'gamma_2_raw': gamma_2_raw.detach().cpu(),
            'lam_1_raw': lam_1_raw.detach().cpu(),
            'lam_2_raw': lam_2_raw.detach().cpu(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item()
            }, 'model_checkpoint.pt')
            #print(len(losses)/num_steps)
            #print(time.time()-st)
            #window = 5000
            #running_avg = np.convolve(np.array(losses), np.ones(window)/window, mode='valid')
            #plt.plot(range(5000), running_avg[-5000:], color='red', label=f'{window}-point running average')
            # plt.show()
            # print((theta_q_loc/theta_q_loc.mean(axis=0,keepdims=True) - tic).pow(2).max(axis=0).values/tic.var(axis=0))
  if step == 100:
      torch.save({
      'step': step,
      'theta_q_loc': theta_q_loc.detach().cpu(),
      'gamma_1_raw': gamma_1_raw.detach().cpu(),
      'gamma_2_raw': gamma_2_raw.detach().cpu(),
      'lam_1_raw': lam_1_raw.detach().cpu(),
      'lam_2_raw': lam_2_raw.detach().cpu(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss.item()
      }, 'model_checkpoint.pt')
      #print(len(losses)/num_steps)
      #print(time.time()-st)
      #window = 5000
      #running_avg = np.convolve(np.array(losses), np.ones(window)/window, mode='valid')
      #plt.plot(range(5000), running_avg[-5000:], color='red', label=f'{window}-point running average')
      # plt.show()
      # print((theta_q_loc/theta_q_loc.mean(axis=0,keepdims=True) - tic).pow(2).max(axis=0).values/tic.var(axis=0))
      print('test complete')
filename = 'real_data_theta_full_kidney.pt'
print(filename)
torch.save(theta_q_loc.cpu(), filename)
file2 = 'real_data_losses_full_kidney.pt'
torch.save(losses.cpu(), file2)