In [None]:
import os
import re
import sys
import glob
import pickle
import numpy as np
import matplotlib.pyplot as plt
if '../scripts' not in sys.path:
    sys.path.append('../scripts')
from plots import plot_wmx, plot_wmx_avg, plot_w_distr, plot_weights, save_selected_w

In [None]:
athorny_place_cells_ratio = 0.2
data = np.load(f'../files/spike_times_a-thorny_place_cells_ratio={athorny_place_cells_ratio:.1f}.npz',
               allow_pickle=True)
spike_times = data['spike_trains'].item()
place_cells = data['place_cell'].item()
cell_types = place_cells.keys()
selected_cells = {}
n_selected_cells = 5
for cell_type in cell_types:
    idx, = np.where(place_cells[cell_type])
    if len(idx) > 0:
        jdx = np.linspace(0, idx.size-1, n_selected_cells+2, dtype=np.int32)
        selected_cells[cell_type] = idx[jdx[1:-1]]
    else:
        selected_cells[cell_type] = np.random.permutation(place_cells[cell_type].size)[:n_selected_cells]
        selected_cells[cell_type].sort()

In [None]:
data = np.load(f'../files/weights_a-thorny_place_cells_ratio={athorny_place_cells_ratio:.1f}.npz',
               allow_pickle=True)
weights = data['weights'].item()
config = data['config'].item()
connections = config['connectivity']

In [None]:
# %%capture
selection = {
    'thorny': {
        'thorny': selected_cells['thorny'],
        'a-thorny': selected_cells['a-thorny']
    },
    'a-thorny': {
        'thorny': selected_cells['thorny'],
        'a-thorny': selected_cells['a-thorny']
    }
}
min_weight = {'thorny': {'thorny': 0.5, 'a-thorny': 0.005},
              'a-thorny': {'thorny': 0.005, 'a-thorny': 0.005}}
max_weight = {'thorny': {'thorny': 8, 'a-thorny': 1},
              'a-thorny': {'thorny': 1, 'a-thorny': 4}}
max_count = {'thorny': {'thorny': 1e6, 'a-thorny': 1e5},
              'a-thorny': {'thorny': 1e5, 'a-thorny': 1e5}}
for pre,post in zip(connections['pre'], connections['post']):
    wgts = weights[pre][post].toarray()
    fig = plt.figure(figsize=(10, 10))
    gs = fig.add_gridspec(4, 2)
    ax = [fig.add_subplot(gs[:2, 0]),
          fig.add_subplot(gs[:2, 1]),
          fig.add_subplot(gs[2, 0]),
          fig.add_subplot(gs[3, 0]),
          fig.add_subplot(gs[2:, 1])]
    plot_wmx(wgts, ax=ax[0])
    plot_wmx_avg(wgts, n_pops=100, ax=ax[1])
    xlim = [min_weight[pre][post], wgts.max() * 1e9]
    plot_w_distr(wgts, bins=50, ax=ax[2:4], xlim=xlim, ylim=[1, max_count[pre][post]])
    plot_weights(save_selected_w(wgts, selection[pre][post]), ax=ax[-1], ylim=[0, wgts.max() * 1e9])
    fig.tight_layout()
    fig.savefig(f'weights_{pre}_{post}_a-thorny_place_cells_ratio={athorny_place_cells_ratio:.1f}.pdf')

In [None]:
raise Exception('stop here')

In [None]:
place_cell_ratio = 0.5
track_type = 'linear'
n_neurons = 8000
t_max = 1205
data_folder = os.path.join('..', 'files', f't_max={t_max:.0f}')
weights_files_pattern = os.path.join(data_folder,
        f'wmx_sym_N={n_neurons}_ratio={place_cell_ratio}_dur=*_{track_type}_sparse.pkl')
weights_files = glob.glob(weights_files_pattern)
durs = np.array([float(re.findall('dur=\d+', f)[0].split('=')[1]) for f in weights_files])
idx = np.argsort(durs)
weights_files = [weights_files[i] for i in idx]
durs = durs[idx]
weights = [pickle.load(open(f, 'rb')) for f in weights_files]

In [None]:
%%capture
weight_max = 13
for i,dur in enumerate(durs):
    wgts = weights[i].toarray()
    selection = np.array([501, 2400, 4002, 5502, 7015])
    fig = plt.figure(figsize=(10, 10))
    gs = fig.add_gridspec(4, 2)
    ax = [fig.add_subplot(gs[:2, 0]),
          fig.add_subplot(gs[:2, 1]),
          fig.add_subplot(gs[2, 0]),
          fig.add_subplot(gs[3, 0]),
          fig.add_subplot(gs[2:, 1])]
    plot_wmx(wgts, ax=ax[0])
    plot_wmx_avg(wgts, n_pops=100, ax=ax[1])
    plot_w_distr(wgts, bins=50, ax=ax[2:4], xlim=[0.5, weight_max], ylim=[1, 10000])
    plot_weights(save_selected_w(wgts, selection), ax=ax[-1], ylim=[0, weight_max])
    fig.tight_layout()
    out_file = os.path.join('..', 'figures', f't_max={t_max:.0f}',
                            f'wmx_sym_N={n_neurons}_ratio={place_cell_ratio}_dur={dur:.0f}_{track_type}_sparse.pdf')
    fig.savefig(out_file)