## Initial codebase for implementing 1D convolutional network on mouse wake+sleep spiking data

## To-do:
1. dataset
2. dataloader
3. 1d-conv
4. Analysis: attributions?

In [16]:
import lightning.pytorch as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
from generate_binned_data import generate_binned_data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

spikes, speeds = generate_binned_data(data_path='sleep_data')
spikes = np.array(spikes)

AttributeError: type object 'torch.device' has no attribute 'is_cuda_available'

In [12]:
print(f'neurons: {spikes.shape[0]}')
for i in range(10):
    print(f'neuron {i}: {spikes[i].shape}')
print('...')
print(f'speeds: {speeds.shape}')

neurons: 143
neuron 0: (2912723,)
neuron 1: (2912723,)
neuron 2: (2912723,)
neuron 3: (2912723,)
neuron 4: (2912723,)
neuron 5: (2912723,)
neuron 6: (2912723,)
neuron 7: (2912723,)
neuron 8: (2912723,)
neuron 9: (2912723,)
...
speeds: (2912725,)


In [14]:
# construct pytorch lightning data module
class SpikeDataModule(pl.LightningDataModule):
    def __init__(self, spikes, speeds, batch_size=32):
        super().__init__()
        self.spikes = torch.from_numpy(spikes)
        self.speeds = torch.from_numpy(speeds)
        self.batch_size = batch_size
        self.train_seq_len = spikes.shape[0]
        # self.window_size =  window_size  # note: window_size should equal receptive field size of 1d-CNN

    def setup(self, stage=None):
        # split dataset (paw_power, spikes) into train, val, test
        train_test_split_ind = self.train_seq_len//5*4
        self.data_train = TensorDataset(self.spikes[0:train_test_split_ind], self.speeds[0:train_test_split_ind])
        self.data_val = TensorDataset(self.spikes[train_test_split_ind:], self.speeds[train_test_split_ind:])
        # self.data_test = TensorDataset(self.spikes[train_test_split_ind:], self.speeds[train_test_split_ind:])
        
    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=self.batch_size, shuffle=False)

In [21]:
# construct pytorch lightning module of 3-layer basic 1-D CNN model
class SpikeModel(pl.LightningModule):
    def __init__(self, n_neurons):
        super().__init__()
        # Note: receptive field is currently 1 + 2*L = 7
        self.conv1 = nn.Conv1d(n_neurons, 64, kernel_size=3, stride=1, padding=0)
        self.conv2 = nn.Conv1d(64, 16, kernel_size=3, stride=1, padding=0)
        self.conv3 = nn.Conv1d(16, 1, kernel_size=3, stride=1, padding=0)
        self.relu = nn.ReLU()
        # self.dropout = nn.Dropout(0.2)
        self.loss = nn.MSELoss()

    def forward(self, x):
        z1 = self.relu(self.conv1(x))
        # x = self.dropout(x)
        z2 = self.relu(self.conv2(z1))
        # x = self.dropout(x)
        y_hat = self.relu(self.conv3(z2))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [22]:
datamodule = SpikeDataModule(spikes, speeds)
spike_model = SpikeModel(n_neurons=spikes.shape[0])


In [None]:
# # 3-layer basic 1-D CNN model
# import torch.nn as nn
# import torch.nn.functional as F
# class Net(nn.Module):
#     def __init__(self):
#         super(Net, self).__init__()
        
#         # define convolutional layers
#         self.conv1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=0)
#         self.conv2 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=0)
#         self.conv3 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=0)

#         # define pooling layer
#         # self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)

#         # define relu layer
#         self.relu = nn.ReLU()


#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.relu(x)
#         x = self.conv2(x)
#         x = self.relu(x)
#         # x = self.pool(x)
#         x = self.conv3(x)
#         x = self.relu(x)
#         return x


In [None]:
# TODO: add in loading from above into the init function
import torch
from torch.utils.data import Dataset, DataLoader
# construct pytorch dataset class
# 1. get all spikes in a 1 second window
# 2. get all power in a 1 second window
class SpikeDataset(Dataset):
    def __init__(self, ttl_times, spikes, front_left_paw_speed, front_right_paw_speed, window_size=10, step_size=1):
        self.ttl_times = ttl_times
        self.spikes = spikes
        self.front_left_paw_speed = front_left_paw_speed
        self.front_right_paw_speed = front_right_paw_speed
        # note: window_size should equal receptive field size of 1d-CNN
        self.window_size = window_size
        self.step_size = step_size

        # calculate power here

        # toy data to move forward
        self.paw_power = np.ones(100)
        self.spikes = np.ones(100)
    
    def __len__(self):
        # TODO: check this
        return len(self.spikes) // self.window_size
    
    def __getitem__(self, idx):
        # for one 'subsequence', pull out binned spikes and power
        spike_subsequence = self.spikes[idx*self.window_size:(idx+1)*self.window_size]
        power_subsequence = self.paw_power[idx*self.window_size:(idx+1)*self.window_size]
        return spike_subsequence, power_subsequence