In [3]:
import torch
import numpy as np

from pathlib import Path
import scipy.io

In [5]:
# load input spike data
mat = scipy.io.loadmat(Path('data') / '01.mat') # animal 01

spike_train_all = mat['resp_train'] # spike train of all neurons, neurons x image x trials x milliseconds

images_all = mat['images'].squeeze()
images_all = torch.stack([torch.tensor(entry) for entry in images_all])

In [6]:
# only keep well-centered channels
indcent = mat['INDCENT'].squeeze()
spike_train_cent = torch.tensor(spike_train_all[indcent == 1]).float()

In [7]:
# get indices of all small natural images
idx_small_nat_images = torch.zeros(spike_train_all.shape[1], dtype=torch.bool)
idx_small_nat_images[:539:2] = 1

# get indices of all big natural images
idx_big_nat_images = torch.ones(spike_train_all.shape[1], dtype=torch.bool)
idx_big_nat_images[:539:2] = 0
idx_big_nat_images[540:] = 0

# get indices of all gratings
idx_gratings = torch.zeros(spike_train_all.shape[1], dtype=torch.bool)
idx_gratings[540:] = 1

In [8]:
# only use gratings
spike_train_cent = spike_train_cent[:, idx_gratings, :, :]
images_all = images_all[idx_gratings, :, :]

In [9]:
n_images = len(images_all)
train_frac = 0.8
val_frac = 0.1
test_frac = 0.1

In [10]:
# generate shuffled indices
indices = np.arange(n_images)
np.random.shuffle(indices)

In [11]:
# compute split sizes
train_size = int(train_frac * n_images)
val_size = int(val_frac * n_images)
test_size = int(test_frac * n_images)

In [12]:
# split indices
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]

In [13]:
# create boolean masks
train_mask = np.zeros(n_images, dtype=bool)
val_mask = np.zeros(n_images, dtype=bool)
test_mask = np.zeros(n_images, dtype=bool)

train_mask[train_indices] = True
val_mask[val_indices] = True
test_mask[test_indices] = True

In [14]:
# create training, validation, and test sets
train_images = images_all[train_mask, :, :]
val_images = images_all[val_mask, :, :]
test_images = images_all[test_mask, :, :]

train_spikes = spike_train_cent[:, train_mask, :, :]
val_spikes = spike_train_cent[:, val_mask, :, :]
test_spikes = spike_train_cent[:, test_mask, :, :]

In [15]:
# save files
save_dir = Path('data')

torch.save(train_images, Path(save_dir) / 'train_images_gratings.pt')
torch.save(val_images, Path(save_dir) / 'val_images_gratings.pt')
torch.save(test_images, Path(save_dir) / 'test_images_gratings.pt')

torch.save(train_spikes, Path(save_dir) / 'train_spikes_gratings.pt')
torch.save(val_spikes, Path(save_dir) / 'val_spikes_gratings.pt')
torch.save(test_spikes, Path(save_dir) / 'test_spikes_gratings.pt')

In [4]:
train_spikes = torch.load('/Users/Divya/SNN Decoding - ANN/data/train_spikes_gratings.pt')
train_images = torch.load('/Users/Divya/SNN Decoding - ANN/data/train_images_gratings.pt')

print("Train spikes shape:", train_spikes.shape)  # [26, 332, 20, 106]
print("Train images shape:", train_images.shape)  # [332, 320, 320]

# For the first few images:
for i in range(5):
    # Get the spike patterns for this image
    spikes = train_spikes[:, i, :, :]  # [26 neurons, 20 trials, 106 timepoints]
    
    # Get corresponding image
    image = train_images[i]  # [320, 320]
    
    # Calculate average firing rate per neuron
    firing_rates = spikes.mean(dim=(1,2))  # Average across trials and time
    
    # We could print:
    print(f"\nImage {i}:")
    print("Active neurons (firing rate > threshold):", 
          torch.where(firing_rates > firing_rates.mean())[0].tolist())
    
    # Find bar location in image (look for non-background pixels)
    bar_pixels = torch.where(image != image[0,0])  # Get coordinates of the bar
    if len(bar_pixels[0]):
        center_y = bar_pixels[0].float().mean()
        center_x = bar_pixels[1].float().mean()
        print(f"Bar center location: ({center_x:.1f}, {center_y:.1f})")

Train spikes shape: torch.Size([26, 332, 20, 106])
Train images shape: torch.Size([332, 320, 320])

Image 0:
Active neurons (firing rate > threshold): [1, 5, 11, 12, 17, 19, 22]
Bar center location: (159.0, 159.0)

Image 1:
Active neurons (firing rate > threshold): [0, 1, 5, 6, 7, 12, 16]
Bar center location: (159.0, 159.0)

Image 2:
Active neurons (firing rate > threshold): [0, 1, 5, 6, 12, 14, 16]
Bar center location: (159.0, 159.0)

Image 3:
Active neurons (firing rate > threshold): [1, 8, 11, 12, 16, 19]
Bar center location: (159.0, 159.0)

Image 4:
Active neurons (firing rate > threshold): [1, 2, 6, 11, 12]
Bar center location: (159.0, 159.0)


In [5]:
# For ground truth images
print("First few images:")
for i in range(3):
    # Get unique values to understand image structure
    unique_vals = torch.unique(train_images[i])
    print(f"\nImage {i} unique values:", unique_vals.tolist())
    
    # Get number of non-background pixels to understand bar width/length
    background = train_images[i, 0, 0]  # Assuming corner pixel is background
    non_background = torch.sum(train_images[i] != background)
    print(f"Number of non-background pixels:", non_background.item())

# For spike patterns
print("\nSpike patterns:")
for i in range(2):  # Look at first 2 images
    print(f"\nImage {i}:")
    # Get active neurons
    spikes = train_spikes[:, i, :, :]  # [26 neurons, 20 trials, 106 timepoints]
    firing_rates = spikes.mean(dim=(1,2))  # Average across trials and time
    active_neurons = torch.where(firing_rates > firing_rates.mean())[0].tolist()
    
    # For each active neuron, show spike timing
    for neuron in active_neurons[:2]:  # Show first 2 active neurons
        # Average across trials
        avg_spikes = spikes[neuron].float().mean(dim=0)  # Average across trials
        # Print timestamps of highest activity
        peak_times = torch.argsort(avg_spikes, descending=True)[:5]
        print(f"Neuron {neuron} peak spike times:", peak_times.tolist())

First few images:

Image 0 unique values: [59, 61, 63, 65, 67, 69, 71, 74, 76, 78, 80, 83, 85, 87, 90, 92, 95, 97, 100, 102, 105, 107, 110, 112, 115, 116, 117, 120, 122, 125, 127, 130, 132, 135, 137, 140, 142, 145, 147, 149, 152, 154, 156, 158, 161, 163, 165, 167, 169, 171]
Number of non-background pixels: 1941

Image 1 unique values: [116, 174, 176, 178, 179, 181, 182, 184, 185, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197]
Number of non-background pixels: 1941

Image 2 unique values: [61, 63, 65, 67, 69, 71, 74, 76, 78, 80, 83, 85, 87, 90, 92, 95, 97, 100, 102, 105, 107, 110, 112, 115, 116, 117, 120, 122, 125, 127, 130, 132, 135, 137, 140, 142, 145, 147, 149, 152, 154, 156, 158, 161, 163, 165, 167, 169, 171, 173]
Number of non-background pixels: 1941

Spike patterns:

Image 0:
Neuron 1 peak spike times: [99, 103, 102, 101, 100]
Neuron 5 peak spike times: [104, 101, 98, 97, 59]

Image 1:
Neuron 0 peak spike times: [101, 98, 80, 3, 4]
Neuron 1 peak spike times: [4, 95, 93, 90,

In [11]:
val_images = torch.load('/Users/Divya/SNN Decoding - ANN/data/val_images_gratings.pt')
print(val_images.shape)
decoded_images = torch.load('/Users/Divya/SNN Decoding - ANN/outputs/decoded_image_val.pt')
print(decoded_images.shape)
test_images = val_images = torch.load('/Users/Divya/SNN Decoding - ANN/data/test_images_gratings.pt')
print(test_images.shape)

torch.Size([41, 320, 320])
torch.Size([41, 320, 320])
torch.Size([43, 320, 320])
