# Figure 3 - Model training on whole-brain recordings of spontaneous activity of larval zebrafish

Light-sheet fluorescence microscopy is an innovative imaging technique that enables the recording of whole-brain activity of small vertebrates (larval zebrafish) at the neuronal level (~40K neurons) simultaneously. The observed patterns of spontaneous activity signify a random exploration of the neuronal state space, which is restricted by the underlying assembly organization of neurons.

In order to analyze these patterns, the cRBM is utilized to identify physiologically meaningful neural assemblies that combine to form successive brain states. Additionally, the RTRBM is used with transfer learning to determine suggestive temporal connections between these assemblies. The cRBM and RTRBM are also used to accurately replicate the mean activity and pairwise correlation, as well as pairwise time-shifted correlation statistics of the recordings using a limited number of parameters. These statistics are then used to compare the performance of the cRBM and RTRBM. Approximately 200 such neural assemblies are identified and analyzed using these techniques.

This paper is the source of both the data and the cRBM model:
van der Plas, T., Tubiana, J., Le Goc, G., Migault, G., Kunst, M., Baier, H., Bormuth, V., Englitz, B. & Debregeas, G. (2021) Compositional restricted boltzmann machines unveil the brain-wide organization of neural assemblies, bioRxiv

## Loading data and packages

In [None]:
# Enter path to PGM package directory
# This pacakge can be downloaded at https://github.com/jertubiana/PGM
# Please follow installation instructions in README.md
PGM_dir_path = ''  
PGM_dir_path = 'C:/Users/luukh/OneDrive/Intern/PGM/'

In [None]:
import os, sys
import torch
import numpy as np
import h5py
import seaborn as sns
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

from numba import set_num_threads
from numba import njit,prange
set_num_threads(8) # Set the number of cores. Must be executed before importing numpy&numba.

sys.path.append(PGM_dir_path + '/source/')
sys.path.append(PGM_dir_path + '/utilities/')
import RBM_utils
import rbm, utilities

from zebrafish_rtrbm.models.RBM import RBM
from zebrafish_rtrbm.models.RTRBM import RTRBM

from zebrafish_rtrbm.utils.data_methods import reshape

## Load data

In [None]:
def load_dataset(path2dataset):
    f = h5py.File(path2dataset, 'r')
    labels = f['Data']['Brain']['Labels'][:].T.astype('bool')
    Coordinates = f['Data']['Brain']['Coordinates'][:].T # Spatial coordinates
    Labels = f['Data']['Brain']['Labels'][:].T.astype('bool')
    Spikes = f['Data']['Brain']['Analysis']['ThresholdedSpikes'][:].astype('bool')
    f.close()

    mask = Labels.max(-1) # Discard neurons not mapped to Zbrain atlas.
    Spikes = Spikes[:,mask]
    Coordinates = Coordinates[mask]

    return Spikes,Coordinates

In [None]:
# data_dir = '../data/figure3_zebrafish'  # Directory path to save the model files to
data_dir = 'C:/Users/luukh/OneDrive/Intern/RTRBM/data/figure3_zebrafish/'  # Directory path to save the model files to
# models_dir = 'C:/Users/luukh/OneDrive/Intern/RTRBM/models/figure3_zebrafish/'

list_datasets = [
    'fish1_20180706_Run04',
    'fish2_20180911_Run01',
    'fish3_20180912_Run01',
    'fish4_20180913_Run01',
    'fish5_20190109_Run04',
    'fish6_20181206_Run03',
    'fish7_20190102_Run01',
    'fish8_20181206_Run05',
]

In [None]:
dataset_idx = 3 # The dataset used.
path2dir = os.getcwd() + '/RTRBM/'

list_datasets = [
    'fish1_20180706_Run04',
    'fish2_20180911_Run01',
    'fish3_20180912_Run01',
    'fish4_20180913_Run01',
    'fish5_20190109_Run04',
    'fish6_20181206_Run03',
    'fish7_20190102_Run01',
    'fish8_20181206_Run05',
]

dataset_idx = 3# The dataset used.
dataset = list_datasets[dataset_idx]

Spikes, Coordinates = load_dataset(data_dir + '/fish%s/rbm_%s.h5'%(dataset_idx+1, dataset))

T, n_v = Spikes.shape
print('Recording has %s time frames and %s neurons' %(T, n_v))


## Data splitting

The train/test split is defined by dividing the recording, with a length of T, into 10 chronological segments of equal length. The train batch consists of segments 1, 3, 4, 5, 8, 9, and 10, while the test batch consists of segments 2, 6, and 7. Since the neural recordings are from separate fish that are in different brain states at the beginning and during the recordings, it is not necessary for each fish to have a different train/test split segments.

In [None]:
n_hidden_units = {
'fish1_20180706_Run04':200,
'fish2_20180911_Run01':200,
'fish3_20180912_Run01':200,
'fish4_20180913_Run01':200,
'fish5_20190109_Run04':100,
'fish6_20181206_Run03':200,
'fish7_20190102_Run01':100,
'fish8_20181206_Run05':100,
                }

learning_rates = {
'fish1_20180706_Run04':1e-3,
'fish2_20180911_Run01':1e-3,
'fish3_20180912_Run01':1e-3,
'fish4_20180913_Run01':2.5e-4,
'fish5_20190109_Run04':1e-4,
'fish6_20181206_Run03':2.5e-4,
'fish7_20190102_Run01':1e-4,
'fish8_20181206_Run05':2.5e-4,
}

batch_sizes = {
'fish1_20180706_Run04':400,
'fish2_20180911_Run01':100,
'fish3_20180912_Run01':100,
'fish4_20180913_Run01':100,
'fish5_20190109_Run04':400,
'fish6_20181206_Run03':100,
'fish7_20190102_Run01':135,
'fish8_20181206_Run05':100,
}

In [None]:
def train_test_split(data):
    batch_size = data.shape[1] // 10
    train = torch.zeros(data.shape[0], batch_size, 7)
    test = torch.zeros(data.shape[0], batch_size, 3)
    batch_index_shuffled = [0, 2, 3, 4, 7, 8, 9, 1, 5, 6]
    i = 0

    for batch in range(10):
        j = batch_index_shuffled[batch]
        if batch < 7:
            train[:, :, batch] = data[:, j * batch_size:(j + 1) * batch_size]
        if batch >= 7:
            test[:, :, batch-7] = data[:, j * batch_size:(j + 1) * batch_size]

    return train, test

### cRBM training

In [None]:
data_dir = '../data/figure3_zebrafish'

In [None]:
path2crbm = data_dir + '/fish%s/rbm_%s_test.data'%(dataset_idx+1, dataset)

n_h = n_hidden_units[dataset] # Number of hidden units.
l1 = 0.02 # Sparse regularization strength.
learning_rate = learning_rates[dataset] # Initial learning rate.
batch_size = batch_sizes[dataset] # Batch size / number of MCMC chains.
N_MC = 15 # Number of alternate Gibbs sampling steps performed between each gradient calculation (PCD algorithm)
n_updates = 1 # Total number of gradient descent updates performed.
RBM = rbm.RBM(n_v = n_v, # Number of visible units (neurons).
              n_h = n_h, # Number of hidden units.
              visible = 'Bernoulli', # Nature of visible units (Bernoulli = 0/1 values)
              hidden = 'dReLU' # Nature of hidden units. double Rectified linear Units potential. hidden='Gaussian' reproduces the Hopfield model.
             )

# 2, 6, 7 test sets
train, test = train_test_split(torch.tensor(Spikes.T))
train, test = np.array(reshape(train)).T, np.array(reshape(test)).T

n_iter = (n_updates // (train.shape[0] // batch_size))    # Number of epochs
print('Starting fit, %s epochs' % n_iter)

RBM.fit(train != 0,
        l1=l1, # sparse l1 regularization.
        n_iter=n_iter, # Number of epochs.
        learning_rate=learning_rate, # The learning rate
        batch_size=batch_size, # Batch size.
        N_MC=N_MC, # Number of MCMC steps.
        verbose=1,
        vverbose=0,
       )
print('Finished fit, %s epochs' % n_iter)
RBM = RBM_utils.swap_sign_RBM(RBM)
RBM_utils.saveRBM(path2crbm, RBM)

## RTRBM training

In [None]:
path2rtrbm = data_dir + '/fish%s/rtrbm_%s_transfer'%(dataset_idx+1, dataset)

# 2, 6, 7 test sets
train, test = train_test_split(torch.tensor(Spikes.T))
train, test =reshape(train), reshape(test)

train, test = reshape(train, T=T[dataset], n_batches=n_batches_train[dataset]), reshape(test, T=T[dataset], n_batches=n_batches_test[dataset])
rtrbm = RTRBM(train, n_hidden=RBM.weights.shape[0])
rtrbm.W = torch.tensor(RBM.weights, dtype=torch.float, device=rtrbm.device)
rtrbm.params = [rtrbm.W, rtrbm.U, rtrbm.b_h, rtrbm.b_v, rtrbm.b_init]
rtrbm.learn(n_epochs=5000, max_lr=1e-3, min_lr=5e-6, lr_schedule='geometric_decay',
            batch_size=batch_sizes[dataset], CDk=15, mom=0.9, wc=0.0002, sp=1e-6, x=2, n=1000)
torch.save(rtrbm, path2rtrbm)
