In [2]:
# Extract from Ann's data the Place Cells indexes

import os
parent_directory = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
grandparent_directory = os.path.abspath(os.path.join(parent_directory, os.pardir))
os.sys.path.append(grandparent_directory)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.graph_objects as go
from scipy.stats import pearsonr
import glob
from scipy.io import loadmat
import pandas as pd

from tools.data_formatting import get_smoothed_moving_all_data, get_common_indexes_n_recordings, smooth_tuning_curves_circularly, from_local_to_global_index
from tools.data_manager import get_all_experiments_runs, get_fovs_given_animal
from tools.alignment import procrustes, canoncorr
from params import order_experiments, get_colors_for_each_experiment, animals, root_dir, experiments_to_exclude


def find_path_to_data_folder(animal, fov=None, experiment=None, run=None):
    """Given the animal, fov, experiment and run find the path to the data folder. Made to account the type.
    If experiment and run are not provided, it will return the path of the animal and fov."""
    if (experiment is not None) and (run is not None) and (fov is not None):
        res = glob.glob(f'{root_dir+'data/'}/**/m*/m*/', recursive=True)
        for path in res:
            animal_path = path.split('/')[-2].split('_')[0]
            fov_path = path.split('/')[-2].split('_')[1]
            experiment_path = path.split('/')[-2].split('_')[2].split('-')[0]
            run_path = path.split('/')[-2].split('_')[2].split('-')[1].split('.')[0]
            if (animal==animal_path) and (fov==fov_path) and (experiment==experiment_path) and (run==run_path):
                return path
        print('Not path found for: ', animal, fov, experiment, run)
    else:
        res = glob.glob(f'{root_dir+'data/'}/**/{animal}/', recursive=True)
        return res[0]
    return None

In [31]:
### Load all sessions ###

animal = 'm135'
fov = 'fov1'
sessions = get_all_experiments_runs(animal, fov)
# Remove sessions to exclude
sessions = [session for session in sessions if session not in experiments_to_exclude]
# Create a mapping of sessions and their chronological order
order_map = {value: index for index, value in enumerate(order_experiments)}
sessions = sorted(sessions, key=lambda x: order_map[x])

# Load the cells with the global index and the tuning curves
all_cells = []
all_tuning_curves = []
all_pc_idxs = []
for (experiment, run) in sessions:
    _, _, phi, cells, tuning_curves, _ = get_smoothed_moving_all_data(animal, fov, experiment, run) 
    # Load PC indexes
    path = [p for p in glob.glob(find_path_to_data_folder(animal, fov, experiment, run)+'/*', recursive=True) if p.endswith('_PFmap_output.mat')][0]
    pc_mat = loadmat(path)
    pc_ids = pc_mat['hist'][0][0]['SIspk'][0]['pcIdx'][0].squeeze()
    # Get the indexes of the Place Cells
    pc_idxs = np.where(np.isin(np.array(cells, dtype=float), pc_ids))[0]
    # Get the global IDs of the cells
    cells = from_local_to_global_index(animal, fov, (experiment, run), cells)
    all_cells.append(cells)
    all_tuning_curves.append(tuning_curves)
    all_pc_idxs.append(pc_idxs)

In [35]:
# Try on one session only first
# Remove the place cells and check how well the tuning curves are predicted with the alignment

# Smooth the tuning curves very little for better alignment
smoothed_tuning_curves = [smooth_tuning_curves_circularly(tuning_curves, 5) for tuning_curves in all_tuning_curves]
# Take the first session as the reference
ref = 0
exp0, run0 = sessions[ref]
ref_pc_idxs = all_pc_idxs[ref]
ref_avg_fr = smoothed_tuning_curves[ref]
ref_cells = all_cells[ref]
# Get the tuning curves without the place cells
ref_tc_wo_pc = np.delete(ref_avg_fr, ref_pc_idxs, axis=1)

pca = PCA(n_components=20)
pca_ref_tc_wo_pc = pca.fit_transform(ref_tc_wo_pc)



In [36]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=pca_ref_tc_wo_pc[:, 0], y=pca_ref_tc_wo_pc[:, 1], z=pca_ref_tc_wo_pc[:, 2], mode='markers', marker=dict(size=2, color=np.arange(pca_ref_tc_wo_pc.shape[0]), colorscale='hsv')))
fig.show()