In [None]:
import pickle
import numpy as np
import os

ref_resps = np.load('../../data/neural/imresps.npy')
ref_stims = np.load('../../data/neural/stimids.npy')

print(f'[ref mouse], resps shape: {ref_resps.shape}, stims shape: {ref_stims.shape}')

other_resps = {}
other_stims = {}

for mouse_id in ['m01-d2', 'm02-d3', 'm03-d4']:
    with open(f'../../data/natimg-data-{mouse_id.upper()}.pickle', 'rb') as f:
        data = pickle.load(f)
    imresps = data['resps']
    stimids = data['stims']
    other_resps[mouse_id] = imresps
    other_stims[mouse_id] = stimids
    print(f'[natimg-data-{mouse_id.upper()}], resps shape: {imresps.shape}, stims shape: {stimids.shape}')

# Ref mouse is m03-d4, both have 15363 neurons, but different number of images, because ref has been filtered and reshaped to account for repeated images

[ref mouse], resps shape: (1573, 2, 15363), stims shape: (1573,)
[natimg-data-M01-D2], resps shape: (3445, 17842), stims shape: (3445,)
[natimg-data-M02-D3], resps shape: (3445, 20387), stims shape: (3445,)
[natimg-data-M03-D4], resps shape: (3445, 15363), stims shape: (3445,)


In [None]:
ref_ids = np.unique(ref_stims)
stims = other_stims['m03-d4']
resps = other_resps['m03-d4']

# Find stimulus IDs with exactly 2 repeats *and* present in the ref set
unique, counts = np.unique(stims, return_counts=True)
valid_stims = unique[(counts == 2)]

print(valid_stims.shape)  # Should be (1573,)

# Print which stims are missing from ref
extra_ids = np.setdiff1d(valid_stims, ref_ids)
print("Extra image IDs in m03-d4 but not in ref set:", extra_ids)

# Print out some example values for same images 
missing = np.setdiff1d(valid_stims, ref_ids)

for stim in missing:
    idxs = np.where(other_stims['m03-d4'] == stim)[0]
    print(f"Stim ID {stim}: trial indices {idxs}, shapes {[other_resps['m03-d4'][i].shape for i in idxs]}")

# Remove those extra ones
valid_stims = unique[(counts == 2) & np.isin(unique, ref_ids)]

# Reconstruct (n_images, 2, n_neurons)
reshaped_resps = []
for stim in valid_stims:
    indices = np.where(stims == stim)[0]
    indices = sorted(indices)
    reshaped_resps.append(resps[indices])

reshaped_resps = np.stack(reshaped_resps)
print(reshaped_resps.shape)  # Should now be (1573, 2, 15363)

# Check contents of rehhaped_resps is same as ref_resps
# Get stimulus IDs in order from reshaped_resps
reshaped_ids = sorted([stim for stim in valid_stims if stim in ref_ids])
ref_ids_sorted = ref_stims[:len(reshaped_ids)]  # assuming ref_stims is ordered

# Check if the IDs match in order
print(np.array_equal(ref_ids_sorted, reshaped_ids))  # Should be True

# Check contents match
same_shape = ref_resps.shape == reshaped_resps.shape
same_values = np.allclose(ref_resps, reshaped_resps)

print(f"Same shape? {same_shape}")
print(f"Same values (within tolerance)? {same_values}")

# Flatten both arrays across repeats and stimuli
ref_flat = ref_resps.reshape(-1, ref_resps.shape[-1])
new_flat = reshaped_resps.reshape(-1, reshaped_resps.shape[-1])

(1585,)
Extra image IDs in m03-d4 but not in ref set: [  42.   43.   96.  251.  508.  783.  876.  972. 1171. 1218. 1380. 1587.]
Stim ID 42.0: trial indices [1366 3402], shapes [(15363,), (15363,)]
Stim ID 43.0: trial indices [1369 3153], shapes [(15363,), (15363,)]
Stim ID 96.0: trial indices [1363 2509], shapes [(15363,), (15363,)]
Stim ID 251.0: trial indices [1370 2369], shapes [(15363,), (15363,)]
Stim ID 508.0: trial indices [1374 2497], shapes [(15363,), (15363,)]
Stim ID 783.0: trial indices [1365 2187], shapes [(15363,), (15363,)]
Stim ID 876.0: trial indices [1372 2284], shapes [(15363,), (15363,)]
Stim ID 972.0: trial indices [1361 3235], shapes [(15363,), (15363,)]
Stim ID 1171.0: trial indices [1362 2352], shapes [(15363,), (15363,)]
Stim ID 1218.0: trial indices [1368 2647], shapes [(15363,), (15363,)]
Stim ID 1380.0: trial indices [1371 3120], shapes [(15363,), (15363,)]
Stim ID 1587.0: trial indices [1373 2194], shapes [(15363,), (15363,)]
(1573, 2, 15363)
True
Same shap

In [80]:
# Correlate across neurons
from scipy.stats import spearmanr
corrs = [spearmanr(ref_flat[:, i], new_flat[:, i]).correlation for i in range(ref_flat.shape[1])]
print(f"Mean neuron-wise Spearman correlation: {np.mean(corrs):.3f}")

# Per-neuron mean and std before and after
ref_mean = ref_resps.mean(axis=(0, 1))
new_mean = reshaped_resps.mean(axis=(0, 1))
print("Mean of per-neuron means (ref vs new):", ref_mean.mean(), new_mean.mean())

ref_std = ref_resps.std(axis=(0, 1))
new_std = reshaped_resps.std(axis=(0, 1))
print("Mean of per-neuron stds (ref vs new):", ref_std.mean(), new_std.mean())

Mean neuron-wise Spearman correlation: 0.667
Mean of per-neuron means (ref vs new): 0.3331099231207367 0.026415929451672003
Mean of per-neuron stds (ref vs new): 0.6021887032253692 0.45298689256960945


In [None]:
reshaped_resps /= reshaped_resps.std(axis=(0, 1), keepdims=True)

# Check contents match
same_shape = ref_resps.shape == reshaped_resps.shape
same_values = np.allclose(ref_resps, reshaped_resps)

print(f"Same shape? {same_shape}")
print(f"Same values (within tolerance)? {same_values}")

Same shape? True
Same values (within tolerance)? False
Neuron 0:
Ref mouse response: [0. 0.]
Other mouse response: [-0.2487094  -0.69744443]

Neuron 1:
Ref mouse response: [0.03512369 0.35232964]
Other mouse response: [-0.45895129  1.64350555]

Neuron 2:
Ref mouse response: [0. 0.]
Other mouse response: [-0.36489253 -0.68108141]

Neuron 3:
Ref mouse response: [0.07362656 0.96354928]
Other mouse response: [-0.27635827  0.32450405]

Neuron 4:
Ref mouse response: [2.89037304 0.        ]
Other mouse response: [ 2.67573886 -0.29235265]



In [None]:
import numpy as np

stims = other_stims['m03-d4']
resps = other_resps['m03-d4']
missing = [42.0, 43.0, 96.0, 251.0, 508.0, 783.0, 876.0, 972.0, 1171.0, 1218.0, 1380.0, 1587.0]

for stim in missing:
    idxs = np.where(stims == stim)[0]
    r1, r2 = resps[idxs[0]], resps[idxs[1]]
    nan1, nan2 = np.isnan(r1).any(), np.isnan(r2).any()
    inf1, inf2 = np.isinf(r1).any(), np.isinf(r2).any()
    mean1, mean2 = np.mean(r1), np.mean(r2)
    print(f"Stim {stim:.0f}: NaNs? {nan1 or nan2}, Infs? {inf1 or inf2}, Means: ({mean1:.2f}, {mean2:.2f})")

Stim 42: NaNs? False, Infs? False, Means: (-0.20, 0.00)
Stim 43: NaNs? False, Infs? False, Means: (-0.20, 0.01)
Stim 96: NaNs? False, Infs? False, Means: (-0.20, -0.02)
Stim 251: NaNs? False, Infs? False, Means: (-0.20, 0.07)
Stim 508: NaNs? False, Infs? False, Means: (-0.20, 0.01)
Stim 783: NaNs? False, Infs? False, Means: (-0.20, -0.02)
Stim 876: NaNs? False, Infs? False, Means: (-0.20, 0.09)
Stim 972: NaNs? False, Infs? False, Means: (-0.20, -0.01)
Stim 1171: NaNs? False, Infs? False, Means: (-0.20, 0.02)
Stim 1218: NaNs? False, Infs? False, Means: (-0.20, 0.00)
Stim 1380: NaNs? False, Infs? False, Means: (-0.20, -0.05)
Stim 1587: NaNs? False, Infs? False, Means: (-0.20, 0.07)
Sampled included stimuli:
Stim 39: NaNs? False, Infs? False, Means: (0.07, -0.01)
Stim 401: NaNs? False, Infs? False, Means: (0.04, -0.02)
Stim 1496: NaNs? False, Infs? False, Means: (0.03, 0.08)
Stim 1812: NaNs? False, Infs? False, Means: (0.05, 0.01)
Stim 859: NaNs? False, Infs? False, Means: (0.01, 0.04)


In [76]:
import numpy as np

# Get all valid stimulus IDs that are included in the ref set
ref_ids = np.unique(ref_stims)
included_ids = [stim for stim in valid_stims if stim in ref_ids and stim not in [42.0, 43.0, 96.0, 251.0, 508.0, 783.0, 876.0, 972.0, 1171.0, 1218.0, 1380.0, 1587.0]]

# Randomly select 5 included stim IDs
# np.random.seed(0)  # for reproducibility
sampled_ids = np.random.choice(included_ids, size=5, replace=False)

# Print same stats for comparison
print("Sampled included stimuli:")
for stim in sampled_ids:
    idxs = np.where(stims == stim)[0]
    r1, r2 = resps[idxs[0]], resps[idxs[1]]
    nan1, nan2 = np.isnan(r1).any(), np.isnan(r2).any()
    inf1, inf2 = np.isinf(r1).any(), np.isinf(r2).any()
    mean1, mean2 = np.mean(r1), np.mean(r2)
    print(f"Stim {stim:.0f}: NaNs? {nan1 or nan2}, Infs? {inf1 or inf2}, Means: ({mean1:.2f}, {mean2:.2f})")

Sampled included stimuli:
Stim 1607: NaNs? False, Infs? False, Means: (-0.03, 0.00)
Stim 1103: NaNs? False, Infs? False, Means: (0.07, 0.04)
Stim 1132: NaNs? False, Infs? False, Means: (0.02, -0.01)
Stim 1036: NaNs? False, Infs? False, Means: (0.06, -0.02)
Stim 1439: NaNs? False, Infs? False, Means: (0.02, -0.05)


In [78]:
for mouse_id, stims in other_stims.items():
    print(f'Stims for {mouse_id} contains all ref stims: {np.all(np.isin(ref_stims, stims))}')

Stims for m01-d2 contains all ref stims: False
Stims for m02-d3 contains all ref stims: True
Stims for m03-d4 contains all ref stims: True
