In [1]:
import torch
import matplotlib.pyplot as plt
from torch import optim, distributions, nn
import torch.nn.utils as nn_utils
from tqdm import tqdm
from gpzoo.gp import SVGP, VNNGP
from gpzoo.kernels import NSF_RBF, RBF
from gpzoo.likelihoods import NSF2
from gpzoo.utilities import rescale_spatial_coords, dims_autocorr
import squidpy as sq
import numpy as np
import time
import random
import scanpy as sc
import anndata as ad
from anndata import AnnData
from squidpy.gr import spatial_neighbors,spatial_autocorr
from matplotlib.animation import FuncAnimation
from matplotlib import animation

  from tqdm.autonotebook import tqdm


In [2]:
torch.manual_seed(0)

# use GPU if available
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device

torch.cuda.empty_cache()

adata = sq.datasets.visium_hne_adata()

In [3]:
def run_benchmarking_experiments(X_train, Y_train, M, L, K, steps=1000, batch_size=64):
    for k in tqdm(range(len(K))):
        for m in range(len(M)):
            for l in range(len(L)):
                # make data dictionary
                exp_data = {
                    'model' :  [],
                    'Z' : [],
                    'L' : [],
                    'K': [],
                    'time' : []}
                # run model
                #idx = torch.multinomial(torch.ones(X_train.shape[0]), num_samples=M[m], replacement=False)
                model = initialize_model(M[m], L[l], K[k], X_train, Y_train, 1.0, 1.0)
                optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
                model.to(device)
                start_time = time.time()
                X_train = torch.tensor(X_train, device='cuda', dtype=torch.float32, requires_grad=True)
                losses, means, scales = train(model, optimizer, X_train, Y_train, device, steps=steps, E=10)
                #losses, means, scales, idxs = train_batched(model, optimizer, X_train, Y_train, device, steps=steps, E=10, batch_size=batch_size)
                end_time = time.time()
                final_time = end_time - start_time

                # update data dictionary
                exp_data['L'].append(L)
                exp_data['Z'].append(M)
                exp_data['K'].append(K)
                alg = 'NNNSF'
                exp_data['model'].append(alg)
                exp_data['time'].append(final_time)

                # make + save loss plot
                fig1, ax1 = plt.subplots()
                plt.plot(losses)
                plt.title(f"Visium Losses")
                plt.close(fig1)

                # make + save factors plot
                size=5
                #X_train = torch.tensor(X_train, device='cuda', dtype=torch.float32, requires_grad=True)
                qF, qU, pU = model.prior(X_train)
                loadings = torch.exp(qF.mean).detach().cpu().numpy()
                del qF, qU, pU
                with torch.no_grad():
                    if device.type=='cuda':
                        torch.cuda.empty_cache()
                X_train = X_train.detach().cpu().numpy()
                moran_idx, moranI = dims_autocorr(np.log(loadings).T, X_train)

                fig2, ax2 = plt.subplots(2, 5, figsize=(size*5, size*2), tight_layout=True)
                fig2.suptitle("NNNSF Factors", size=20)
                plot(X_train, loadings, moran_idx, L[l], ax=ax2, size=5, alpha=0.8)
                plt.close(fig2)

                # make + save animation
                fig3, ax = plt.subplots(2, 5, figsize=(size*5, size*2), tight_layout=True)
                fig3.suptitle("NNNSF Factors", size=20)

                def update(iteration):
                    for ax_row in ax:
                        for element in ax_row:
                            element.cla()
                    curr_mean = means[iteration]
                    #curr_idx = idxs[iteration]
                    plot(X_train, curr_mean, moran_idx, L=10, ax=ax, size=size, alpha=0.9)

                anim = FuncAnimation(fig3, update, frames=np.arange(0, len(means), 1), interval=100)
                plt.show()
                plt.close()
                anim.save(f'.visium_K={K[k]}_Z={M[m]}_factors_anim.gif', writer="pillow")

In [4]:
def train(model, optimizer, X, y, device, steps=200, E=10, batch_size=64, **kwargs):
    losses = []
    means = []
    scales = []
    #idxs = []
    
    for it in tqdm(range(steps)):   
        idx = torch.multinomial(torch.ones(X.shape[0]), num_samples=batch_size, replacement=False)

        optimizer.zero_grad()
        pY, qF, qU, pU = model.forward(X=X, E=E, **kwargs)

        logpY = pY.log_prob(y)

        ELBO = (logpY).mean(axis=0).sum()
        ELBO -= torch.sum(distributions.kl_divergence(qU, pU))

        loss = -ELBO
        loss.backward()

        optimizer.step()

        losses.append(loss.item())
        if (it%10)==0:
            #idxs.append(idx.detach().cpu().numpy())
            means.append(torch.exp(qF.mean.detach().cpu()).numpy())
            scales.append(qF.scale.detach().cpu().numpy())

    with torch.no_grad():
        if device.type=='cuda':
            torch.cuda.empty_cache()

    return losses, means, scales

In [5]:
def initialize_model(M, L, K, X, Y, sigma=1.0, lengthscale=1.0):
    idx = torch.multinomial(torch.ones(X.shape[0]), num_samples=M, replacement=False)
    kernel = NSF_RBF(L=L, sigma=1.0, lengthscale=1.0)
    gp = VNNGP(kernel, M=M, jitter=1e-2, K=K)
    gp.Lu = nn.Parameter(torch.eye(M).expand(L, M, M).clone())
    gp.mu = nn.Parameter(torch.randn((L, M)))
    gp.Z = nn.Parameter(torch.tensor(X[idx]))

    model = NSF2(gp=gp, y=Y, L=L)
    model.prior.kernel.lengthscale.requires_grad = True
    model.prior.kernel.sigma.requires_grad = False
    model.prior.Z.requires_grad = False
    model.prior.mu.requires_grad = True
    model.prior.Lu.requires_grad = True

    return model

In [6]:
def dims_autocorr(factors,coords,sort=True):
    """
    factors: (num observations) x (num latent dimensions) array
    coords: (num observations) x (num spatial dimensions) array
    sort: if True (default), returns the index and I statistics in decreasing
    order of autocorrelation. If False, returns the index and I statistics
    according to the ordering of factors.

    returns: an integer array of length (num latent dims), "idx"
    and a numpy array containing the Moran's I values for each dimension

    indexing factors[:,idx] will sort the factors in decreasing order of spatial
    autocorrelation.
    """
    #from anndata import AnnData
    #from squidpy.gr import spatial_neighbors,spatial_autocorr

    ad = AnnData(X=factors,obsm={"spatial":coords})
    spatial_neighbors(ad)
    df = spatial_autocorr(ad,mode="moran",copy=True)
    if not sort: #revert to original sort order
        df.sort_index(inplace=True)
    
    idx = np.array([int(i) for i in df.index])
    return idx,df["I"].to_numpy()

In [7]:
def plot(X_train, factors, moran_idx, L, ax=None, size=7, alpha=0.9):
    L=10
    max_val = np.percentile(factors, 99)
    min_val = np.percentile(factors, 1)
    #size = 5
    #fig, ax = plt.subplots(L//5, 5, figsize=(size*5, size*2), tight_layout=True)
    #fig.suptitle("NNNSF Factors", size=25)
    for i in range(L):
        plt.subplot(L//5, 5, i+1)
        
        curr_ax = ax[i//5, i%5]
        curr_ax.scatter(X_train[:, 0], X_train[:,1], c=factors[moran_idx][i], vmin=min_val, vmax=max_val, alpha=alpha, cmap='turbo')#, s=0.1)
        curr_ax.set_xlim([X_train[:,0].min()*1.1, X_train[:,0].max()*1.1])
        curr_ax.set_ylim([X_train[:,1].min()*1.1, X_train[:,1].max()*1.1])
        curr_ax.invert_yaxis()
        curr_ax.set_xticks([])
        curr_ax.set_yticks([])
        curr_ax.set_facecolor('xkcd:gray')


def plot_anim(factors, moran_idx, curr_idx, L, ax=None, size=7, alpha=0.9):
    max_val = np.percentile(factors, 99)
    min_val = np.percentile(factors, 1)
    if ax is None:
        fig, ax = plt.subplots(4, 5, figsize=(size*5, size*4), tight_layout=True)
    for i in range(L):
        plt.subplot(L//5, 5, i+1)
        curr_ax = ax[i//5, i%5]
        curr_ax.scatter(X[curr_idx, 0], X[curr_idx,1], c=factors[moran_idx][i], vmin=min_val, vmax=max_val, alpha=alpha, cmap='turbo')#, s=0.1)
        curr_ax.invert_yaxis()
        curr_ax.set_xticks([])
        curr_ax.set_yticks([])
        curr_ax.set_facecolor('xkcd:gray')

In [8]:
def main():
	# load data
	Y_sums = np.array(np.sum(adata.raw.X > 0, axis=0))[0]
	Y = np.array(adata.raw.X[:, Y_sums>200].todense(), dtype=int).T
	X = adata.obsm['spatial']
	X = X.astype('float64')
	X = rescale_spatial_coords(X)
	
	X = torch.tensor(X, dtype=torch.float)
	Y = torch.tensor(Y, dtype=torch.float)	

	X_train = X.to(device)
	Y_train = Y.to(device)
	neighbors = [8]#,2,3,4,5,6,7,8,9,10]
	inducing_pts = [1000] #[100, 500, 1000, 1500, 2000, 2500, 3000]
	gps = [10]
	run_benchmarking_experiments(X_train, Y_train, inducing_pts, gps, neighbors, steps=1000)

In [9]:
main()

  gp.Z = nn.Parameter(torch.tensor(X[idx]))
  X_train = torch.tensor(X_train, device='cuda', dtype=torch.float32, requires_grad=True)

  0%|                                                                                                      | 0/1000 [00:00<?, ?it/s][A
  0%|                                                                                                         | 0/1 [00:00<?, ?it/s]


TypeError: forward() got an unexpected keyword argument 'return_distance'