In [2]:
import os
import sys
import pickle

import numpy as np
import matplotlib.pyplot as plt
from astropy.convolution import Gaussian2DKernel, convolve

import umap
from sklearn.decomposition import PCA

from datashader import transfer_functions as tf
from datashader.colors import inferno, viridis


sys.path.append("../src") if "../src" not in sys.path else None 

from Experiment import Experiment
from methods import filenames

In [3]:
def load_experiment(name, loc):
    experiment = Experiment(name = name, base_path=loc)
    experiment.setup()
    return experiment

def load_ratemaps(experiment):
    ratemaps = []
    kernel = Gaussian2DKernel(x_stddev=1)
    for env_i in range(len(experiment.environments)):
        with open(experiment.paths['ratemaps'] / f'env_{env_i}' / filenames(experiment.paths['ratemaps'] / f'env_{env_i}')[-1], "rb") as f:
            ratemaps.append(convolve(pickle.load(f), kernel.array[None]))
    ratemaps = np.concatenate(ratemaps)
    return ratemaps

def env_split(stats, n_envs, Ng=4096):
    # stack ratemaps by environment
    env_stat = []
    for env_i in range(n_envs):
        tmp = stats[Ng*env_i:Ng*(env_i+1)]
        env_stat.append(tmp)
    return np.array(env_stat)

In [4]:
def pca_UMAP(states):
    """ Run PCA followed by UMAP on states; Similar to Gardner et al. (2022)
    
    states.shape = (Nsamples, Nfeatures)
    PCA down to 6 features/principal components 
    UMAP of PCA componenents to 3 features.
    
    Note: Gardner et al. use n_neighbors = 5000 (more than number of samples in our case...)
    """
    pca_fit = PCA(n_components = 6).fit(states)
    pca_result = pca_fit.transform(states)
    umap_fit = umap.UMAP(n_components = 3, min_dist = 0.8,  # Almost Gardner et al. params
                         n_neighbors = 1000, metric = "cosine", init = "spectral")
    umap_result = umap_fit.fit_transform(pca_result)
    return pca_result, umap_result

In [5]:
def scatter3d(data, ncols=4, nrows=4, s=1, alpha=0.5, azim_elev_title=True, **kwargs):
    assert data.shape[-1] == 3, "data must have three axes. No more, no less."
    if data.ndim > 2:
        data = data.reshape(-1, 3)
    fig, axs = plt.subplots(ncols=ncols, nrows=nrows, subplot_kw={"projection": "3d"}, **kwargs)
    num_plots = ncols * nrows
    azims = np.linspace(0, 180, ncols + 1)[:-1]
    elevs = np.linspace(0, 90, nrows + 1)[:-1]
    view_angles = np.stack(np.meshgrid(azims, elevs), axis=-1).reshape(-1, 2)
    for i, ax in enumerate(axs.flat):
        ax.scatter(xs=data[:, 0], ys=data[:, 1], zs=data[:, 2], s=s, alpha=alpha)
        ax.azim = view_angles[i, 0]
        ax.elev = view_angles[i, 1]
        ax.axis("off")
        if azim_elev_title:
            ax.set_title(f"azim={ax.azim}, elev={ax.elev}")
    return fig, axs

In [6]:
# Load ratemaps and module indices
experiment = load_experiment("gg-3ME", "/home/users/vemundss/")
ratemaps = load_ratemaps(experiment)

module_indices = np.load("/home/users/markusbp/data/emergent-grid-cells/module_indices.npz")

Experiment <gg-3ME> already EXISTS. Loading experiment settings!
Loading experiment details
This experiment has ALREADY been setup - SKIPPING.


In [7]:
ratemaps = env_split(ratemaps, len(experiment.environments))
ratemaps.shape

(3, 4096, 64, 64)

In [None]:
g = np.reshape(ratemaps, (*ratemaps.shape[:2],-1)) # flatten bin dims
g = np.transpose(g, (0, 2, 1)) # reshape so cell dim is last

for key in module_indices:
    inds = module_indices[key]
    
    for j in range(3):
        pc, um = pca_UMAP(g[j, :,inds].T)
        fig, axs = scatter3d(um.reshape(64,64,3),ncols=3,nrows=3,s=0.1,alpha=0.5, figsize=(10,10))
        plt.show()
    

(245, 4096)


In [None]:
"""
1. Run through each environments
2. Run through each interesting cluster index
3. Identify cells that are implicated in the cluster
4. Perform PCA + UMAP in each environment
5. Plot result

# Note; need to cut out non-stddev cells for each cluster
"""

fig_titles = []
pca_fits = []
umap_fits = []


for i in range(len(ratemaps)):

    cluster_ind = interesting_clusters[i]
    for j in cluster_ind:
        cell_mask =  cluster_labels[i] == j
        
        #fig = plt.figure(figsize = (8,4))
        fig.supxlabel(f"No. Cells:{cell_mask.sum()}")
        
        env_pca_fits = []
        env_umap_fits = []
        env_titles = []
        
        for k in range(len(environments)):
            #ax = fig.add_subplot(1, 3, k+1, projection = "3d")
            pca_fit2, u2 = pca_UMAP(g[k][:,std_mask[i]][:,cell_mask])
            
            title = f"C{j}_from_env{i}_in_env{k}"
            env_titles.append(title)
            env_pca_fits.append(pca_fit2)
            env_umap_fits.append(u2)
            
            #plot_3d_proj(u2, pca_fit2[:,0], title = title, ax = ax)
        #plt.show()
        
        fig_titles.append(env_titles)
        pca_fits.append(env_pca_fits)
        umap_fits.append(env_umap_fits)

In [16]:
env_figs = []

for u, pc, title in zip(umap_fits, pca_fits, fig_titles):
    fig = plt.figure(figsize = (8,4))
    print(*title)
    for i in range(len(environments)):
        ax = fig.add_subplot(1, 3, i+1, projection = "3d")
        plot_3d_proj(u[i], pc[i][:,0], axes = ax)
    env_figs.append(fig)

NameError: name 'umap_fits' is not defined