In [None]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import h5py
import multiprocessing
from math import pi

from analysis.config import birds, h5_path_dict, pickle_dir
from analysis.ExpData import ExpData
from analysis.utils import popln_overlap, get_fr

In [None]:
def get_site_distance(a, b):
    dist = min(
        abs(a - b), 16 - abs(a-b)
        ) % 16
    return dist

In [None]:
with open(pickle_dir / 'population_patterns.p', 'rb') as f: 
    population_patterns = pickle.load(f)

# Spatial Dimensionality

### Collect Distance Matrix

In [None]:
def get_distance_mat(fil_string):
    f = h5py.File(fil_string, 'r')
    exp_data = ExpData(f)
    noncr_visits = population_patterns[fil_string]['noncr_visits']
    visit_patterns = population_patterns[fil_string]['visit_patterns']
    distance_matrix = [[[] for _ in range(16)] for _ in range(16)]
    for i, noncr_visit in enumerate(noncr_visits):
        noncr_site = int(exp_data.visit_wedges[noncr_visit]) - 1
        if noncr_site == 16: continue
        navig_pattern_mat = visit_patterns[noncr_visit]
        for j, noncr2_visit in enumerate(noncr_visits):
            if j < i: continue
            noncr2_site = int(exp_data.visit_wedges[noncr2_visit]) - 1
            if noncr2_site == 16: continue
            navig2_pattern_mat = visit_patterns[noncr2_visit]
            overlaps = []
            for navig2_pattern in navig2_pattern_mat.T:
                overlaps.append(np.nanmax(
                    popln_overlap(navig_pattern_mat, navig2_pattern)
                    ))
            max_overlap = np.nanmax(overlaps)
            distance_matrix[noncr_site][noncr2_site].append(max_overlap)
            if i != j:
                distance_matrix[noncr2_site][noncr_site].append(max_overlap)
    return np.array(distance_matrix)

In [None]:
%%capture
PROCESSES = 5
distance_matrices = []
with multiprocessing.Pool(PROCESSES) as pool:
    params = [p for p in population_patterns.keys()]
    for res in pool.map(get_distance_mat, params):
        distance_matrices.append(res)

In [None]:
with open(pickle_dir / 'reactivation_place_dist_matrix.p', 'wb') as f:
    pickle.dump(distance_matrices, f)

### Plot raw distance matrices

In [None]:
with open(pickle_dir / 'reactivation_place_dist_matrix.p', 'rb') as f:
    distance_matrices = pickle.load(f)

In [None]:
distance_matrix = [[[] for _ in range(16)] for _ in range(16)]
for mat in distance_matrices:
    for i in range(16):
        for j in range(16):
            distance_matrix[i][j].extend(mat[i][j])
distance_matrix = np.array(distance_matrix)

In [None]:
x = np.zeros(distance_matrix.shape)
for i in range(x.shape[0]):
    for j in range(x.shape[1]):
        x[i,j] = np.nanmedian(distance_matrix[i,j])
plt.imshow(x)
plt.ylabel("Navigation")
plt.xlabel("Navigation")
plt.title("Median Overlap")
plt.xticks(np.arange(0, 16, 2), np.arange(0, 16, 2)+1)
plt.yticks(np.arange(0, 16, 2), np.arange(0, 16, 2)+1)
plt.colorbar()

In [None]:
vals_by_sitedist = [[] for _ in range(9)]
for i in range(distance_matrix.shape[0]):
    for j in range(distance_matrix.shape[1]):
        mat_ij = np.array(distance_matrix[i,j])
        nonnan_vals = mat_ij[np.logical_not(np.isnan(mat_ij))].tolist()
        dist = get_site_distance(i,j)
        vals_by_sitedist[dist].extend(nonnan_vals)

In [None]:
plt.plot(
    np.arange(9), [np.nanmedian(v) for v in vals_by_sitedist],
    linewidth=2
    )
plt.ylim(0, 1)
plt.title("Navigation/Navigation Overlap")
plt.xlabel("Site Distance between\nVisits")
plt.ylabel("Median Overlap")
plt.show()

# Cache v Cache Dimensionality

### Collect Distance Matrix

In [None]:
distance_matrix = [[[] for _ in range(16)] for _ in range(16)]
for fil_string in population_patterns.keys():
    f = h5py.File(fil_string, 'r')
    exp_data = ExpData(f)
    c_visits = population_patterns[fil_string]['c_visits']
    r_visits = population_patterns[fil_string]['r_visits']
    visit_patterns = population_patterns[fil_string]['visit_patterns']
    for i, c_visit in enumerate(c_visits):
        c_site = int(exp_data.cr_sites[exp_data.cr_was_cache][i]) - 1
        cache_pattern_mat = visit_patterns[c_visit]
        for j, c2_visit in enumerate(c_visits):
            if j <= i: continue
            c2_site = int(exp_data.cr_sites[exp_data.cr_was_cache][j]) - 1
            cache2_pattern_mat = visit_patterns[c2_visit]
            overlaps = []
            for cache2_pattern in cache2_pattern_mat.T:
                overlaps.append(np.nanmax(
                    popln_overlap(cache_pattern_mat, cache2_pattern)
                    ))
            max_overlap = np.nanmax(overlaps)
            distance_matrix[c_site][c2_site].append(max_overlap)
            if i != j:
                distance_matrix[c2_site][c_site].append(max_overlap)
distance_matrix = np.array(distance_matrix)

### Plot raw distance matrices

In [None]:
vals_by_sitedist = [[] for _ in range(9)]
for i in range(distance_matrix.shape[0]):
    for j in range(distance_matrix.shape[1]):
        mat_ij = np.array(distance_matrix[i,j])
        nonnan_vals = mat_ij[np.logical_not(np.isnan(mat_ij))].tolist()
        dist = get_site_distance(i,j)
        vals_by_sitedist[dist].extend(nonnan_vals)

In [None]:
x = np.zeros(distance_matrix.shape)
for i in range(x.shape[0]):
    for j in range(x.shape[1]):
        x[i,j] = np.nanmedian(distance_matrix[i,j])
plt.imshow(x)
plt.ylabel("Cache")
plt.xlabel("Retrival")
plt.xticks(np.arange(16), np.arange(16)+1)
plt.yticks(np.arange(16), np.arange(16)+1)
plt.colorbar()

In [None]:
plt.plot(
    np.arange(9), [np.nanmedian(v) for v in vals_by_sitedist],
    linewidth=2
    )
plt.ylim(0, 1)
plt.title("Cache/Cache Overlap")
plt.xlabel("Site Distance between\nCache/Retrieval")
plt.ylabel("Median Overlap")
plt.show()

# Cache v Retrieval Dimensionality

### Collect Distance Matrix

In [None]:
distance_matrix = [[[] for _ in range(16)] for _ in range(16)]
for fil_string in population_patterns.keys():
    f = h5py.File(fil_string, 'r')
    exp_data = ExpData(f)
    c_visits = population_patterns[fil_string]['c_visits']
    r_visits = population_patterns[fil_string]['r_visits']
    visit_patterns = population_patterns[fil_string]['visit_patterns']
    for i, c_visit in enumerate(c_visits):
        c_site = int(exp_data.cr_sites[exp_data.cr_was_cache][i]) - 1
        cache_pattern_mat = visit_patterns[c_visit]
        for j, r_visit in enumerate(r_visits):
            r_site = int(exp_data.cr_sites[exp_data.cr_was_retrieval][j]) - 1
#             if c_site == r_site:
#                 if r_visit < c_visit: continue
            retriev_pattern_mat = visit_patterns[r_visit]
            overlaps = []
            for retriev_pattern in retriev_pattern_mat.T:
                overlaps.append(np.nanmax(
                    popln_overlap(cache_pattern_mat, retriev_pattern)
                    ))
            max_overlap = np.nanmax(overlaps)
            distance_matrix[c_site][r_site].append(max_overlap)
#             if c_site == r_site:
#                 break
distance_matrix = np.array(distance_matrix)

### Plot raw distance matrices

In [None]:
vals_by_sitedist = [[] for _ in range(9)]
for i in range(distance_matrix.shape[0]):
    for j in range(distance_matrix.shape[1]):
        mat_ij = np.array(distance_matrix[i,j])
        nonnan_vals = mat_ij[np.logical_not(np.isnan(mat_ij))].tolist()
        dist = get_site_distance(i,j)
        vals_by_sitedist[dist].extend(nonnan_vals)

In [None]:
x = np.zeros(distance_matrix.shape)
for i in range(x.shape[0]):
    for j in range(x.shape[1]):
        x[i,j] = np.nanmedian(distance_matrix[i,j])
plt.imshow(x)
plt.ylabel("Cache")
plt.xlabel("Retrival")
plt.xticks(np.arange(0,16,2), np.arange(0,16,2)+1)
plt.yticks(np.arange(0,16,2), np.arange(0,16,2)+1)
plt.title("Median Overlap")
plt.colorbar()

In [None]:
plt.plot(
    np.arange(9), [np.nanmedian(v) for v in vals_by_sitedist],
    linewidth=2
    )
plt.ylim(0, 1)
plt.title("Cache/Retrieval Overlap")
plt.xlabel("Site Distance between\nCache/Retrieval")
plt.ylabel("Median Overlap")
plt.show()

# Cache v Navigation Dimensionality
(Using Cache "Memory")

### Collect Distance Matrix

In [None]:
with open(pickle_dir / 'episode_cells_overlap.p', 'rb') as f:
    episode_cells = pickle.load(f)

In [None]:
distance_matrix = [[[] for _ in range(16)] for _ in range(16)]
counts = np.zeros((16, 16))
for fil_string in population_patterns.keys():
    f = h5py.File(fil_string, 'r')
    exp_data = ExpData(f)
    c_visits = population_patterns[fil_string]['c_visits']
    r_visits = population_patterns[fil_string]['r_visits']
    noncr_visits = population_patterns[fil_string]['noncr_visits']
    fr = population_patterns[fil_string]['fr']
    visit_patterns = population_patterns[fil_string]['visit_patterns']
    cr_idx_mat = episode_cells[fil_string]['cr_idx_mat']
    for i, c_visit in enumerate(c_visits):
        c_site = exp_data.cr_sites[exp_data.cr_was_cache][i]
        c_site = int(c_site) - 1
        cache_pattern = cr_idx_mat[i]
        cache_pattern[np.isnan(cache_pattern)] = 0
        if np.sum(cache_pattern) == 0: continue
        for j, noncr_visit in enumerate(noncr_visits):
            overlaps = []
            noncr_site = exp_data.visit_wedges[noncr_visit]
            noncr_site = int(noncr_site) - 1
            if noncr_site == 16: continue
            navig_pattern_mat = visit_patterns[noncr_visit]
            for navig_pattern in navig_pattern_mat.T:
                overlaps.append(
                    popln_overlap(cache_pattern[:,None], navig_pattern)[0]
                    )
            max_overlap = np.nanmax(overlaps)
            distance_matrix[c_site][noncr_site].append(max_overlap)
            counts[c_site][noncr_site] += 1
distance_matrix = np.array(distance_matrix)

### Plot raw distance matrices

In [None]:
vals_by_sitedist = [[] for _ in range(9)]
for i in range(distance_matrix.shape[0]):
    for j in range(distance_matrix.shape[1]):
        mat_ij = np.array(distance_matrix[i,j])
        nonnan_vals = mat_ij[np.logical_not(np.isnan(mat_ij))].tolist()
        dist = get_site_distance(i,j)
        vals_by_sitedist[dist].extend(nonnan_vals)

In [None]:
x = np.zeros(distance_matrix.shape)
for i in range(x.shape[0]):
    for j in range(x.shape[1]):
        x[i,j] = np.nanmedian(distance_matrix[i,j])
plt.imshow(x)
plt.ylabel("Cache")
plt.xlabel("Arbitrary Visit")
plt.xticks(np.arange(0,16,2), np.arange(0,16,2)+1)
plt.yticks(np.arange(0,16,2), np.arange(0,16,2)+1)
plt.title("Median Overlap")
plt.clim(0, 0.8)
plt.colorbar()

In [None]:
plt.plot(
    np.arange(9), [np.nanmedian(v) for v in vals_by_sitedist],
    linewidth=2
    )
plt.ylim(0, 1)
plt.title("Cache/Navigation Overlap")
plt.xlabel("Site Distance between\nCache/Navigation")
plt.ylabel("Median Overlap")
plt.show()