In [332]:
### Analysis steps

# 1. For each neuron, calculate variance explained by the stimulus across repeats. Plot the SRV distribution for all neurons.

# 2. Shuffle responses (e.g., 100 random trials) to compute a baseline SRV. 

# 3. Keep neurons whose SRV is in the top 90th percentile of the shuffle distribution.

# 4. Apply PCA to SimCLR representations and/or raw image data

# 5. Choose a subset of principal components that explain a large proportion of the variance.

# 6. Ensure the number of features is not vastly greater than the number of neurons to reduce overfitting.

# 7. 80% training, 10% validation, 10% testing

# 8. Training: Learn weights for regression

# 9. Validation: Optimize the regularization parameter

# 10. Testing: Evaluate the final model

# 11. Predict the response of each neuron using SimCLR features as input.

# 12. Train one model per neuron (Input: SimCLR features (e.g., 512 features for final layer; Output: Neural response (scalar value for that neuron))

# 13. Aggregate results to evaluate overall prediction accuracy

In [None]:
### Load data

import numpy as np
from os import path

# imresps.npy is of shape (1573, 2, 15363), where 1573 is number of images, 2 repeats each, and 15363 neurons recorded
# stimids.npy has the image id (matching the image dataset ~selection1866~) for each stimulus number, 
# so of you want to see what image was presented on imresps[502] you would check stim_ids[502]

PATH_TO_DATA = '../../data/neural'

imresps = np.load(path.join(PATH_TO_DATA, 'imresps.npy'))
stimids = np.load(path.join(PATH_TO_DATA, 'stimids.npy'))

print(imresps.shape) # (1573, 2, 15363)
print(stimids.shape) # (1573,)

In [334]:
def compute_signal_related_variance(resp_a, resp_b, mean_center=True):
    """
    compute the fraction of signal-related variance for each neuron,
    as per Stringer et al Nature 2019. Cross-validated by splitting
    responses into two halves. Note, this only is "correct" if resp_a
    and resp_b are *not* averages of many trials.

    Args:
        resp_a (ndarray): n_stimuli, n_cells
        resp_b (ndarray): n_stimuli, n_cells

    Returns:
        fraction_of_stimulus_variance: 0-1, 0 is non-stimulus-caring, 1 is only-stimulus-caring neurons
        stim_to_noise_ratio: ratio of the stim-related variance to all other variance
    """
    if len(resp_a.shape) > 2:
        # if the stimulus is multi-dimensional, flatten across all stimuli
        resp_a = resp_a.reshape(-1, resp_a.shape[-1])
        resp_b = resp_b.reshape(-1, resp_b.shape[-1])
    ns, nc = resp_a.shape
    if mean_center:
        # mean-center the activity of each cell
        resp_a = resp_a - resp_a.mean(axis=0)
        resp_b = resp_b - resp_b.mean(axis=0)

    # compute the cross-trial stimulus covariance of each cell
    # dot-product each cell's (n_stim, ) vector from one half
    # with its own (n_stim, ) vector on the other half

    covariance = (resp_a * resp_b).sum(axis=0) / ns

    # compute the variance of each cell across both halves
    resp_a_variance = (resp_a**2).sum(axis=0) / ns
    resp_b_variance = (resp_b**2).sum(axis=0) / ns
    total_variance = (resp_a_variance + resp_b_variance) / 2

    # compute the fraction of the total variance that is
    # captured in the covariance
    fraction_of_stimulus_variance = covariance / total_variance

    # if you want, you can compute SNR as well:
    stim_to_noise_ratio = fraction_of_stimulus_variance / (
        1 - fraction_of_stimulus_variance
    )

    return fraction_of_stimulus_variance, stim_to_noise_ratio

In [None]:
### Compute SRV and filter most reliable neurons

import matplotlib.pyplot as plt

# for each stimulus, randomly assign each repeat to spilt a or split b
split_A, split_B = [], []
for responses in imresps:
    indices = np.random.permutation(2)
    split_A.append(responses[indices[0]])
    split_B.append(responses[indices[1]])

split_A = np.vstack(split_A)
split_B = np.vstack(split_B)

fraction_of_stimulus_variance, stim_to_noise_ratio = compute_signal_related_variance(split_A, split_B)
print(fraction_of_stimulus_variance)

print('Image responses:', imresps.shape) # (1573, 2, 15363)
print('Image responses split A:', split_A.shape) # (1573, 15363)
print('Image responses split B:', split_B.shape) # (1573, 15363)
print('SRV:', fraction_of_stimulus_variance.shape) # (15363,)
print('SNR:', stim_to_noise_ratio.shape) # (15363,)

# Plot SRV distribution
plt.hist(fraction_of_stimulus_variance, bins=50, color='blue', alpha=0.7)
plt.xlabel("Fraction of Stimulus-Related Variance (SRV)")
plt.ylabel("Number of Neurons")
plt.title("Distribution of SRV Across Neurons")
plt.show()

# Filter out neurons in the top 90th percentile (maintaining the indices)
top_90th_percentile = np.percentile(fraction_of_stimulus_variance, 90)

# The output is an array of indices corresponding to the positions of these reliable neurons 
# in the original fraction_of_stimulus_variance array. It preserves the link to the original 
# data because it gives you the positions of the reliable neurons relative to the full list of neurons.
reliable_neurons = np.where(fraction_of_stimulus_variance >= top_90th_percentile)[0]

print('Filtered neurons in top 90th percentile:', len(reliable_neurons)) # 1537

srv_reliable = fraction_of_stimulus_variance[reliable_neurons]
plt.hist(srv_reliable, bins=50, color='blue', alpha=0.7)
plt.xlabel("Fraction of Stimulus-Related Variance (SRV)")
plt.ylabel("Number of Neurons")
plt.title("Distribution of SRV Across Reliable (>90th percentile) Neurons")

In [None]:
### Compute a null distribution of SRV values

# Image responses: (1573, 2, 15363)
num_stimuli = imresps.shape[0] # 1573
num_repeats = imresps.shape[1] # 2
num_neurons = imresps.shape[2] # 15363
n_shuffles = 100
srv_null_distribution = []

# TODO: use full shuffling of stimuli rather than just shuffling repeats
for _ in range(n_shuffles):
    split_A, split_B = [], []
    for responses in imresps:
        indices = np.random.permutation(2)  # Randomly permute repeats
        split_A.append(responses[indices[0]])
        split_B.append(responses[indices[1]])

    split_A = np.vstack(split_A)
    split_B = np.vstack(split_B)

    # Compute SRV for the shuffled data
    fraction_of_stimulus_variance, _ = compute_signal_related_variance(split_A, split_B)
    srv_null_distribution.append(fraction_of_stimulus_variance)

# Plot the null distribution for a SINGLE SHUFFLE for a SINGLE NEURON
plt.hist(srv_null_distribution[0], bins=100, color='blue', alpha=0.7)
plt.xlabel("Fraction of Stimulus-Related Variance (SRV)")
plt.ylabel("Number of Shuffles")
plt.title("Null Distribution of SRV Across Shuffles")
plt.show()

In [None]:
### Compute null distribution for all neurons, and plot for a single neuron

import matplotlib.pyplot as plt

# 1. First compute the real SRV for each neuron

# for each stimulus, randomly assign each repeat to spilt a or split b
split_A, split_B = [], []
for responses in imresps:
    indices = np.random.permutation(2)
    split_A.append(responses[indices[0]])
    split_B.append(responses[indices[1]])

split_A = np.vstack(split_A)
split_B = np.vstack(split_B)

real_srv_all_neurons, stim_to_noise_ratio = compute_signal_related_variance(split_A, split_B)
print(real_srv_all_neurons)

print('Image responses:', imresps.shape) # (1573, 2, 15363)
print('Image responses split A:', split_A.shape) # (1573, 15363)
print('Image responses split B:', split_B.shape) # (1573, 15363)
print('SRV:', real_srv_all_neurons.shape) # (15363,)
print('SNR:', stim_to_noise_ratio.shape) # (15363,)

# Plot SRV distribution
plt.hist(real_srv_all_neurons, bins=50, color='blue', alpha=0.7)
plt.xlabel("Fraction of Stimulus-Related Variance (SRV)")
plt.ylabel("Number of Neurons")
plt.title("Real SRV: SRV Across Neurons")
plt.show()

# 2. Compute null distribution of SRV values for all neurons
# Image responses: (1573, 2, 15363)
num_stimuli = imresps.shape[0] # 1573
num_repeats = imresps.shape[1] # 2
num_neurons = imresps.shape[2] # 15363
n_shuffles = 100

# shape (n_shuffles, num_neurons)
null_srv_all_neurons = []

for _ in range(n_shuffles):
    # Shuffle stimulus indices
    shuffled_indices = np.random.permutation(num_stimuli)
    shuffled_resps = imresps[shuffled_indices, :, :]  # Shuffle stimulus order

    # Split into two groups, maintaining random assignments across stimuli
    split_A = shuffled_resps[:, 0, :] # First repeat of shuffled stimuli
    split_B = shuffled_resps[:, 1, :] # Second repeat of shuffled stimuli

    # Compute SRV for the shuffled data - returns SRV for each neuron - shape (15363,)
    fraction_of_stimulus_variance, _ = compute_signal_related_variance(split_A, split_B)
    null_srv_all_neurons.append(fraction_of_stimulus_variance)

# Convert null distribution to numpy array for easier indexing
# shape (n_shuffles, num_neurons) - (100, 15363) - each value is the SRV for a neuron in a shuffle
null_srv_all_neurons = np.array(null_srv_all_neurons)
print(null_srv_all_neurons.shape)

# e.g. if neuron_index = 0, it will plot the SRV value for neuron 0 across all shuffles
neuron_index = 0
plt.hist([srv[neuron_index] for srv in null_srv_all_neurons], bins=100, color='blue', alpha=0.7)
plt.xlabel("Fraction of Stimulus-Related Variance (SRV)")
plt.ylabel("Number of Shuffles")
plt.title(f"Null Distribution of SRV for Neuron {neuron_index}")
plt.show()

# 3. Now filter our neurons whose real SRV is in the top 90th percentile of its null distribution

top_90th_percentile_null = np.percentile(null_srv_all_neurons, 90, axis=0)

# reliable_neurons contains the indices of neurons whose real SRV is statistically significant
reliable_neurons = np.where(real_srv_all_neurons >= top_90th_percentile_null)[0]
print('Filtered neurons in top 90th percentile of null distribution:', len(reliable_neurons))