# MAP on Steinmetz
We run our algorithm on the Steinmetz dataset, with an anatomically inspired connectivity matrix.

## Load libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
pip install pyro-ppl

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyro-ppl
  Downloading pyro_ppl-1.8.4-py3-none-any.whl (730 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m730.7/730.7 KB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
Collecting pyro-api>=0.1.1
  Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)
Installing collected packages: pyro-api, pyro-ppl
Successfully installed pyro-api-0.1.2 pyro-ppl-1.8.4


In [None]:
import torch
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import random
import numpy as np

from tqdm.auto import trange

from scipy.spatial import distance
from scipy.stats import pearsonr 

import pandas as pd

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.distributions.laplace import Laplace
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.gamma import Gamma
from torch.distributions.exponential import Exponential
from torch.distributions.bernoulli import Bernoulli

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, HMC

def seed_everything(seed: int):
    """Sets the seed for generating random numbers in PyTorch, numpy and
    Python.

    Args:
        seed (int): The desired seed.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
def nx_to_laplacian(g, alpha=0.5):
  """
  alpha: float, controls how much we augment the diagonal of the graph Laplacian
  """
  adj_matrix = nx.convert_matrix.to_numpy_array(g)
  deg = np.sum(adj_matrix, axis=1)
  lap = np.diag(alpha + deg) - adj_matrix
  return torch.from_numpy(lap).float()

In [None]:
def initialize_model(K, N, T, g, sigma=0.01, random_seed=0, mask=True, alpha=0.5):
  """
  This function initializes our starting point for MAP estimation on L, D, A

  K : number of factors
  N: number of neurons
  T: number of time steps
  g: the base graph (with N neurons)
  """
  seed_everything(random_seed)
  S = Laplace(0, 1/np.sqrt(2 * T)).sample(sample_shape=(T,K)) # T x k
  D = torch.sqrt(Dirichlet(torch.ones(K) / K).sample(sample_shape=(N,)).T)
  A = torch.abs(MultivariateNormal(torch.zeros(N),  precision_matrix=nx_to_laplacian(g, alpha=alpha)).sample(sample_shape=(K, ))) #k x N

  # let's make the Lambdas
  B = Bernoulli(0.8).sample(sample_shape=(K,))
  L = B * Gamma(10,10).sample(sample_shape=(K,))
  E = (1 - B) * Exponential(1).sample(sample_shape=(K,))
  L = L + E
  L = L.sort().values

  if mask:
    Atilde = torch.diag(L) @ (D * A)
  else:
    Atilde = torch.diag(L) @ A

  Y = Normal(S @ Atilde, sigma).sample() # T x N

  return (S,A,D,L)

In [None]:
def stable_softmax(x):
    """
    Computes the numerically stable softmax of a tensor x on the first dimension.
    """
    max_val, _ = torch.max(x, dim=0, keepdim=True)
    x_exp = torch.exp(x - max_val)
    x_sum = torch.sum(x_exp, dim=0, keepdim=True)
    return x_exp / x_sum
def stable_dirichlet(logit_D):
    return logit_D.sum()

In [None]:
# compute losses
# the problem is the log density of the Dirichlet

def get_loss(Y, S, log_A, logit_D, log_L, g, sigma=0.01, alpha=0.5, beta=0, mask=True):
  T, N = Y.shape
  K, _ = log_A.shape
  A = torch.exp(log_A)
  L = torch.exp(log_L)
  if mask:
    D = torch.sqrt(stable_softmax(logit_D.T).T)  # D_i^2 is Dirichlet

  loss = 0
  # priors
  loss += -Laplace(0, 1/np.sqrt(2 * T)).log_prob(S).sum()
  loss += -MultivariateNormal(torch.zeros(N), precision_matrix=nx_to_laplacian(g, alpha=alpha)).log_prob(A).sum()
  if mask:
    loss += stable_dirichlet(logit_D)
  
  # Lambda loss
  gam = torch.log(torch.tensor(0.8)) + Gamma(10,10).log_prob(L)
  exp = torch.log(torch.tensor(0.2)) + Exponential(1).log_prob(L)
  cat = torch.vstack([gam, exp])
  loss += -torch.logsumexp(cat, 0).sum()


  # likelihood
  if mask:
    Atilde = A * D
    if D.isnan().sum() > 0:
      print("D is nan")
      print(D)
  else:
    Atilde = A
  loss += -Normal(S @ torch.diag(L) @ Atilde, sigma).log_prob(Y).sum()
  # l1 regularization
  loss += beta * torch.linalg.norm(A)
  return loss / T

In [None]:
def map_estimation(Y,
                   g,
                   K, 
                   N,
                   num_steps=2000,
                   alpha=0.5,
                   beta=0,
                   mask=False,
                   random_seed=1,
                   tol_steps=200):
  pbar = trange(num_steps)
  pbar.set_description("---")
  T, N = Y.shape
  # intialization
  S, A, D, L = initialize_model(K, N, T, g, random_seed=random_seed, mask=mask, alpha=alpha)
  S = nn.parameter.Parameter(S)
  log_A = nn.parameter.Parameter(torch.log(A))
  logit_D = nn.parameter.Parameter(2 * torch.log(D)) #D_i^2 ~ Dirichlet
  log_L = nn.parameter.Parameter(torch.log(L))

  S_best = S
  log_A_best = log_A
  logit_D_best = logit_D
  log_L_best = log_L
  s = 0

  best_loss = float('inf')
  optimizer = optim.Adam([S,log_A, logit_D, log_L], lr=1e-1)
  train_losses = []
  for step in pbar:
    if step-s > tol_steps:
      print(s)
      break
    with torch.set_grad_enabled(True):
        optimizer.zero_grad()
        loss = get_loss(Y, S, log_A, logit_D, log_L, g, beta=beta, mask=mask)
        if loss < best_loss:
          s = step
          log_L_best = log_L.detach().clone()
          sort = log_L_best.sort()
          log_L_best = sort.values
          idxs = sort.indices

          S_best = S.detach().clone()
          S_best = S_best[:, idxs]

          log_A_best = log_A.detach().clone()
          log_A_best = log_A_best[idxs, :]

          logit_D_best = logit_D.detach().clone()
          logit_D_best = logit_D_best[idxs, :]

          best_loss = loss
        loss.backward()
        optimizer.step()
        train_losses.append(loss.detach().numpy())

  return S_best, torch.exp(log_A_best), torch.sqrt(stable_softmax(logit_D_best)), torch.exp(log_L_best), train_losses

## Application Steinmetz

### load in Steinmetz data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import zscore
from sklearn.decomposition import PCA
import torch

In [None]:
# @title Figure settings
from matplotlib import rcParams

rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] = 15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

In [None]:
# @title Data retrieval
import os, requests

fname = []
for j in range(3):
  fname.append('steinmetz_part%d.npz'%j)
url = ["https://osf.io/agvxh/download"]
url.append("https://osf.io/uv3mw/download")
url.append("https://osf.io/ehmw2/download")

for j in range(len(url)):
  if not os.path.isfile(fname[j]):
    try:
      r = requests.get(url[j])
    except requests.ConnectionError:
      print("!!! Failed to download data !!!")
    else:
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open(fname[j], "wb") as fid:
          fid.write(r.content)

In [None]:
# @title Data loading
alldat = np.array([])
for j in range(len(fname)):
  alldat = np.hstack((alldat,
                      np.load('steinmetz_part%d.npz'%j,
                              allow_pickle=True)['dat']))

In [None]:
# Take a single mouse (the sample notebook takes mouse 11, since it has some neurons from vis_ctx)
# The rest of this notebook will just focus on this mouse
dat = alldat[11]

In [None]:
def get_successful_trials(dat): 
  """
  Drop the trials where the mouse failed to correctly distinguish the contrast 
  identification challenge.

  This returns the spike data on successful trials.

  TODO: might want to wrap this into a whole data cleaning function that returns 
  the cleaned dat dictionary object instead of just spikes.
  """
  result = np.zeros_like(dat["contrast_right"])
  mask1 = dat["contrast_right"] > dat["contrast_left"]
  mask2 = dat["contrast_left"] > dat["contrast_right"]
  result[mask1] = -1
  result[mask2] = 1

  success_idx = np.where(result == dat["response"])
  success_idx = np.squeeze(success_idx)

  return dat["spks"][:, success_idx, :]

In [None]:
success_dat = get_successful_trials(dat)
success_dat.shape

In [None]:
steinmetz_adj = pd.read_csv("/content/drive/MyDrive/connectome_prior/steinmetz.csv", dtype=int)

G = nx.Graph(steinmetz_adj.values)

In [None]:
Y = success_dat[:, 0, :]
Y.shape

In [None]:
MASK = False
N = 698
ALPHA = 0.5

In [None]:
K = 10
S1_map, A1_map, D1_map, L1_map, train1_losses = map_estimation(torch.tensor(Y.T), G, K, N, mask=MASK, num_steps=3000, alpha=ALPHA, tol_steps=300, random_seed=1)
S2_map, A2_map, D2_map, L2_map, train2_losses = map_estimation(torch.tensor(Y.T), G, K, N, mask=MASK, num_steps=3000, alpha=ALPHA, tol_steps=300, random_seed=2)

In [None]:
plt.plot(train1_losses)

In [None]:
plt.plot(train1_losses[-100:])

In [None]:
# sim = np.zeros((num_subgraphs,K))
# for i in range(num_subgraphs):
#   for j in range(K):
#     if MASK:
#       sim[i,j] = pearsonr( L1_map[j] * D1_map[j] * A1_map[j], L2_map[j] * D2_map[j] * A2_map[j])[0]
#     else:
#       sim[i,j] = pearsonr(L1_map[i] * A1[i], L2_map[j] * A2_map[j])[0]

# plt.imshow(sim, vmin=0, vmax=1)
# plt.colorbar()

sim = np.zeros((K,K))
for i in range(K):
  for j in range(K):
    if MASK:
      sim[i,j] = pearsonr( L1_map[j] * D1_map[j] * A1_map[j], L2_map[j] * D2_map[j] * A2_map[j])[0]
    else:
      sim[i,j] = pearsonr(L1_map[i] * A1_map[i], L2_map[j] * A2_map[j])[0]

plt.imshow(sim, vmin=0, vmax=1)
plt.colorbar()

In [None]:
L1_map

In [None]:
L2_map

## MCMC for Steinmetz

In [None]:
N = 698
K = 10
Y=torch.tensor(Y.T)
T,N=Y.shape

In [None]:

S_prior = dist.Laplace(0,1/np.sqrt(2 * T))
A_prior = dist.MultivariateNormal(torch.zeros(N), precision_matrix=nx_to_laplacian(G))
D_prior = dist.Dirichlet(torch.Tensor(K*[1/K]))
L_prior = dist.Gamma(torch.Tensor([10.0]),torch.Tensor([10.0]))

def model(data):
  A = torch.zeros(K,N)
  D = torch.zeros(K,N)
  S = torch.zeros(T,K)
  Lambda = torch.zeros(K)
  for k in range(K):
    A[k,:] = torch.abs(pyro.sample(f'A_tilde_prior_{k}', A_prior))
    #Lambda[k] = pyro.sample(f'L_prior_{k}',L_prior)
    for t in range(T):
      S[t,k] = pyro.sample(f'S_prior_{t}_{k}',S_prior)
  #for i in range(N):
  #  D[:,i] = torch.sqrt(pyro.sample(f'D_prior_{i}',D_prior))
  #Lambda,_ = torch.sort(Lambda)
  #A = Lambda.unsqueeze(1)*A*D
  #sigma_2 = 1/pyro.sample("Sigma_prior",dist.Gamma(torch.Tensor([1.0]),torch.Tensor([1.0])))
  with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.MultivariateNormal(S@A, 0.01*torch.eye(N)), obs=data)
  
pyro.clear_param_store()

# 2. Define the MCMC kernel function we will employ, and tell
# it to use the model function we defined as the basis for
# sampling
my_kernel = NUTS(model)


# 3. Define the MCMC algorithm with our specific
# implementation of choice and the number of samples
# to use to evaluate the most likely distribution
# of "weight1".

my_mcmc = MCMC(my_kernel,
               num_samples=100,
               warmup_steps=50)

# 4. Run the algorithm, send our observations 
# (notice this is the parameter model(observations) receives)
mc_results=my_mcmc.run(Y)

In [None]:
samples=torch.abs(my_mcmc.get_samples()["A_tilde_prior_0"])
samples.shape 
plt.acorr(samples[:,1].flatten().cpu(),maxlags = 299)

In [None]:
Lambda = torch.zeros(K)
A_sample = torch.zeros(K,N)
D = torch.zeros(K,N)
for k in range(K):
    A_sample[k,:] = torch.abs(my_mcmc.get_samples()[f'A_tilde_prior_{k}']).mean(axis=0)

In [None]:
plt.imshow(A_sample.cpu())

In [None]:
num_subgraphs=K
sim=np.zeros((num_subgraphs,num_subgraphs))
for i in range(num_subgraphs):
  for j in range(num_subgraphs):
    sim[i,j]=pearsonr(A1.cpu()[i], A_sample.cpu()[j])[0]
plt.imshow(sim,vmin=0,vmax=1)
plt.colorbar()