In [6]:
import numpy as np
import matplotlib.pyplot as plt
rng = np.random.default_rng(42)

In [2]:
def MI_decrease(ranks, neurons):
    """
    Calculates all the MIs after pruning each neuron individually.
    Input: 
    - ranks           : the ranks of all the trials
    - neurons         : a list of all neurons

    Output:
    - MIs_sorted      : all MIs after pruning, sorted
    - neurons_sorted  : all neurons, sorted on their MI after pruning
    """
    # Calculate MI before pruning
    baseline_MI = MI(firing_ranks2occ_matrix(ranks))
    
    MIs = np.zeros(len(neurons))
    
    # Prune each neuron individually from the original ranks and calculate MI
    for ineuron, neuron in enumerate(neurons):
        ranks_candidate, MI_candidate = cut_ranks(ranks, neuron)
        MIs[ineuron] = MI_candidate

    # Sort the MIs based on their decrease
    MIs_sorted = baseline_MI - np.sort(MIs)

    # Sort the neurons according to their decrease  (larger decreases resemble a higher likelihood to be part of the sequence)
    neurons_sorted = np.array(neurons)[np.argsort(MIs)]
    return MIs_sorted, neurons_sorted

def construct_sequence(ranks, n_neurons, n_surrogates, threshold=0.05, mode="max", max_neurons=20): 
    """
    Constructs the sequence, by filtering out insignificant neurons and subsequently ordering on either the mean or mode rank.
    Input:
    - ranks          :  the ranks of all the trials.
    - n_neurons      :  the total number of neurons in the data.
    - n_surrogates   :  the number of surrogates used for testing significance of individual neurons
    - threshold      :  value between 0 and 1. used for significance testing and can be interpreted as the p-value.
    - mode           :  either "mean" or "max". decides how the significant neurons are ordered to find sequence order.
    - max_neurons    :  the maximum sequence length. 

    Output:
    - sequence             :  a list containing the ordered sequence
    """

    # Find MIs with the largest decrease after pruning the neurons individually
    MIs, seq_neurons = MI_decrease(ranks, range(n_neurons))
    MIs = MIs[:max_neurons]
    seq_neurons = seq_neurons[:max_neurons]
    
    MIs_surrogate = np.zeros([n_surrogates, max_neurons])

    for s in range(n_surrogates):
        for ineuron in range(max_neurons):
            surrogate = surrogate_ranks(ranks, seq_neurons[ineuron])
            occ = firing_ranks2occ_matrix(surrogate)

            # Compute the decrease in MI after pruning all occurrences of ineuron
            MIs_surrogate[s, ineuron] = MI(firing_ranks2occ_matrix(surrogate)) - cut_ranks(surrogate, seq_neurons[ineuron])[1]

    insignificant = np.array([np.count_nonzero(MIs_surrogate[:,ineuron]>=MIs[ineuron])/n_surrogates > threshold for ineuron in range(max_neurons)], dtype=int)

    if len(np.where(insignificant==0)[0])>0:
        seq_neurons = seq_neurons[np.where(insignificant==0)[0]]
    else:
        print("NO SIGNIFICANT NEURONS FOUND, 2 NEURONS WITH THE HIGHEST MI DECREASE ARE USED.")
        # Makes sure that we find a sequence of at least 2 neurons, even if nothing was tested as significant.
        seq_neurons = seq_neurons[:3]
    
    occ_matrix = firing_ranks2occ_matrix(ranks)
    
    # Compute mean and maximum ranks
    mean_ranks = np.zeros(len(seq_neurons))
    max_ranks = np.zeros(len(seq_neurons))
    for ineuron, neuron in enumerate(seq_neurons):        
        mean_ranks[ineuron] = np.mean(occ_matrix[neuron,:]*np.arange(occ_matrix.shape[1]))
        max_ranks[ineuron] = np.argmax(occ_matrix[neuron,:])

    # Sort the sequence neurons
    if mode == "mean":
        sequence = np.array(seq_neurons)[np.argsort(mean_ranks)]
    else: 
        sequence = np.array(seq_neurons)[np.argsort(max_ranks)]


    return sequence

def cut_ranks(ranks, neuron):
    """
    Cuts all entries of the neuron out of the ranks. Also, the MI after cutting is outputted.
    """
    pruned_ranks = np.zeros([ranks.shape[0], ranks.shape[1]])
    for itrial in range(ranks.shape[0]):
        trial = list(ranks[itrial])
        if neuron in trial:
            trial.remove(neuron)
            trial.append(np.nan)
        pruned_ranks[itrial] = trial    
    return pruned_ranks, MI(firing_ranks2occ_matrix(pruned_ranks))

def surrogate_ranks(firing_ranks, neuron): 
    """
    Surrogate ranks are created by randomly shuffling the ranks of all spikes of the given neuron. 
    The sequential information of all other neurons remain untouched.
    """
    ranks = np.zeros([firing_ranks.shape[0], firing_ranks.shape[1]])
    
    for itrial, trial in enumerate(firing_ranks):
        if np.any(trial==neuron):
            spike_indices = np.sort(np.where(trial==neuron)[0],)[::-1]
            trial_list = list(trial)
            
            # Remove all spikes of the neuron (in reverse order, to prevent indexing problems)
            for spike_idx in spike_indices:
                trial_list.pop(spike_idx)      
            # Insert each spike on a random rank in the trial
            for spk in range(len(spike_indices)):
                ivalid = np.where(~np.isnan(trial))[0]
                n_slots = len(ivalid)+1
                new_rank = np.random.randint(n_slots)
                trial_list.insert(new_rank, neuron)
            ranks[itrial,:] = np.array(trial_list)
        else:
            ranks[itrial,:] = firing_ranks[itrial]
    return ranks