In [2]:
import cbm_pack.cbm_pack as cbm_pack
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.colors import colorConverter
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d import Axes3D

In [11]:
extensions = {
    "mf_raster" : ".mfr",
    "gr_raster" : ".grr",
    "go_raster" : ".gor",
    "pc_raster" : ".pcr",
    "bc_raster" : ".bcr",
    "sc_raster" : ".scr",
    "nc_raster" : ".ncr",
    "io_raster" : ".ior",
    "mf_psth" : ".mfp", 
    "gr_psth" : ".grp",
    "go_psth" : ".gop",
    "pc_psth" : ".pcp",
    "bc_psth" : ".bcp",
    "sc_psth" : ".scp",
    "nc_psth" : ".ncp",
    "io_psth" : ".iop",
    "pfpc_weights" : ".pfpcw",
    "mfnc_weights" : ".mfncw"
}
cell_nums = {
    "mf" : 4096,
    "gr" : 2 ** 20,
    "go" : 4096,
    "pc" : 32,
    "bc" : 128,
    "sc" : 512,
    "nc" : 8,
    "io" : 4
}

In [None]:
MAX_CR_AMP = 6
PRE_CS_COLLECT_TS = 400
POST_CS_COLLECT_TS = 400
ISIS = np.array([150, 200, 250, 500, 750, 1000, 1250, 1500, 1750, 2000])
ISI = 1500
NUM_COLLECT_TRIALS = 1000 
NUM_COLLECT_TS = PRE_CS_COLLECT_TS + ISI + POST_CS_COLLECT_TS
INPUT_DIR = "./data/inputs/Forgetting/"
OUTPUT_DIR = "./data/outputs/Forgetting/"
INPUT_BASE = f"acq_no_probe_ISI_{ISI}"
INPUT_PC_FILE = INPUT_DIR + INPUT_BASE + "/" + INPUT_BASE + extensions["pc_raster"]
INPUT_NC_FILE = INPUT_DIR + INPUT_BASE + "/" + INPUT_BASE + extensions["nc_raster"]

# Acquisition Analysis

## Acquisition from Data Collected on all trials

- for ISI less than 1000, 500 paired trials were run: for 1000 and above, 1000 paired trials were run
- pc and nc raster data were collected during every trial

## Analysis Outline

- compute the CRs for every isi, every trial, every ts
- plot the time series (across trials) of the maximum CR for every ISI
- within the plot, as to avoid too many data points, sample every 10 or so data points

In [None]:
def plot_acq_all_isi_from_nc_use_cr_amp(isis: np.ndarray, method="mike"):
    num_collect_trials = 500
    fig = plt.figure()
    fig.suptitle("Acquisition: CR Amplitude", fontsize=14)
    fig.patch.set_facecolor('white')
    ax = plt.subplot(111)
    ax.set_xlabel('trial num', fontsize=12)
    ax.set_ylabel('eyelid closure (mm)', fontsize=12)
    
    for isi in isis:
        num_collect_ts = PRE_CS_COLLECT_TS + isi + POST_CS_COLLECT_TS
        if isi > 750:
            num_collect_trials = 1000
        input_base = f"acq_no_probe_ISI_{isi}"
        input_nc_file = INPUT_DIR + input_base + "/" + input_base + extensions["nc_raster"]
        nc_rasters = cbm_pack.np_arr_from_file(input_nc_file, np.uint8)
        nc_rasters = cbm_pack.reshape_raster( \
                nc_rasters, \
                cell_nums["nc"], \
                num_collect_trials, \
                num_collect_ts)
        if method == "mike":
            nc_crs = cbm_pack.ncs_to_cr_mike(nc_rasters)
        elif method == "sean":
            nc_crs = cbm_pack.ncs_to_cr_sean(nc_rasters, \
                PRE_CS_COLLECT_TS, \
                POST_CS_COLLECT_TS, \
                isi)
        trials_to_plot = np.arange(0, num_trials, 10)
        crs_to_plot = nc_crs[trials_to_plot]
        crs_to_plot = np.max(crs_to_plot, axis=1)
        # boilerplate matplotlib code
        if isi <= 1250:
            plt.plot(trials_to_plot, crs_to_plot, 'o-', markersize=2.0, linewidth=1.0, label=f'{isi}')
        else:
            plt.plot(trials_to_plot, crs_to_plot, 'o', markersize=2.0, label=f'{isi}')
    ax.legend(loc='center right', bbox_to_anchor=(1.20, 0.5))
    plt.ylim((0, MAX_CR_AMP+1))
    plt.tight_layout()
    plt.show()
    plt.close(fig)

def plot_acq_all_isi_from_pc_use_cr_amp(isis: np.ndarray):
    num_collect_trials = 500
    fig = plt.figure()
    fig.suptitle("Acquisition: CR Amplitude", fontsize=14)
    fig.patch.set_facecolor('white')
    ax = plt.subplot(111)
    ax.set_xlabel('trial num', fontsize=12)
    ax.set_ylabel('eyelid closure (mm)', fontsize=12)
    
    for isi in isis:
        num_collect_ts = PRE_CS_COLLECT_TS + isi + POST_CS_COLLECT_TS
        if isi > 750:
            num_collect_trials = 1000
        input_base = f"acq_no_probe_ISI_{isi}"
        input_pc_file = INPUT_DIR + input_base + "/" + input_base + extensions["pc_raster"]
        pc_rasters = cbm_pack.np_arr_from_file(input_pc_file, np.uint8)
        pc_rasters = cbm_pack.reshape_raster( \
                pc_rasters, \
                cell_nums["pc"], \
                num_collect_trials, \
                num_collect_ts)
        pc_crs = cbm_pack.pcs_to_crs(pc_rasters, PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, isi)
        trials_to_plot = np.arange(0, num_collect_trials, 10)
        crs_to_plot = pc_crs[trials_to_plot]
        crs_to_plot = np.max(crs_to_plot, axis=1)
        # boilerplate matplotlib code
        if isi <= 1250:
            plt.plot(trials_to_plot, crs_to_plot, 'o-', markersize=2.0, linewidth=1.0, label=f'{isi}')
        else:
            plt.plot(trials_to_plot, crs_to_plot, 'o', markersize=2.0, label=f'{isi}')
    ax.legend(loc='center right', bbox_to_anchor=(1.20, 0.5))
    plt.ylim((0, MAX_CR_AMP+1))
    plt.tight_layout()
    plt.show()
    plt.close(fig)


def plot_acq_all_isi_from_pc_use_cr_amp_avg_probe(isis: np.ndarray):
    num_train_trials = 500
    num_collect_trials = 500
    num_blocks_per_session = 50
    num_train_trials_per_b =  num_train_trials // num_blocks_per_session
    num_collect_trials_per_b =  num_collect_trials // num_blocks_per_session
    fig = plt.figure()
    fig.suptitle("Acquisition: CR Amplitude", fontsize=14)
    fig.patch.set_facecolor('white')
    ax = plt.subplot(111)
    ax.set_xlabel('trial num', fontsize=12)
    ax.set_ylabel('eyelid closure (mm)', fontsize=12)
    
    for isi in isis:
        num_collect_ts = PRE_CS_COLLECT_TS + isi + POST_CS_COLLECT_TS
        if isi > 750:
            num_train_trials = 1000
            num_collect_trials = 1000
            num_blocks_per_session = 100
            num_train_trials_per_b =  num_train_trials // num_blocks_per_session
            num_collect_trials_per_b =  num_collect_trials // num_blocks_per_session
        input_base = f"acq_10_train_10_probe_ISI_{isi}"
        input_pc_file = INPUT_DIR + input_base + "/" + input_base + extensions["pc_raster"]
        pc_rasters = cbm_pack.np_arr_from_file(input_pc_file, np.uint8)
        pc_rasters = cbm_pack.reshape_raster( \
                pc_rasters, \
                cell_nums["pc"], \
                num_collect_trials, \
                num_collect_ts)
        pc_crs = cbm_pack.pcs_to_crs(pc_rasters, PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, isi)
        trials_to_plot = np.arange(num_train_trials_per_b, num_train_trials + num_train_trials_per_b, num_train_trials_per_b) # every 100 trials ran 50 probe trials
        mean_crs_blocks = np.zeros(num_blocks_per_session)
        trial_start = 0
        trial_end = num_collect_trials_per_b
        for i in np.arange(num_blocks_per_session):
                mean_crs_blocks[i] = np.mean(np.max(pc_crs[trial_start:trial_end], axis=1))
                trial_start += num_collect_trials_per_b
                trial_end += num_collect_trials_per_b
        # boilerplate matplotlib code
        if isi <= 1250:
            plt.plot(trials_to_plot, mean_crs_blocks, 'o-', markersize=2.0, linewidth=1.0, label=f'{isi}')
        else:
            plt.plot(trials_to_plot, mean_crs_blocks, 'o', markersize=2.0, label=f'{isi}')
    ax.legend(loc='center right', bbox_to_anchor=(1.20, 0.5))
    plt.ylim((0, MAX_CR_AMP+1))
    plt.tight_layout()
    plt.show()
    plt.close(fig)

## Method Sean
- Criterion: DCN fr goes above baseline + 10Hz
- baseline is computed by taking middle half of the pre-cs period
    - ie if pre cs collection period is 400, then baseline is computed as 
      the average over timesteps 100 to 300
- amplitude is computed by taking maximum within the ISI before and subtracting criterion

- NOTE: gives CRs for every trial because our criterion (10Hz above base) is satisfied even
  from the first trial for most ISIs

In [None]:
plot_acq_all_isi_from_nc_use_cr_amp(ISIS, method="sean")

## Method Mike

- there is no criterion: cr amplitude is computed from nc via:
    - integrating all nc activity into an excitatory Red nucleus cell's membrane potential
    - that membrane potential is then modified by scaling, dropping out little wigglies close to zero, and scaling
- need to threshold: CR only if final result is greater than 0.3

In [None]:
plot_acq_all_isi_from_nc_use_cr_amp(ISIS, method="mike")

## Method Joe (ie CR Amp from Purkinje Cells)

- Criterion: mean, smoothed PC cell firing rates fall below 80% baseline firing rates
- not every trial gives CRs, ISIs 150 - 750 learn quite robustly, 1000 and 1250 learn
  but with widely varying amplitudes, 1500 and 1750 somewhat learn, 2000 does not learn
  essentially at all

In [None]:
plot_acq_all_isi_from_pc_use_cr_amp(np.array([150, 200, 250, 500]))

## Acquisition from Data Collected on Probe Trials alone

- for ISI less than 1000, 500 paired trials and 250 probe trials were ran: for 1000 and above, 1000 paired trials and 500 probe trials were run
- training was run in blocks of 100 paired CS-US trials and 50 probe trials (CS-Alone, pfpc plasticity off)
- pc and nc raster data were collected only during probe trials

## Analysis Outline

- compute the CRs for every isi, every trial, every ts
- Create an array of the maxima across the number of probe trials within a block (here across 50 trials)
- average over these mixima to produce a single, mean maximum across 50 probe trials
- plot each of these "mean maxima" at the point at which each group of 50 probe trials began (so at paired trials 100, 200, 300, etc.)

UPDATE 07/07/2023: trained again using same equilibriated simulations as all runs before, but with alternating batches of 10 train, 10 probe trials

In [None]:
plot_acq_all_isi_from_pc_use_cr_amp_avg_probe(np.array([150, 200, 250, 500]))

# Acquisition Onset Times

In [None]:
def plot_onset_times_all_isi_from(isis: np.ndarray, input_type: str):
    num_collect_trials = 500
    fig = plt.figure()
    fig.suptitle("CR Onset Times", fontsize=14)
    fig.patch.set_facecolor('white')
    ax = plt.subplot(111)
    ax.set_xlabel('trial num', fontsize=12)
    ax.set_ylabel('onset time (ms)', fontsize=12)
    
    for isi in isis:
        num_collect_ts = PRE_CS_COLLECT_TS + isi + POST_CS_COLLECT_TS
        if isi > 750:
            num_collect_trials = 1000
        input_base = f"acq_no_probe_ISI_{isi}"
        if input_type == "pc":
            input_pc_file = INPUT_DIR + input_base + "/" + input_base + extensions["pc_raster"]
            pc_rasters = cbm_pack.np_arr_from_file(input_pc_file, np.uint8)
            pc_rasters = cbm_pack.reshape_raster( \
                    pc_rasters, \
                    cell_nums["pc"], \
                    num_collect_trials, \
                    num_collect_ts)
            onset_times = cbm_pack.calc_cr_onsets_from_pc( \
                pc_rasters,
                PRE_CS_COLLECT_TS,
                isi)
        elif input_type == "nc":
            input_nc_file = INPUT_DIR + input_base + "/" + input_base + extensions["nc_raster"]
            nc_rasters = cbm_pack.np_arr_from_file(input_nc_file, np.uint8)
            nc_rasters = cbm_pack.reshape_raster( \
                    nc_rasters, \
                    cell_nums["nc"], \
                    num_collect_trials, \
                    num_collect_ts)
            onset_times = cbm_pack.calc_cr_onsets_from_nc( \
                nc_rasters,
                PRE_CS_COLLECT_TS,
                isi)
        else:
            raise ValueError(f"expected positional parameter 'input_type' to be either 'pc' or 'nc'. Got '{input_type}'")
        trials_to_plot = np.arange(0, num_collect_trials, 20)
        onset_times_to_plot = onset_times[trials_to_plot] + PRE_CS_COLLECT_TS
        # boilerplate matplotlib code
        if isi <= 750:
            plt.plot(trials_to_plot, onset_times_to_plot, 'o-', markersize=2.0, linewidth=1.0, label=f'{isi}')
        else:
            plt.plot(trials_to_plot, onset_times_to_plot, 'o', markersize=2.0, label=f'{isi}')

    ax.legend(loc='center right', bbox_to_anchor=(1.20, 0.5))
    plt.ylim((0, PRE_CS_COLLECT_TS + isis[-1] + POST_CS_COLLECT_TS))
    plt.tight_layout()
    plt.show()
    plt.close(fig)

In [None]:
plot_onset_times_all_isi_from(ISIS, "pc")

In [None]:
plot_onset_times_all_isi_from(ISIS, "nc") # latest CR onset times from nc computed with threshold 40Hz above base

# Acqusition CR Probabilities

In [None]:
def plot_acq_all_isi_from_pc_use_cr_prob(isis: np.ndarray):
    num_collect_trials = 500
    fig = plt.figure()
    fig.suptitle("Acquisition: CR Probabilities", fontsize=14)
    fig.patch.set_facecolor('white')
    ax = plt.subplot(111)
    ax.set_xlabel('trial num', fontsize=12)
    ax.set_ylabel('% CR', fontsize=12)
    num_avg_over = 10
    for isi in isis:
        num_collect_ts = PRE_CS_COLLECT_TS + isi + POST_CS_COLLECT_TS
        if isi > 750:
            num_collect_trials = 1000
        input_base = f"acq_no_probe_ISI_{isi}"
        input_pc_file = INPUT_DIR + input_base + "/" + input_base + extensions["pc_raster"]
        pc_rasters = cbm_pack.np_arr_from_file(input_pc_file, np.uint8)
        pc_rasters = cbm_pack.reshape_raster( \
                pc_rasters, \
                cell_nums["pc"], \
                num_collect_trials, \
                num_collect_ts)
        cr_probs = cbm_pack.calc_cr_probs_from_pc(pc_rasters, \
            PRE_CS_COLLECT_TS, \
            isi, \
            num_avg_over)
        trials_to_plot = np.arange(0, num_collect_trials, num_avg_over)
        # boilerplate matplotlib code
        if isi <= 1250:
            plt.plot(trials_to_plot, cr_probs, 'o-', markersize=2.0, linewidth=1.0, label=f'{isi}')
        else:
            plt.plot(trials_to_plot, cr_probs, 'o', markersize=2.0, label=f'{isi}')
    ax.legend(loc='center right', bbox_to_anchor=(1.20, 0.5))
    plt.ylim((0, 1.25))
    plt.tight_layout()
    plt.show()
    plt.close(fig)

In [None]:

INPUT_PC_FILE = INPUT_DIR + INPUT_BASE + "/" + INPUT_BASE + extensions["pc_raster"]
pc_rasters = cbm_pack.np_arr_from_file(INPUT_PC_FILE, np.uint8)
pc_rasters = cbm_pack.reshape_raster( \
        pc_rasters, \
        cell_nums["pc"], \
        NUM_COLLECT_TRIALS, \
        NUM_COLLECT_TS)
cr_probs = cbm_pack.calc_cr_probs_from_pc(pc_rasters, \
    PRE_CS_COLLECT_TS, \
    ISI, \
    10)
#inst_frs = cbm_pack.calc_inst_fire_rates_from(nc_rasters)
#mean_inst_frs = np.mean(inst_frs, axis=0)
#smooth_mean_inst_frs = cbm_pack.calc_smooth_mean_frs(mean_inst_frs, kernel_type="gaussian")

#cr_probs = cbm_pack.calc_cr_probs_from_nc(nc_rasters, \
#    PRE_CS_COLLECT_TS, \
#    ISI, \
#    10)


In [None]:
print(cr_probs)

In [None]:
plot_acq_all_isi_from_pc_use_cr_prob(ISIS)

### Debugging % CR and onset times computed from NCs

- What was wrong?
    - Every trial constituted a CR because the MF has a strong influence on the CS period
      fr of the DCN neurons, and so mean, smoothed fr was reaching criterion (baseline + 10Hz) immediately
    - so onset times were essentially all at cs onset, and cr probabilities were all 1% from the start
- conclusion:
    - for our current NC population fr distribution (wrt time), PC's are a better measure of CR %, esp given
      we have not averaged percentages over multiple runs at the same ISI yet

In [None]:
def plot_mean_smooth_fr_at_trial(smooth_fr_data, trial):
    # plot the PSTH for this cell
    fig = plt.figure()
    fig.suptitle(f"mean smoothed fr at trial {trial}", fontsize=14)
    ax_1 = plt.subplot(111)
    ax_1.add_patch(Rectangle((PRE_CS_COLLECT_TS, 0), ISI, 160, color='black', alpha=0.1, edgecolor=None))
    
    # boilerplate matplotlib code
    plt.plot(smooth_fr_data[trial], color='black', linewidth=1.5)
    plt.xlim((50, NUM_COLLECT_TS - 50))
    plt.ylim((0, 160))
    fig.patch.set_facecolor('white')
    ax_1.set_xlabel('trial time (ms)', fontsize=12)
    ax_1.set_ylabel('rate', fontsize=12)
    plt.show()
    plt.close(fig)

In [None]:
trial = 506
print(np.mean(smooth_mean_inst_frs[trial, 100:300]))
print(np.max(smooth_mean_inst_frs[trial, PRE_CS_COLLECT_TS:PRE_CS_COLLECT_TS+ISI]))
plot_mean_smooth_fr_at_trial(smooth_mean_inst_frs, trial)

# Background Analysis

## Outline
- for each isi:
- for each block of num_background_trials and num_probe_trials:
- compute the average CR amplitude over num_probe_trials
- keep track of the collected CR amps computed so far
- plot all the CR amps, where the x-axis will be the trial number just before probe trials began

In [None]:
def plot_forget_all_isi_from_pc_use_cr_amp(isis: np.ndarray, num_bkgd_trials: int, num_bkgd_trials_per_b: int, \
                                       pre_cs_collect_ts: int, post_cs_collect_ts: int, \
                                       input_dir: str, input_base: str, num_probe_trials: int, num_probe_trials_per_b: int, \
                                       num_blocks_per_session: int, max_cr_amp: int):
    fig = plt.figure()
    fig.suptitle("Forgetting", fontsize=14)
    fig.patch.set_facecolor('white')
    ax = plt.subplot(111)
    ax.set_xlabel('trial num', fontsize=12)
    ax.set_ylabel('eyelid closure (mm)', fontsize=12)
    trials_to_plot = np.arange(0, num_bkgd_trials, num_bkgd_trials_per_b)
    for isi in isis:
        # get the data from file
        num_collect_ts = pre_cs_collect_ts + isi + post_cs_collect_ts
        input_base_isi = f"{input_base}_{isi}"
        input_pc_file = input_dir + input_base_isi + "/" + input_base_isi + extensions["pc_raster"]
        pc_rasters = cbm_pack.np_arr_from_file(input_pc_file, np.uint8)
        pc_rasters = cbm_pack.reshape_raster( \
                pc_rasters, \
                cell_nums["pc"], \
                num_probe_trials, \
                num_collect_ts)
        # compute the crs from the pcs -> come back in 10 years
        pc_forget_crs = cbm_pack.pcs_to_crs(pc_rasters, pre_cs_collect_ts, post_cs_collect_ts, isi)

        # compute mean crs from probe trials
        mean_crs_blocks = np.zeros(num_blocks_per_session)
        trial_start = 0
        trial_end = num_probe_trials_per_b
        for i in np.arange(num_blocks_per_session):
                mean_crs_blocks[i] = np.mean(np.max(pc_forget_crs[trial_start:trial_end], axis=1))
                trial_start += num_probe_trials_per_b
                trial_end += num_probe_trials_per_b

        # plot away sailor
        if isi <= 1250:
            plt.plot(trials_to_plot, mean_crs_blocks, 'o-', markersize=2.0, linewidth=1.0, label=f'{isi}')
        else:
            plt.plot(trials_to_plot, mean_crs_blocks, 'o', markersize=2.0, label=f'{isi}')

    # final setup stuff
    ax.legend(loc='center right', bbox_to_anchor=(1.20, 0.5))
    plt.ylim((0, max_cr_amp+1))
    plt.tight_layout()
    plt.show()
    plt.close(fig)


### Initial Run with 5000 Trials

In [None]:
NUM_BLOCKS_PER_SESSION = 50
NUM_BKGD_TRIALS_PER_B = 100
NUM_PROBE_TRIALS_PER_B = 50
NUM_BKGD_TRIALS =  NUM_BKGD_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION
NUM_PROBE_TRIALS = NUM_PROBE_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION

input_base = f"forget_est_ISI"
plot_forget_all_isi_from_pc_use_cr_amp(ISIS, NUM_BKGD_TRIALS, NUM_BKGD_TRIAlS_PER_B, \
                                       PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, \
                                       INPUT_DIR, input_base, NUM_PROBE_TRIALS, NUM_PROBE_TRIALS_PER_B, \
                                       NUM_BLOCKS_PER_SESSION, MAX_CR_AMP) # run for 5k trials

### Results (from above Figure)

- ISIs 1500, 1750, and 2000 'forget' (loosely defined, have not measured wrt amplitude cutoff) in 5000ts
- every other ISI does not
- keep in mind 5000 trials corresponds (with each trial being 5000ms) to about 6.75 hours

In [None]:
from matplotlib.colors import colorConverter
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d import Axes3D

def waterfall_plot(crs: np.ndarray):
    num_probe_trials, num_ts_per_trial = crs.shape
    cc = lambda arg: colorConverter.to_rgba(arg, alpha=1)
    ts = np.arange(num_ts_per_trial)
    probe_trials = np.arange(0, num_probe_trials, NUM_BKGD_TRIALS_PER_B)
    verts = []
    for trial in probe_trials:
        verts.append(list(zip(ts, crs[trial, :])))
    poly = LineCollection(verts, linewidths=0.5, edgecolor=[cc('k')])
    ax = plt.figure().add_subplot(projection='3d')
    ax.add_collection3d(poly, zs=probe_trials, zdir='y')
    ax.set_xlim(50, num_ts_per_trial - 50)
    ax.set_ylim(0, num_probe_trials)
    ax.set_zlim(0, MAX_CR_AMP+0.5)
    ax.set_title(f'Waterfall plot of CRs for isi {ISI}')
    ax.set_xlabel('ts (ms)')
    ax.set_ylabel('trials')
    ax.set_zlabel('eyelid closure (mm)')
    ax.view_init(elev=30, azim=-100, roll=0)
    plt.show()

In [None]:
waterfall_plot(pc_forget_crs)

### Next Run with 20000 Trials

In [None]:
NUM_BLOCKS_PER_SESSION = 200
NUM_BKGD_TRIALS_PER_B = 100
NUM_PROBE_TRIALS_PER_B = 50
NUM_BKGD_TRIALS =  NUM_BKGD_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION
NUM_PROBE_TRIALS = NUM_PROBE_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION

input_base = f"forget_est_20k_ISI"
plot_forget_all_isi_from_pc_use_cr_amp(ISIS, NUM_BKGD_TRIALS, NUM_BKGD_TRIAlS_PER_B, \
                                       PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, \
                                       INPUT_DIR, input_base, NUM_PROBE_TRIALS, NUM_PROBE_TRIALS_PER_B, \
                                       NUM_BLOCKS_PER_SESSION, MAX_CR_AMP) # run for 20k trials

### Results (from above figure)

- ISIs 1000-2000 forget after 2000 trials
- every ISI below 1000 is on its way, but not completely forgotten
- 20000 trials corresponds roughly to 27 hours in real-time

## Resetting Weights

### Outline
- ran 50 probe trials for several reset weight values
- only ran for ISI 500 so far to gauge number of reset weights necessary for forgetting

In [None]:

NUM_PROBE_TRIALS = 20 #50
#ISIS = np.array([150, 200, 250, 500, 750, 1000, 1250, 1500, 1750, 2000])
ISI = 500
NUM_COLLECT_TS = PRE_CS_COLLECT_TS + ISI + POST_CS_COLLECT_TS
PRE_CS_COLLECT_TS = 400
POST_CS_COLLECT_TS = 400

#RESET_INPUT_PREFIX = f"probe_isi_{ISI}_reset"
# below is the full path until we get to parts of file name that vary
#RESET_SUB_DIRS = [d for d in os.listdir(INPUT_DIR) \
#                     if os.path.isdir(os.path.join(INPUT_DIR, d)) \
#                         and d.startswith(RESET_INPUT_PREFIX)]
#RESET_NUMS = sorted([int(d.split('_')[-1]) for d in RESET_SUB_DIRS])

#RESET_INPUT_PC_FILE = f"{CONST_FILE_BASE}{extensions['pc_raster']}"
#RESET_INPUT_NC_FILE = f"{CONST_FILE_BASE}{extensions['nc_raster']}"

In [None]:
def plot_reset_weights_for_isi(isis: np.ndarray):
    fig = plt.figure()
    fig.suptitle("Reset Weights", fontsize=14)
    fig.patch.set_facecolor('white')
    ax = plt.subplot(111)
    ax.set_xlabel('Number Reset', fontsize=12)
    ax.set_ylabel('eyelid closure (mm)', fontsize=12)
    for isi in isis:
        #probe_isi_{isi}_reset_43500_bugaloo
        #reset_input_prefix = f"probe_isi_{isi}_reset"
        #reset_sub_dirs = [d for d in os.listdir(INPUT_DIR) \
        #                     if os.path.isdir(os.path.join(INPUT_DIR, d)) \
        #                         and d.startswith(reset_input_prefix)]
        #reset_nums = sorted([int(d.split('_')[-1]) for d in reset_sub_dirs])
        # 15500 is when I get nans, so expect anything greater than that to be 0
        reset_nums = np.arange(0, 25500, 500)
        isi_reset_crs = []
        for num in reset_nums:
            #reset_base = f"{reset_input_prefix}_{num}"
            reset_base = f"probe_isi_{isi}_reset_{num}_bugaloo"
            reset_pc_file = f"{INPUT_DIR}{reset_base}/{reset_base}{extensions['pc_raster']}"

            num_collect_ts = PRE_CS_COLLECT_TS + isi + POST_CS_COLLECT_TS
            pc_rasters = cbm_pack.np_arr_from_file(reset_pc_file, np.uint8)
            pc_rasters = cbm_pack.reshape_raster( \
                    pc_rasters, \
                    cell_nums["pc"], \
                    NUM_PROBE_TRIALS, \
                    num_collect_ts)
            # compute the crs from the pcs -> come back in 10 years
            pc_reset_crs = cbm_pack.pcs_to_crs(pc_rasters, PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, isi)

            computed_cr = np.mean(np.max(pc_reset_crs, axis=1))
            if np.isnan(computed_cr):
                computed_cr = 0.0
            isi_reset_crs.append(computed_cr)
            print(f"ISI {isi}: computed cr for {num} reset: {computed_cr}")
            
        plt.plot(reset_nums, isi_reset_crs, 'o', markersize=2.0, label=f'{isi}')
    
    # final setup stuff
    ax.legend(loc='center right', bbox_to_anchor=(1.20, 0.5))
    plt.ylim((0, MAX_CR_AMP+1))
    plt.tight_layout()
    plt.show()
    plt.close(fig)

### Issues

- clearly not forgetting. Needs to be debugged

### Attempt at a Solution

- checked the distributions of weights:
  - after equilibrium
  - after acquisition
  - after reset (using Satvik's script)
- found that what was being reset appeared to be closer to the center of the peak of the distribution, rather than the tails

- re-created Satvik's script
  - visually inspected the distribution for different amounts of lowest weights reset to equilibrium
  - checked for the number that was after the lower peak (weakest learned synapses) had vanished
  - that number was just before 50000
- then ran 20 probe trials for number of reset weights in numpy.arange(0, 50500, 500)

### Second Result

- also failed to forget -> need a control
- TODO: let the overwrite file be the equilibrium weights file: if that file leads to forgetting of CRs, question implementation

### Third Result

- Success! -> check your implementation: I mixed up my flags

## Freeze Weight Analysis

### Outline

- proof of concept, so running only on ISI 500
- identified weights from resetting (for ISI 500 it was approx. the lowest 15000) were frozen
- 5000 forgetting trials were administered, with an additional 2500 probe trials
- forgetting trials and probe trials were interweaved in blocks of 100 forgetting trials and 50 probe trials

In [None]:
NUM_BLOCKS_PER_SESSION = 50
NUM_BKGD_TRIALS_PER_B = 100
NUM_PROBE_TRIALS_PER_B = 50
NUM_BKGD_TRIALS =  NUM_BKGD_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION
NUM_PROBE_TRIALS = NUM_PROBE_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION
#ISIS = np.array([150, 200, 250, 500, 750, 1000, 1250, 1500, 1750, 2000])
ISI = 500
ISIS = np.array([ISI])

In [None]:
input_base = f"forget_nec_suf_ISI"
#plot_forget_all_isi_from_pc_use_cr_amp(ISIS, NUM_BKGD_TRIALS, NUM_BKGD_TRIALS_PER_B, \
#                                       PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, \
#                                       INPUT_DIR, input_base, NUM_PROBE_TRIALS, NUM_PROBE_TRIALS_PER_B, \
#                                       NUM_BLOCKS_PER_SESSION, MAX_CR_AMP) # run for 5k trials

# get the data from file
num_collect_ts = PRE_CS_COLLECT_TS + ISI + POST_CS_COLLECT_TS
input_base_isi = f"{input_base}_{ISI}"
input_pc_file = INPUT_DIR + input_base_isi + "/" + input_base_isi + extensions["pc_psth"]
pc_psths = cbm_pack.np_arr_from_file(input_pc_file, np.uint8)
pc_psths = pc_psths.reshape((num_collect_ts, cell_nums["pc"]))
pc_psths = pc_psths.transpose()
pc_inst_frs = cbm_pack.calc_inst_fire_rates_from(pc_psths, data_type="psth", num_trials=1000.0)
#pc_smooth_frs = cbm_pack.calc_smooth_inst_fire_rates_from_psth(pc_psths,
#    NUM_PROBE_TRIALS, kernel_type="gaussian")
#mean_pc_smooth_fr = np.mean(pc_smooth_frs, axis=0)

In [None]:
fig = plt.figure()
fig.suptitle("15K Frozen Weights, PC Cell 31 psth", fontsize=14)
fig.patch.set_facecolor('white')
ax = plt.subplot(111)
ax.set_xlabel('ts (ms)', fontsize=12)
ax.set_ylabel('spike count', fontsize=12)
plt.plot(pc_psths[31], '-', linewidth=1.0)

# final setup stuff
plt.ylim((0, 260))
plt.tight_layout()
plt.show()
plt.close(fig)

In [None]:
fig = plt.figure()
fig.suptitle("15K Frozen Weights, PC Cell 31 psth, ts 800:875", fontsize=14)
fig.patch.set_facecolor('white')
ax = plt.subplot(111)
ax.set_xlabel('ts (ms)', fontsize=12)
ax.set_ylabel('spike count', fontsize=12)
plt.plot(pc_psths[31][800:875], '-', linewidth=1.0)

# final setup stuff
plt.ylim((0, 260))
plt.tight_layout()
plt.show()
plt.close(fig)

In [None]:
input_nc_file = INPUT_DIR + input_base_isi + "/" + input_base_isi + extensions["nc_psth"]
nc_psths = cbm_pack.np_arr_from_file(input_nc_file, np.uint8)
nc_psths = nc_psths.reshape((num_collect_ts, cell_nums["nc"]))
nc_psths = nc_psths.transpose()

In [None]:
nc_id = 2
fig = plt.figure()
fig.suptitle(f"15K Frozen Weights, NC Cell {nc_id} psth", fontsize=14)
fig.patch.set_facecolor('white')
ax = plt.subplot(111)
ax.set_xlabel('ts (ms)', fontsize=12)
ax.set_ylabel('spike count', fontsize=12)
plt.plot(nc_psths[nc_id], '-', linewidth=1.0)

# final setup stuff
plt.ylim((0, 260))
plt.tight_layout()
plt.show()
plt.close(fig)

### Issues

- What you see is integer overflow: the data types of the psths are 8 bit, unsigned integer
  - we never ran into this problem because we never really ran psths with so many trials
- something else is probably going on
  - why would the pc cells even *have* that many spikes during the marked decrease in firing?
  - why do these periods look oscilatory?
  - what is the population of GR that connect to the frozen weigths doing?
  - what is the population of GR that connect to the non-frozen weights doing?
  - what is the overlap in the gr cell population?

### Freeze Raster Analysis

- Verification of frozen weights not changing before and after run
- Verification of non-frozen weights changing before and after run

In [None]:
def verify_frozen_non_frozen_weights(acq_weight_path: str, post_freeze_forget_weight_path: str, mask_path: str, num_reset: int):
    acq_weights =  cbm_pack.np_arr_from_file(acq_weight_path, datatype=np.single)
    assert len(acq_weights) == cell_nums["gr"]

    post_freeze_forget_weights =  cbm_pack.np_arr_from_file(post_freeze_forget_weight_path, datatype=np.single)
    assert len(post_freeze_forget_weights) == cell_nums["gr"]

    freeze_mask = cbm_pack.np_arr_from_file(mask_path, datatype=np.ubyte)
    try:
        np.testing.assert_array_equal(acq_weights[1-freeze_mask > 0], post_freeze_forget_weights[1-freeze_mask > 0])
    except AssertionError:
        print(f"lowest {num_reset} weights are not equal")
    else:
        print(f"lowest {num_reset} weights are equal")
        
    try:
        np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, acq_weights[freeze_mask > 0], post_freeze_forget_weights[freeze_mask > 0])
    except AssertionError:
        print(f"all of the rest of the weights are the same")
    else:
        print(f"some of the rest of the weights are different")

In [None]:
acq_input_base = f"acq_ISI_{ISI}"
acq_weight_path = f"{INPUT_DIR}{acq_input_base}/{acq_input_base}_TRIAL_749.pfpcw"

forget_nec_suf_base = f"forget_nec_suf_ISI_{ISI}"
post_freeze_forget_weight_path = f"{INPUT_DIR}{forget_nec_suf_base}/{forget_nec_suf_base}_TRIAL_7499.pfpcw"

freeze_mask_path = f"./Tests/freeze_isi_500_15000.mask"

verify_frozen_non_frozen_weights(acq_weight_path, post_freeze_forget_weight_path, freeze_mask_path, 15000)

In [None]:
NUM_BLOCKS_PER_SESSION = 50
NUM_BKGD_TRIALS_PER_B = 100
NUM_PROBE_TRIALS_PER_B = 50
NUM_BKGD_TRIALS =  NUM_BKGD_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION
NUM_PROBE_TRIALS = NUM_PROBE_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION

input_base = f"forget_nec_suf_ISI"
plot_forget_all_isi_from_pc_use_cr_amp(ISIS, NUM_BKGD_TRIALS, NUM_BKGD_TRIALS_PER_B, \
                                       PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, \
                                       INPUT_DIR, input_base, NUM_PROBE_TRIALS, NUM_PROBE_TRIALS_PER_B, \
                                       NUM_BLOCKS_PER_SESSION, MAX_CR_AMP) # run for 20k trials

### Issues

- 1) post-acq weights should not be the same as post-freeze weights
- 2) if post-acq weights are same as post-freeze weights, then cr amps make sense
- 3) did a lil debugging, found I wasn't giving the filestream the full path
      to the weight file -> assuming the fs obj silently fails, then what is the value
      of the mask? ZERO. and what meaning did we assign to ZERO? -> FREEZE, so nothing 
      was updating, thus 1) and 2) above make sense. the psths though, not sure...

In [None]:
acq_input_base = f"acq_ISI_{ISI}"
acq_weight_path = f"{INPUT_DIR}{acq_input_base}/{acq_input_base}_TRIAL_749.pfpcw"

forget_nec_suf_base = f"forget_nec_suf_ISI_{ISI}"
post_freeze_forget_weight_path = f"{INPUT_DIR}{forget_nec_suf_base}/{forget_nec_suf_base}_TRIAL_7499.pfpcw"

freeze_mask_path = f"./Tests/freeze_isi_500_15000.mask"

verify_frozen_non_frozen_weights(acq_weight_path, post_freeze_forget_weight_path, freeze_mask_path, 15000)

In [None]:
NUM_BLOCKS_PER_SESSION = 50
NUM_BKGD_TRIALS_PER_B = 100
NUM_PROBE_TRIALS_PER_B = 50
NUM_BKGD_TRIALS =  NUM_BKGD_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION
NUM_PROBE_TRIALS = NUM_PROBE_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION

forget_nec_suf_input_base = f"forget_nec_suf_ISI"
plot_forget_all_isi_from_pc_use_cr_amp(ISIS, NUM_BKGD_TRIALS, NUM_BKGD_TRIALS_PER_B, \
                                       PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, \
                                       INPUT_DIR, forget_nec_suf_input_base, NUM_PROBE_TRIALS, NUM_PROBE_TRIALS_PER_B, \
                                       NUM_BLOCKS_PER_SESSION, MAX_CR_AMP) # run for 5k trials

### Results

- frozen weights did not change
- non-frozen weights did change
- got a little bit of forgetting, but not nearly as much as when these weights were allowed to drift

### CR Views

- mean CRs from 50 probe trial blocks administered at the end of given forgetting trial number

In [None]:
bkgd_trial_to_plot = 4900
fig = plt.figure()
fig.suptitle(f"mean probe CR after {bkgd_trial_to_plot} trials with 15000 'CR' Weights frozen", fontsize=14)
fig.patch.set_facecolor('white')
ax = plt.subplot(111)
ax.set_xlabel('ts', fontsize=12)
ax.set_ylabel('eyelid closure (mm)', fontsize=12)

#NUM_COLLECT_TS = PRE_CS_COLLECT_TS + ISI + POST_CS_COLLECT_TS
## get the data from file
#forget_nec_suf_input_base = f"forget_nec_suf_ISI"
#forget_nec_suf_input_base_isi = f"{forget_nec_suf_input_base}_{ISI}"
#input_pc_file = INPUT_DIR + forget_nec_suf_input_base_isi + "/" + forget_nec_suf_input_base_isi + extensions["pc_raster"]
#pc_rasters = cbm_pack.np_arr_from_file(input_pc_file, np.uint8)
#pc_rasters = cbm_pack.reshape_raster( \
#        pc_rasters, \
#        cell_nums["pc"], \
#        NUM_PROBE_TRIALS, \
#        NUM_COLLECT_TS)
## compute the crs from the pcs -> come back in 10 years
#pc_forget_crs = cbm_pack.pcs_to_crs(pc_rasters, PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, ISI)
#
### compute mean crs from probe trials
trial_start = int((bkgd_trial_to_plot / NUM_BKGD_TRIALS) * NUM_PROBE_TRIALS)
trial_end = trial_start + NUM_PROBE_TRIALS_PER_B
print(trial_start, trial_end)
mean_probe_cr_at_trial = np.mean(pc_forget_crs[trial_start:trial_end], axis=0)
ci = 1.96 * np.std(pc_forget_crs[trial_start:trial_end], axis=0) / np.sqrt(NUM_PROBE_TRIALS_PER_B)

ax.plot(mean_probe_cr_at_trial, '-', color='black', alpha=.8, linewidth=1.0, label=f'{ISI}')
ax.fill_between(np.arange(NUM_COLLECT_TS), (mean_probe_cr_at_trial-ci), (mean_probe_cr_at_trial+ci), color='b', alpha=.1)

## final setup stuff
ax.legend(loc='center right', bbox_to_anchor=(1.20, 0.5))
plt.ylim((0, MAX_CR_AMP+1))
plt.tight_layout()
plt.show()
plt.close(fig)

### 20K Trials

In [None]:
NUM_BLOCKS_PER_SESSION = 200
NUM_BKGD_TRIALS_PER_B = 100
NUM_PROBE_TRIALS_PER_B = 50
NUM_BKGD_TRIALS =  NUM_BKGD_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION
NUM_PROBE_TRIALS = NUM_PROBE_TRIALS_PER_B * NUM_BLOCKS_PER_SESSION

forget_nec_suf_input_base = f"forget_nec_suf_20k_ISI"
plot_forget_all_isi_from_pc_use_cr_amp(ISIS, NUM_BKGD_TRIALS, NUM_BKGD_TRIALS_PER_B, \
                                       PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, \
                                       INPUT_DIR, forget_nec_suf_input_base, NUM_PROBE_TRIALS, NUM_PROBE_TRIALS_PER_B, \
                                       NUM_BLOCKS_PER_SESSION, MAX_CR_AMP) # run for 5k trials

### CRs

In [None]:
bkgd_trials_to_plot = [900, 5000, 10000, 15000, 19900]
fig = plt.figure()
fig.suptitle(f"mean probe CRs at trials with 15000 'CR' Weights frozen (ISI 500)", fontsize=14)
fig.patch.set_facecolor('white')
ax = plt.subplot(111)
ax.set_xlabel('ts', fontsize=12)
ax.set_ylabel('eyelid closure (mm)', fontsize=12)

#NUM_COLLECT_TS = PRE_CS_COLLECT_TS + ISI + POST_CS_COLLECT_TS
## get the data from file
#forget_nec_suf_input_base = f"forget_nec_suf_20k_ISI"
#forget_nec_suf_input_base_isi = f"{forget_nec_suf_input_base}_{ISI}"
#input_pc_file = INPUT_DIR + forget_nec_suf_input_base_isi + "/" + forget_nec_suf_input_base_isi + extensions["pc_raster"]
#pc_rasters = cbm_pack.np_arr_from_file(input_pc_file, np.uint8)
#pc_rasters = cbm_pack.reshape_raster( \
#        pc_rasters, \
#        cell_nums["pc"], \
#        NUM_PROBE_TRIALS, \
#        NUM_COLLECT_TS)
## compute the crs from the pcs -> come back in 10 years
#pc_forget_crs = cbm_pack.pcs_to_crs(pc_rasters, PRE_CS_COLLECT_TS, POST_CS_COLLECT_TS, ISI)

### compute mean crs from probe trials
for trial in bkgd_trials_to_plot:
        trial_start = int((trial / NUM_BKGD_TRIALS) * NUM_PROBE_TRIALS)
        trial_end = trial_start + NUM_PROBE_TRIALS_PER_B
        print(trial_start, trial_end)
        mean_probe_cr_at_trial = np.mean(pc_forget_crs[trial_start:trial_end], axis=0)
        ci = 1.96 * np.std(pc_forget_crs[trial_start:trial_end], axis=0) / np.sqrt(NUM_PROBE_TRIALS_PER_B)

        ax.plot(mean_probe_cr_at_trial, '-', alpha=.8, linewidth=1.0, label=f'{trial}')
        ax.fill_between(np.arange(NUM_COLLECT_TS), (mean_probe_cr_at_trial-ci), (mean_probe_cr_at_trial+ci), color='b', alpha=.1)

## final setup stuff
plt.axvline(x=PRE_CS_COLLECT_TS, color='green', linestyle='dotted', label='CS Onset')
plt.axvline(x=PRE_CS_COLLECT_TS + ISI, color='red', linestyle='dotted', label='US Onset')
ax.legend(loc='center right', bbox_to_anchor=(1.30, 0.5))
plt.ylim((0, MAX_CR_AMP+1))
plt.tight_layout()
plt.show()
plt.close(fig) 

### Results

- Still forgetting, despite freezing lowest 15000 weights
- One thing we have not considered is the variability between runs:
  - could run a couple of reset weights trials to find a mean cut-off point

### Next Steps

- Try another Forgetting run of 20k Trials with *20000* reset weights