In [None]:
import pickle
import numpy as np
import os
import matplotlib.pyplot as plt

from model import Model
from decoder import LinearAccDecoder
import utils
from early_stopping import EarlyStopping

import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
config = utils.read_config()
# set seeds
utils.set_seeds(config['seed'])

In [None]:
behaviour_data, spikes = utils.load_dataset(config)
# consider data from only t = -1
# time_from = int(1/bin_len)
# behaviour_data, spikes = [x[time_from:, :] for x in behaviour_data], [x[time_from:, :] for x in spikes]
num_trials, time_bins, emissions_dim = np.array(spikes).shape

In [None]:
stim_idx, choice_idx = 6, 3
stim = [x[0, stim_idx] for x in behaviour_data]
choice = [x[0, choice_idx] for x in behaviour_data]
num_contacts = [np.sum(x[:, -9:-5], axis=1) for x in behaviour_data]
# concat them
behaviour_data = np.stack((stim, choice), axis=1)

In [None]:
# convert to torch tensors
behaviour_data = torch.tensor(behaviour_data, dtype=torch.float32)
spikes = torch.tensor(spikes, dtype=torch.float32)

In [None]:
# create dataloader with random sampling for training and testing
# split data into training and testing
behaviour_data_train, behaviour_data_test, spikes_train, spikes_test = train_test_split(behaviour_data, spikes, test_size=0.3, random_state=42)

# create dataloaders
train_dataset = TensorDataset(behaviour_data_train, spikes_train)
test_dataset = TensorDataset(behaviour_data_test, spikes_test)

batch_size = config['batch_size']
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# distribution of choice and stimulus in test
print("Train distribution of Stimulus: {}, Choice: {}".format(np.mean(behaviour_data_train[:, 0].numpy()), np.mean(behaviour_data_train[:, 1].numpy())))
print("Test distribution of Stimulus: {}, Choice: {}".format(np.mean(behaviour_data_test[:, 0].numpy()), np.mean(behaviour_data_test[:, 1].numpy())))

In [None]:
# mean firing rate of neurons in tran spikes
neuron_bias = torch.mean(spikes_train, dim=0)

In [None]:
config = utils.read_config()
# training loop
num_epochs, learning_rate = config['epochs'], config['lr']
# create model and optimizer
model = Model(config, input_dim=emissions_dim, z_dim=2, x_dim=2)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# # check if mps is available
# device = torch.device('mps' if torch.backends.mps.is_built() else 'cpu')
# print(device)
# model = model.to(device)
# spikes = spikes.to(device)

In [None]:
def test(model, test_loader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for _, (behavior_batch, spikes_batch) in enumerate(test_loader):
            y_recon, (mu, A), (z, x), behavior_batch_pred = model(spikes_batch)
            _, loss_l = model.loss(spikes_batch, y_recon, mu, A, z, x, behavior_batch_pred, behavior_batch)
            test_loss += np.array(loss_l)
    # divide loss by total number of samples in dataloader    
    return test_loss/len(test_loader)

In [None]:
def train(model, val_loader):
    train_losses, test_losses = [], []
    test_every = config['test_every']
    early_stop = EarlyStopping(patience=config['early_stop']['patience'], delta=config['early_stop']['delta'],
                            trace_func=print)
    save_model = True
    for epoch in range(num_epochs):
        # forward pass
        epoch_loss = 0
        for i, (behavior_batch, spikes_batch) in enumerate(train_loader):
            model.train()
            y_recon, (mu, A), (z, x), behavior_pred = model(spikes_batch)
            loss, loss_l = model.loss(spikes_batch, y_recon, mu, A, z, x, behavior_pred, behavior_batch)        
            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()    
            epoch_loss += np.array(loss_l)
        train_losses.append(epoch_loss/len(train_loader))
        # test loss
        if (epoch+1) % test_every == 0:
            test_loss = test(model, val_loader)
            test_losses.append(test_loss)
            early_stop(np.sum(test_loss), model, save_model=save_model, save_prefix='best')
            print('Epoch [{}/{}], Train Loss: {}, Test Loss: {}'.format(epoch+1, num_epochs, train_losses[-1], test_losses[-1]))            
            if early_stop.slow_down:
                test_every = config['early_stop']['test_every_new']
            else:
                test_every = config['test_every']
            if early_stop.early_stop:
                print("Early stopping")
                break
    # compute min test loss and return it    
    return np.min([np.sum(x) for x in test_losses])


# train model
min_test_loss = train(model, test_loader)

In [None]:
c1, c2 = model.vae.c1.weight.detach().numpy(), model.vae.c2.weight.detach().numpy()
print(c1.T.dot(c2)/(np.linalg.norm(c1)*np.linalg.norm(c2)))

In [None]:
with torch.no_grad():
    model.eval()
    y_recon, (mu, A), _, _ = model.forward(spikes)    
    # run on only test
    y_recon_test, (mu_test, A_test), _, _ = model.forward(spikes_test)
    
# convert to numpy
y_recon_np = y_recon.detach().numpy()
spikes_np = spikes.detach().numpy()
y_recon_test_np = y_recon_test.detach().numpy()
spikes_test_np = spikes_test.detach().numpy()
# compute bits/spike
bits_per_spike_all = utils.bits_per_spike(y_recon_np, spikes_np)
bits_per_spike_test = utils.bits_per_spike(y_recon_test_np, spikes_test_np)
# show distribution of bits per spike
plt.hist(bits_per_spike_all, bins=50)
plt.ylabel('Bits/spike')
plt.xlabel('Frequency')
plt.show()
# print('Bits per spike: {}'.format(bits_per_spike))
print("Bits per spike all: {}, test: {}".format(np.sum(bits_per_spike_all), np.sum(bits_per_spike_test)))

In [None]:
a_t = np.mean(A.numpy()[:, 10, :, :], axis=0)
cov = a_t * a_t.T
# print(cov.shape, spikes_np.shape)
plt.imshow(cov)
plt.colorbar()

In [None]:
# plot PSTH of reconstructed and original data
averaged_recon, averaged_original = y_recon.mean(axis=0), spikes_np.mean(axis=0)
# stimulus and choice important
common = [12, 14, 4, 31]
stim_neurons = [15, 11, 33, 30]
choice_neurons = [16, 2, 6, 8]
# plot each in a 5x7 grid
fig, axs = plt.subplots(5, 7, figsize=(12, 9))
# set title of figure
fig.suptitle('yellow: choice, green: stimulus, pink: common')
for i in range(5):
    for j in range(7):
        neuron_idx = i*7+j        
        axs[i, j].plot(averaged_recon[:, neuron_idx], label='recon', color='red')
        axs[i, j].plot(averaged_original[:, neuron_idx], label='original', color='blue')
        # no ticks
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
        # set title of plot to neuron index
        # print only 2 decimal places        
        axs[i, j].set_title('{}: {:.4f}'.format(neuron_idx, bits_per_spike_all[neuron_idx]))
        # set background color of plot to green if neuron in choice
        if neuron_idx in choice_neurons:
            axs[i, j].set_facecolor('yellow')
        # set background color of plot to red if neuron in stimulus
        if neuron_idx in stim_neurons:
            axs[i, j].set_facecolor('green')
        # set background color of plot to blue if neuron in common
        if neuron_idx in common:
            axs[i, j].set_facecolor('pink')
axs[0, 0].legend()

In [None]:
# plot trial averaged latent space
z, x = torch.sigmoid(mu[:, :, :model.vae.z_dim]).numpy(), mu[:, :, model.vae.z_dim:].numpy()
# z_std, x_std = np.std(z, axis=0), np.std(x, axis=0)
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
z_avg, x_avg = np.mean(z, axis=0), np.mean(x, axis=0)
# make x ticks of range 0.1 from -2 to 0.5
t = np.arange(-1, 0.5, bin_len)
axs[0].plot(t, z_avg[:, 0], label='z0')
axs[0].plot(t, z_avg[:, 1], label='z1')
axs[1].plot(t, x_avg[:, 0], label='x0')
axs[1].plot(t, x_avg[:, 1], label='x1')
axs[0].set_title('z')
axs[1].set_title('x')
axs[0].legend()
axs[1].legend()

In [None]:
# group x for stimulus and choice
stim, choice = behaviour_data[:, 0].numpy(), behaviour_data[:, 1].numpy()
# group x for stimulus
x_stim_left, x_stim_right = x[stim == 1].mean(axis=0), x[stim == 0].mean(axis=0)
x_choice_left, x_choice_right = x[choice == 1].mean(axis=0), x[choice == 0].mean(axis=0)
# plot x for stimulus and choice
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].plot(t, x_stim_left[:, 0], label='stim left (x0)', color='red', linestyle='--')
axs[0].plot(t, x_stim_left[:, 1], label='stim left (x1)', color='red')
axs[0].plot(t, x_stim_right[:, 0], label='stim right (x0)', color='blue', linestyle='--')
axs[0].plot(t, x_stim_right[:, 1], label='stim right (x1)', color='blue')
axs[0].set_title('x grouped by stimulus')
axs[0].legend()

axs[1].plot(t, x_choice_left[:, 0], label='choice left (x0)', color='red', linestyle='--')
axs[1].plot(t, x_choice_left[:, 1], label='choice left (x1)', color='red')
axs[1].plot(t, x_choice_right[:, 0], label='choice right (x0)', color='blue', linestyle='--')
axs[1].plot(t, x_choice_right[:, 1], label='choice right (x1)', color='blue')
axs[1].set_title('x grouped by choice')
axs[1].legend()


In [None]:
with torch.no_grad():
    model.eval()    
    _, (mu_train, A_train), (z_train, x_train), behavior_pred_train = model.forward(spikes_train)
    _, (mu_test, A_test), (z_test, x_test), behavior_pred_test = model.forward(spikes_test)

In [None]:
# if behavior_pred_train is None:
# train the linear decoder for behavior
# create linear decoder
linear_decoder = LinearAccDecoder(input_dim=2)
optimizer = torch.optim.Adam(linear_decoder.parameters(), lr=0.01)
decoder_train_l, decoder_test_l = [], []
for epoch in range(5000):
    # forward pass        
    linear_decoder.train()
    behavior_pred = linear_decoder(x_train, z_train)
    # behavior_pred = linear_decoder(mu_train[:, :, :2], mu_train[:, :, 2:])
    loss = linear_decoder.loss(behavior_pred, behaviour_data_train)
    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()    
    epoch_loss = loss.item()    
    decoder_train_l.append(epoch_loss/len(train_loader))
    # test loss
    if (epoch+1) % 500 == 0:
        test_loss = test(model, test_loader)
        test_losses.append(test_loss)
        print('Epoch [{}/{}], Train Loss: {:.4f}, Test Loss: {:.4f}'.format(epoch+1, num_epochs, decoder_train_l[-1], test_losses[-1]))            
    # if epoch % 10 == 0:
    #     torch.save(model.state_dict(), os.path.join(base_path, 'vae_model_{}.pt'.format(epoch)))
    #     print('Model saved at epoch {}'.format(epoch))

with torch.no_grad():
    behavior_pred_train = linear_decoder(x_train, z_train).detach()
    behavior_pred_test = linear_decoder(x_test, z_test).detach()
    # behavior_pred_train = linear_decoder(mu_train[:, :, :2], mu_train[:, :, 2:])
    # behavior_pred_test = linear_decoder(mu_test[:, :, :2], mu_test[:, :, 2:])

In [None]:
# get predicted behaviour
# convert to numpy
pred_train = behavior_pred_train.numpy() > 0
pred_test = behavior_pred_test.numpy() > 0
# compute accuracy
accuracy_train_stim = accuracy_score(behaviour_data_train.numpy()[:, 0], pred_train[:, 0])
accuracy_test_stim = accuracy_score(behaviour_data_test.numpy()[:, 0], pred_test[:, 0])
print('Stimulus Accuracy - train: {:.4f}, test: {:.4f}'.format(accuracy_train_stim, accuracy_test_stim))
# do the same for choice
accuracy_train_choice = accuracy_score(behaviour_data_train.numpy()[:, 1], pred_train[:, 1])
accuracy_test_choice = accuracy_score(behaviour_data_test.numpy()[:, 1], pred_test[:, 1])
print('Choice Accuracy - train: {:.4f}, test: {:.4f}'.format(accuracy_train_choice, accuracy_test_choice))


In [None]:
# examine a random trial
trial_idx = np.random.randint(num_trials)
# plot z and x
fig, axs = plt.subplots(3, 1, figsize=(5, 8))
# plot z
axs[0].plot(t, z[trial_idx, :, 0], label='z0')
axs[0].plot(t, z[trial_idx, :, 1], label='z1')
axs[0].set_title('z')
axs[0].set_ylim(0, 1)
axs[0].legend()
axs[0].set_xticks([])
# plot num contacts
axs[1].plot(t, num_contacts[trial_idx])
axs[1].set_title('num contacts')
axs[1].set_xticks([])
# plot x
axs[2].plot(t, x[trial_idx, :, 0], label='x0')
axs[2].plot(t, x[trial_idx, :, 1], label='x1')
axs[2].set_title('x, stimulus: {}, choice: {}'.format(stim[trial_idx].astype(int), choice[trial_idx].astype(int)))
axs[2].set_ylim(-2, 2)
axs[2].legend()