## Initial codebase for implementing 1D convolutional network on mouse wake+sleep spiking data

## To-do:
1. dataset
2. dataloader
3. 1d-conv
4. Analysis: attributions?

In [None]:
import lightning.pytorch as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import torchmetrics

import os
import matplotlib.pyplot as plt
import numpy as np
from preprocessing import generate_binned_data, obtain_binned_rhythm, obtain_binned_acceleration

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# rhythm = obtain_binned_rhythm(rhythm)
# binned_speeds = obtain_rhythm_percentiles(speeds)

spikes = np.load(os.path.join('sleep_data', 'spikes.npy'))
rhythm = np.load(os.path.join('sleep_data', 'rhythm.npy'))
acceleration = np.load(os.path.join('sleep_data', 'acceleration.npy'))
behavior = np.load(os.path.join('sleep_data', 'behavior.npy'))

binned_acceleration, spikes = obtain_binned_acceleration(behavior, acceleration, spikes, just_running=False)
n_classes = 6

# binned_acceleration, spikes = obtain_binned_acceleration(behavior, acceleration, spikes, just_running=True)
# binned_acceleration = binned_acceleration - 2
# n_classes = 4

In [None]:
# spikes, rhythm, acceleration, behavior = generate_binned_data(data_path='sleep_data')
# spikes = np.array(spikes).transpose()

# np.save(os.path.join('sleep_data', 'spikes.npy'), spikes)
# np.save(os.path.join('sleep_data', 'rhythm.npy'), rhythm)
# np.save(os.path.join('sleep_data', 'acceleration.npy'), acceleration)
# np.save(os.path.join('sleep_data', 'behavior.npy'), behavior)

In [None]:
# np.unique(acceleration)

In [None]:
plt.plot(acceleration[400000:405000], color='orange')
plt.title('Acceleration (difference of front paw speeds)')

In [None]:
# # for behavior
# print(np.unique(behavior, return_counts=True))
# test_behavior = behavior+1
# classes, weights = np.unique(test_behavior, return_counts=True)
# print(classes, weights)
# behavior_encoding = torch.nn.functional.one_hot(torch.tensor(test_behavior), num_classes=3).numpy()
# # classes, weights = np.unique(test_behavior[:,2], return_counts=True)
# # print(classes, weights)
# weights = torch.tensor([weights[i]/spikes.shape[0] for i in classes])
# print(weights)

In [None]:
# plt.plot(rhythm[1010000:1020000])
# plt.show()

In [None]:
# print(f'behavior: {behavior.shape}')
# print(f'neurons: {spikes.shape[0]}')
# for i in range(10):
#     print(f'neuron {i}: {spikes[i].shape}')
# print('...')
# print(f'speeds: {rhythm.shape}')

In [None]:
# subsequence_length = 100
# windowed_spikes = torch.from_numpy(spikes)
# windowed_rhythm = torch.from_numpy(rhythm)
# windowed_behavior = torch.from_numpy(behavior)
# windowed_spikes = torch.stack([windowed_spikes[i:i+subsequence_length] for i in range(0, 10000-subsequence_length, subsequence_length)], dim=0)
# windowed_speeds = torch.stack([windowed_rhythm[i:i+subsequence_length] for i in range(0, 10000-subsequence_length, subsequence_length)], dim=0)
# windowed_behavior = torch.stack([windowed_behavior[i:i+subsequence_length] for i in range(0, 10000-subsequence_length, subsequence_length)], dim=0)
# print(f'windowed spikes: {windowed_spikes.shape}')
# print(f'windowed speeds: {windowed_speeds.shape}')
# print(f'windowed behavior: {windowed_behavior.shape}')

In [None]:
# subsequence_len = 100
# train_seq_len = spikes.shape[0]
# stacked_spikes = torch.from_numpy(spikes)
# stacked_behavior = torch.from_numpy(behavior)
# stacked_spikes = torch.stack([stacked_spikes[i:i+subsequence_len] for i in range(0, train_seq_len-2*subsequence_len, subsequence_len)], dim=0)
# stacked_spikes = torch.transpose(stacked_spikes, 1, 2)
# stacked_behavior = torch.stack([stacked_behavior[i:i+subsequence_len] for i in range(0, train_seq_len-2*subsequence_len, subsequence_len)], dim=0)
# print(f'stacked spikes: {stacked_spikes.shape}; stacked behavior: {stacked_behavior.shape}')

In [None]:
# construct pytorch lightning data module
class SpikeDataModule(pl.LightningDataModule):
    def __init__(self, spikes, behavior, num_classes, batch_size=16, subsequence_length=1000):
        super().__init__()
        self.spikes = torch.from_numpy(spikes)
        self.behavior = torch.from_numpy(behavior+1)
        self.batch_size = batch_size
        self.train_seq_len = spikes.shape[0]
        self.subsequence_len =  subsequence_length  # Note: subsequence length is not necessarily the same as the receptive field size
        self.num_classes = num_classes

    def setup(self, stage=None):
        # break spikes into windows of size window_size
        print(f'spikes: {self.spikes.shape}; behavior: {self.behavior.shape}')
        self.spikes = torch.stack([self.spikes[i:i+self.subsequence_len] for i in range(0, self.train_seq_len-self.subsequence_len, self.subsequence_len)], dim=0)
        self.spikes = torch.transpose(self.spikes, 1, 2).to(torch.float32)
        self.behavior = torch.nn.functional.one_hot(self.behavior, num_classes=self.num_classes)
        self.behavior = torch.stack([self.behavior[i:i+self.subsequence_len] for i in range(0, self.train_seq_len-self.subsequence_len, self.subsequence_len)], dim=0)
        self.behavior = torch.transpose(self.behavior, 1, 2).to(torch.float32)
        print(f'spikes: {self.spikes.shape}; behavior: {self.behavior.shape}')
        
        # TODO : testing; remove this
        # self.spikes = torch.cat((self.spikes, self.behavior), dim=1)

        # split dataset (behavior, spikes) into train, val, test
        # train_test_split_ind = self.spikes.shape[0]//5*4
        # train_val_split_ind = train_test_split_ind//5*4

        # print(train_test_split_ind, train_val_split_ind)

        first_train_split_ind = round(self.spikes.shape[0]*.4)
        end_val_split_ind = round(self.spikes.shape[0]*.5)
        end_test_split_ind = round(self.spikes.shape[0]*.6)

        first_train_split_spikes = self.spikes[0:first_train_split_ind]
        first_train_split_behavior = self.behavior[0:first_train_split_ind]
        second_train_split_spikes = self.spikes[end_test_split_ind:]
        second_train_split_behavior = self.behavior[end_test_split_ind:]
        
        self.data_train = TensorDataset(torch.cat((first_train_split_spikes, second_train_split_spikes), dim=0),
                                        torch.cat((first_train_split_behavior, second_train_split_behavior), dim=0))
        self.data_val = TensorDataset(self.spikes[first_train_split_ind:end_val_split_ind], self.behavior[first_train_split_ind:end_val_split_ind])
        self.data_test = TensorDataset(self.spikes[end_val_split_ind:end_test_split_ind], self.behavior[end_val_split_ind:end_test_split_ind])
        # self.data_train = TensorDataset(self.spikes[0:train_val_split_ind], self.behavior[0:train_val_split_ind])
        # self.data_val = TensorDataset(self.spikes[train_val_split_ind:train_test_split_ind], self.behavior[train_val_split_ind:train_test_split_ind])
        # self.data_test = TensorDataset(self.spikes[train_test_split_ind:], self.behavior[train_test_split_ind:])
        
    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=self.batch_size, shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(self.data_test, batch_size=self.batch_size, shuffle=False)

In [None]:
# construct pytorch lightning module of 3-layer basic 1-D CNN model
class SpikeModel(pl.LightningModule):
    def __init__(self, n_neurons, out_dim, weights, lr=1e-4, receptive_field=60):
        super().__init__()
        self.out_dim = out_dim
        # Note: receptive field is currently 1 + 2*L = 7
        self.conv1 = nn.Conv1d(n_neurons, 64, kernel_size=20, stride=1, padding=0)
        self.conv2 = nn.Conv1d(64, 16, kernel_size=21, stride=1, padding=0)
        self.conv3 = nn.Conv1d(16, out_dim, kernel_size=21, stride=1, padding=0)
        self.receptive_field = receptive_field
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.loss = nn.CrossEntropyLoss(weight=weights, reduction='mean')
        self.lr = lr
        # self.train_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=out_dim, multidim_average='samplewise')
        # self.val_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=out_dim, multidim_average='samplewise')
        # self.test_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=out_dim, multidim_average='samplewise')
        # self.train_accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=out_dim, multidim_average='global')
        # self.val_accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=out_dim, multidim_average='global')
        # self.test_accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=out_dim, multidim_average='global')
        self.epoch_train_accuracies = []
        self.epoch_val_accuracies = []
        self.epoch_test_accuracies = []

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, x):
        z1 = self.relu(self.conv1(x))
        z1 = self.dropout(z1)
        z2 = self.relu(self.conv2(z1))
        z2 = self.dropout(z2)
        y_hat = self.conv3(z2)
        return y_hat

    def training_step(self, batch, batch_idx):
        x, y = batch
        # x = y.clone()

        # Predict on LAST time step
        y = y[:, :, self.receptive_field-1:]

        # # Predict on MIDLE time step
        # first_half = self.receptive_field//2
        # second_half = self.receptive_field - first_half
        # y = y[:, :, first_half:-second_half+1]

        y_hat = self.forward(x)

        if batch_idx == 0 and self.current_epoch % 10 == 0:
            print(f'y_hat shape: {y_hat.shape}; y shape: {y.shape}')
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)
        # combine 0th and 2nd dimension of y_hat

        y_flat = y.transpose(0, 1).flatten(1, 2).transpose(0, 1)
        y_hat_flat = y_hat.transpose(0, 1).flatten(1, 2).transpose(0, 1)
        y_flat = torch.argmax(y_flat, dim=1)
        y_hat_flat = torch.argmax(y_hat_flat, dim=1)
        # print(f'y_hat_flat shape: {y_hat_flat.shape}; y_flat shape: {y_flat.shape}')
        step_accuracy = torch.sum(y_hat_flat == y_flat) / (y_flat.shape[0])
        self.log('train_accuracy', step_accuracy)
        self.epoch_train_accuracies.append(step_accuracy)


        # self.train_accuracy(y_hat_flat, y_flat)
        # self.train_accuracy(y_hat, y)
        # self.log('train_acc_step', self.train_accuracy)            
        if batch_idx == 0 and self.current_epoch % 50 == 0:

            # print(f'y_hat_flat shape: {y_hat_flat.shape}; y_flat shape: {y_flat.shape}')
            # print(f'y_flat: {y_flat[0].detach().cpu()}')
            pred = torch.argmax(y_hat, dim=1)
            # print(f'pred: {y_hat[0].detach().cpu().numpy().transpose()}')
            y_class = torch.argmax(y, dim=1)
            plt.figure()
            sample_pred = pred[0].detach().cpu().numpy().transpose()
            sample_y = y_class[0].detach().cpu().numpy().transpose()
            print(f'prediction shape: {sample_pred.shape}')
            plt.plot(sample_pred, label='prediction')
            plt.plot(sample_y, label='ground truth')
            plt.title('Training sample prediction')
            plt.legend()
            plt.show()

            plt.figure()
            plt.imshow(x[0].detach().cpu().numpy(), aspect='auto')
            plt.title('Training sample input')
            plt.xlabel('Time')
            plt.ylabel('Neuron')
            plt.colorbar()
            plt.show()

            plt.figure()
            raw_sample_pred = y_hat[0].detach().cpu().numpy().transpose()
            if self.out_dim == 6:
                plt.plot(raw_sample_pred[:,0], label='raw prediction (class 0 - still)')
                plt.plot(raw_sample_pred[:,1], label='raw prediction (class 1 - other)')
                plt.plot(raw_sample_pred[:,2], label='raw prediction (class 2 - running 1)')
                plt.plot(raw_sample_pred[:,3], label='raw prediction (class 3 - running 2)')
                plt.plot(raw_sample_pred[:,4], label='raw prediction (class 4 - running 3)')
                plt.plot(raw_sample_pred[:,5], label='raw prediction (class 5 - running 4)')
            elif self.out_dim == 4:
                plt.plot(raw_sample_pred[:,0], label='raw prediction (class 0 - running 1)')
                plt.plot(raw_sample_pred[:,1], label='raw prediction (class 1 - running 2)')
                plt.plot(raw_sample_pred[:,2], label='raw prediction (class 2 - running 3)')
                plt.plot(raw_sample_pred[:,3], label='raw prediction (class 3 - running 4)')
            plt.title('Training prediction raw confidences')
            plt.legend()
            plt.show()
        return loss
    

    def on_train_epoch_end(self):
        # self.log('train_acc_epoch', self.train_accuracy)
        self.log('train_acc_epoch', torch.mean(torch.tensor(self.epoch_train_accuracies)))
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        # x = torch.Tensor([y.clone() for i in range(143)])
        # x = y.clone()
        
        # Predict on LAST time step
        y = y[:, :, self.receptive_field-1:]

        # # Predict on MIDDLE time step
        # first_half = self.receptive_field//2
        # second_half = self.receptive_field - first_half
        # y = y[:, :, first_half:-second_half+1]

        # print(f'x shape: {x.shape}')
        # print(f'y shape: {y.shape}')
        y_hat = self.forward(x)
        # print(f'y_hat shape: {y_hat.shape}; y shape: {y.shape}')
        loss = self.loss(y_hat, y)
        self.log('val_loss', loss)
        # self.val_accuracy(y_hat, y)
        # self.log('val_acc_step', self.val_accuracy)

        y_flat = y.transpose(0, 1).flatten(1, 2).transpose(0, 1)
        y_hat_flat = y_hat.transpose(0, 1).flatten(1, 2).transpose(0, 1)
        y_flat = torch.argmax(y_flat, dim=1)
        y_hat_flat = torch.argmax(y_hat_flat, dim=1)
        # print(f'y_hat_flat shape: {y_hat_flat.shape}; y_flat shape: {y_flat.shape}')
        step_accuracy = torch.sum(y_hat_flat == y_flat) / (y_flat.shape[0])
        self.log('val_accuracy', step_accuracy)
        self.epoch_val_accuracies.append(step_accuracy)

        if batch_idx == 0 and self.current_epoch % 100 == 0:

            # print(f'y_hat_flat shape: {y_hat_flat.shape}; y_flat shape: {y_flat.shape}')
            # print(f'y_flat: {y_flat[0].detach().cpu()}')
            pred = torch.argmax(y_hat, dim=1)
            # print(f'pred: {y_hat[0].detach().cpu().numpy().transpose()}')
            y_class = torch.argmax(y, dim=1)
            plt.figure()
            sample_pred = pred[0].detach().cpu().numpy().transpose()
            sample_y = y_class[0].detach().cpu().numpy().transpose()
            print(f'prediction shape: {sample_pred.shape}')
            plt.plot(sample_pred, label='prediction')
            plt.plot(sample_y, label='ground truth')
            plt.title('Validation sample prediction')
            plt.legend()
            plt.show()

            plt.figure()
            plt.imshow(x[0].detach().cpu().numpy(), aspect='auto')
            plt.title('Validation sample input')
            plt.xlabel('Time')
            plt.ylabel('Neuron')
            plt.colorbar()
            plt.show()

            plt.figure()
            raw_sample_pred = y_hat[0].detach().cpu().numpy().transpose()
            if self.out_dim == 6:
                plt.plot(raw_sample_pred[:,0], label='raw prediction (class 0 - still)')
                plt.plot(raw_sample_pred[:,1], label='raw prediction (class 1 - other)')
                plt.plot(raw_sample_pred[:,2], label='raw prediction (class 2 - running 1)')
                plt.plot(raw_sample_pred[:,3], label='raw prediction (class 3 - running 2)')
                plt.plot(raw_sample_pred[:,4], label='raw prediction (class 4 - running 3)')
                plt.plot(raw_sample_pred[:,5], label='raw prediction (class 5 - running 4)')
            elif self.out_dim == 4:
                plt.plot(raw_sample_pred[:,0], label='raw prediction (class 0 - running 1)')
                plt.plot(raw_sample_pred[:,1], label='raw prediction (class 1 - running 2)')
                plt.plot(raw_sample_pred[:,2], label='raw prediction (class 2 - running 3)')
                plt.plot(raw_sample_pred[:,3], label='raw prediction (class 3 - running 4)')
            plt.title('Validation prediction raw confidences')
            plt.legend()
            plt.show()
        return loss
    
    def on_validation_epoch_end(self):
        # self.log('val_acc_epoch', self.val_accuracy)
        self.log('val_acc_epoch', torch.mean(torch.tensor(self.epoch_val_accuracies)))

    def test_step(self, batch, batch_idx):
        x, y = batch
        # x = y.clone()
        
        # Predict on LAST time step
        y = y[:, :, self.receptive_field-1:]

        # # Predict on MIDDLE time step
        # first_half = self.receptive_field//2
        # second_half = self.receptive_field - first_half
        # y = y[:, :, first_half:-second_half+1]

        # print(f'x shape: {x.shape}')
        # print(f'y shape: {y.shape}')
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        self.log('test_loss', loss)


        # self.test_accuracy(y_hat, y)
        # self.log('test_acc_step', self.test_accuracy)
        # pred = torch.argmax(y_hat, dim=1)
        # step_accuracy = torch.sum(pred == y) / (y.shape[0] * y.shape[1])
        # self.log('test_accuracy', step_accuracy)
        y_flat = y.transpose(0, 1).flatten(1, 2).transpose(0, 1)
        y_hat_flat = y_hat.transpose(0, 1).flatten(1, 2).transpose(0, 1)
        y_flat = torch.argmax(y_flat, dim=1)
        y_hat_flat = torch.argmax(y_hat_flat, dim=1)
        # print(f'y_hat_flat shape: {y_hat_flat.shape}; y_flat shape: {y_flat.shape}')
        step_accuracy = torch.sum(y_hat_flat == y_flat) / (y_flat.shape[0])
        self.log('test_accuracy', step_accuracy)
        self.epoch_test_accuracies.append(step_accuracy)


        if batch_idx == 0:
            pred = torch.argmax(y_hat, dim=1)
            y_class = torch.argmax(y, dim=1)
            plt.figure()
            sample_pred = pred[0].detach().cpu().numpy().transpose()
            sample_y = y_class[0].detach().cpu().numpy().transpose()
            print(f'prediction shape: {sample_pred.shape}')
            plt.plot(sample_pred, label='prediction')
            plt.plot(sample_y, label='ground truth')
            plt.title('Test sample prediction')
            plt.legend()
            plt.show()

            plt.figure()
            plt.imshow(x[0].detach().cpu().numpy(), aspect='auto')
            plt.title('Training sample input')
            plt.xlabel('Time')
            plt.ylabel('Neuron')
            plt.colorbar()
            plt.show()

            plt.figure()
            raw_sample_pred = y_hat[0].detach().cpu().numpy().transpose()
            if self.out_dim == 6:
                plt.plot(raw_sample_pred[:,0], label='raw prediction (class 0 - still)')
                plt.plot(raw_sample_pred[:,1], label='raw prediction (class 1 - other)')
                plt.plot(raw_sample_pred[:,2], label='raw prediction (class 2 - running 1)')
                plt.plot(raw_sample_pred[:,3], label='raw prediction (class 3 - running 2)')
                plt.plot(raw_sample_pred[:,4], label='raw prediction (class 4 - running 3)')
                plt.plot(raw_sample_pred[:,5], label='raw prediction (class 5 - running 4)')
            elif self.out_dim == 4:
                plt.plot(raw_sample_pred[:,0], label='raw prediction (class 0 - running 1)')
                plt.plot(raw_sample_pred[:,1], label='raw prediction (class 1 - running 2)')
                plt.plot(raw_sample_pred[:,2], label='raw prediction (class 2 - running 3)')
                plt.plot(raw_sample_pred[:,3], label='raw prediction (class 3 - running 4)')
            plt.title('Test prediction raw confidences')
            plt.legend()
            plt.show()
        return loss
    
    def on_test_epoch_end(self):
        # self.log('test_acc_epoch', self.test_accuracy)
        self.log('test_acc_epoch', torch.mean(torch.tensor(self.epoch_test_accuracies)))
    
    def on_test_end(self):
        print('Finished testing')
        print(f'Loss: {self.trainer.callback_metrics["test_loss"]}')

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
# for acceleration
print(np.unique(binned_acceleration, return_counts=True))
test_binned_acceleration = binned_acceleration+1
classes, weights = np.unique(test_binned_acceleration, return_counts=True)
print(classes, weights)
binned_acceleration_encoding = torch.nn.functional.one_hot(torch.tensor(test_binned_acceleration), num_classes=6).numpy()
# classes, weights = np.unique(test_behavior[:,2], return_counts=True)
# print(classes, weights)
weights = torch.tensor([weights[i]/spikes.shape[0] for i in classes])

weights = 1/weights
w_max = torch.max(weights)
weights = weights/w_max
print(f'weights: {weights}')

In [None]:
# logger = pl.loggers.TensorBoardLogger('logs/', name='1D_CNN')
# receptive_field = 60
datamodule = SpikeDataModule(spikes, binned_acceleration, num_classes=n_classes)
datamodule.setup()
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
test_loader = datamodule.test_dataloader()
spike_model = SpikeModel(n_neurons=spikes.shape[1], out_dim=n_classes, weights=weights)
# spike_model = SpikeModel(n_neurons=3, out_dim=3, receptive_field=220)
trainer = pl.Trainer(devices=1, max_epochs=2000)  # , logger=logger
trainer.fit(spike_model, train_loader, val_loader)
# trainer.fit(spike_model, datamodule)
# trainer.validate(spike_model, val_loader)
trainer.test(spike_model, test_loader)
# trainer.test(spike_model, datamodule)
# trainer.save_checkpoint(os.path.join('results', 'model.ckpt'))