# A comparison on how the cRBM and the cRTRBM explain the underlying structure and dynamics in zebrafish data

In [3]:
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt
import pickle
from matplotlib.pyplot import figure
import seaborn as sns
from tqdm import tqdm
import h5py
plt.rcParams['figure.figsize'] = [8, 5]

import sys
sys.path.append(r'D:\OneDrive\RU\Intern\rtrbm_master')

from utils.plots import *
from boltzmann_machines.RTRBM import RTRBM
from boltzmann_machines.RBM import RBM
from utils.funcs import *
from utils.visualize_hidden_network import *
from utils.create_param_class import Parameters
from utils.reshape_data import *


import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
path = os.path.dirname(os.getcwd())

num_data_sets = 18
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))
torch.cuda.set_device(1)
torch.cuda.current_device()


ModuleNotFoundError: No module named 'utils.plots'; 'utils' is not a package

## Import zebrafish surrogate data, train cRBM and cRTRBM and save the parameters


In [None]:
# initialize class to save parameters
cRBM_parameters = Parameters()
cRTRBM_parameters = Parameters()

# define number of hiddens and epochs  
N_H = 40
n_epochs = 300
num_data_sets = 18

for i in tqdm(range(1,num_data_sets+1)):
    
    # import data
    data = h5py.File('/mnt/data/zebrafish/chen2018/subject_' +str(i)+'/Deconvolved/subject_'+str(i)+'_reconv_spikes.h5', 'r')

    # define coordinates and spikes
    spikes = torch.tensor(data['Data']['spikes'])
    xyz = torch.tensor(data['Data']['coords'])

    # disregard neurons that dont fire 
    xyz = xyz[torch.sum(spikes,1)!=0, :]
    spikes = spikes[torch.sum(spikes,1)!=0, :]

    # reduce dataset by taking voxels
    [voxel_spike, voxel_xyz] =  make_voxel_xyz(n = 25, spikes = spikes, xyz = xyz , mode = 1, fraction = 0.1, disable_tqdm = True)

    # make the spiking behaviour binairy
    spike_thres = torch.sort(voxel_spike.ravel(), descending=True)[0][int(np.ceil(0.15*(voxel_spike.shape[0]*voxel_spike.shape[1])))]
    voxel_spike[voxel_spike<=spike_thres] = 0
    voxel_spike[voxel_spike>spike_thres] = 1
    N_V, T = voxel_spike.shape

    # reshape data in train and test batches
    train_data, test_data = generate_train_test(voxel_spike, train_data_ratio=0.75, mode=1)

    # transfer dataset to cudo GPU:1
    torch.cuda.set_device(1)
    device='cuda:1'
    train_data = torch.tensor(train_data, device=device)
    test_data = torch.tensor(test_data, device=device)
    
    # define cRBM and train on GPU:1
    cRBM = RBM(train_data, N_H=N_H, device = device)
    cRBM.learn(n_epochs=n_epochs, lr=1e-4, sp=3e-4, x=1, batchsize= 1, disable_tqdm = True)
    cRBM.add_test_data_to_class(test_data)
    cRBM.add_xyz_to_class(voxel_xyz)
    
    # CRBM
    pickle.dump(cRBM, open(path + '/Results/cRBM_40HU_lr1e-4_sp3e-4_x1_subject_'+str(i), 'wb'))

    # transfer dataset in batches to cudo GPU:0
    torch.cuda.set_device(0)
    device='cuda:0'
    train_data = torch.tensor(train_data, device=device)
    test_data = torch.tensor(test_data, device=device)

    # define cRTRBM and train on GPU:0
    cRTRBM = RTRBM(train_data, N_H=N_H, device = device)
    cRTRBM.learn(n_epochs=n_epochs, lr=1e-4, sp=3e-5, x=2, batchsize=1, disable_tqdm = True)
    cRTRBM.add_test_data_to_class(test_data)
    cRTRBM.add_xyz_to_class(voxel_xyz)

    # Safe cRTRBM class
    pickle.dump(cRTRBM, open(path + '/Results/cRTRBM_40HU_lr1e-4_sp3e-5_x2_subject_'+str(i), 'wb'))


  train_data = torch.tensor(train_data, device=device)
  test_data = torch.tensor(test_data, device=device)
  train_data = torch.tensor(train_data, device=device)
  test_data = torch.tensor(test_data, device=device)
  6%|▌         | 1/18 [21:31<6:05:49, 1291.14s/it]

##  VH, spikes grouped by strongest connecting HU and hidden unit activity of the cRBM

In [None]:
for i in range(1,num_data_sets+1):
    cRBM = pickle.load(open(path + '/Results/cRBM_40HU_lr1e-4_sp3e-4_x1_subject_'+str(i), 'rb'))
    T = cRBM.data.shape[1]
    rt = torch.zeros(cRBM.N_H, T)
    for t in range(T):
        rt[:, t], _ = cRBM.visible_to_hidden(cRBM.data[:,t])
    

    plot_spikes_grouped_by_HU(VH=cRBM.W.cpu(), V=cRBM.data.cpu(), H=rt)

##  VH, spikes grouped by strongest connecting HU and hidden unit activity of the cRTRBM

In [None]:
for i in range(1,num_data_sets+1):
    device='cuda:0'
    cRTRBM = pickle.load(open(path + '/Results/cRTRBM_40HU_lr1e-4_sp3e-5_x2_subject_'+str(i), 'rb'))
    a = cRTRBM.V.shape
    rt = torch.zeros([cRTRBM.N_H, a[1]*a[2]])
    V = torch.zeros([a[0], a[1]*a[2]])
    for j in range(a[2]):
        rt[:,a[1]*j:a[1]*(j+1)] = cRTRBM.visible_to_expected_hidden(cRTRBM.V[:,:,j])
        V[:,a[1]*j:a[1]*(j+1)] = cRTRBM.V[:,:,j]
    plot_spikes_grouped_by_HU(VH=cRTRBM.W.cpu(), V=V.cpu(), H=rt.cpu())

## Receptive field cRBM

In [None]:
%matplotlib inline
import matplotlib as mpl

for subject in range(1, num_data_sets+1):
    cRBM = pickle.load(open(path + '/Results/cRBM_40HU_lr1e-4_sp3e-4_x1_subject_'+str(subject), 'rb'))

    VH = cRBM.W.cpu().clone()
    coordinates = cRBM.xyz
    
    VH[VH<0] = 0

    rf = get_hidden_mean_receptive_fields(VH, coordinates)
    fig, ax = plt.subplots(figsize=(6,6))

    ax.scatter(coordinates[:, 0], coordinates[:, 1], s =20, edgecolors = 'b')
    ax.scatter(rf[:, 0], rf[:, 1], s=20, edgecolors = 'r')

    ax.set_xlabel('x')
    ax.set_ylabel('y')
    plt.show()


In [None]:
import matplotlib as mpl

for subject in range(1, num_data_sets+1):
    cRBM = pickle.load(open(path + '/Results/cRBM_40HU_lr1e-4_sp3e-4_x1_subject_'+str(subject), 'rb'))

    VH = cRBM.W.cpu().detach().clone()
    coordinates = cRBM.xyz.detach().clone()
    fig, ax = plt.subplots(16, 4, figsize=(22, 70))
    max_hidden_connection = torch.max(VH, 0)[1]
    strong_thresh = 0.5 * torch.std(VH)
    h = 0
    for i in range(16):
        for j in range(4):
            idx_p = (max_hidden_connection==h)*(VH[h,:] >  strong_thresh)
            idx_m = (max_hidden_connection==h)*(VH[h,:] < -strong_thresh)
            #idx_not_p = (max_hidden_connection==h)*((VH[h,:] <=  strong_thresh) & (VH[h,:]>0))
            #idx_not_m = (max_hidden_connection==h)*((VH[h,:] >= -strong_thresh) & (VH[h,:]<0))
            
            ax[i,j].scatter(coordinates[:,0], coordinates[:,1], s=15, color = 'blue', alpha=0.01)

            ax[i,j].scatter(coordinates[idx_p,0], coordinates[idx_p,1], s=25, color = 'green', marker='^')
            ax[i,j].scatter(coordinates[idx_m,0], coordinates[idx_m,1], s=25, color = 'red', marker="v")

            #ax[i,j].scatter(coordinates[idx_not_p,0], coordinates[idx_not_p,1], s=25, color = 'red', marker='^', alpha=0.3)
            #ax[i,j].scatter(coordinates[idx_not_m,0], coordinates[idx_not_m,1], s=25, color = 'red', marker="v", alpha=0.3)
                
            h +=1
            
    #mpl.style.use('seaborn')

## Receptive field cRTRBM

In [None]:
%matplotlib inline
import matplotlib as mpl

for i in range(1, num_data_sets+1):
    cRTRBM = pickle.load(open(path + '/Results/cRTRBM_40HU_lr1e-4_sp3e-5_x2_subject_'+str(i), 'rb'))
    VH = cRTRBM.W.cpu().clone()
    coordinates = cRTRBM.xyz.cpu().clone()

    VH[VH<0] = 0

    rf = get_hidden_mean_receptive_fields(VH, coordinates)
    fig, ax = plt.subplots(figsize=(8,8))

    ax.scatter(coordinates[:, 0], coordinates[:, 1], s =20, edgecolors = 'b')
    ax.scatter(rf[:, 0], rf[:, 1], s=20, edgecolors = 'r')

    ax.set_xlabel('x')
    ax.set_ylabel('y')
    plt.show()



In [None]:
for subject in range(1, num_data_sets+1):
    cRTRBM = pickle.load(open(path + '/Results/cRTRBM_40HU_lr1e-4_sp3e-5_x2_subject_'+str(i), 'rb'))
    VH = cRTRBM.W.cpu().clone()
    coordinates = cRTRBM.xyz.cpu().clone()
    fig, ax = plt.subplots(16, 4, figsize=(22, 70))
    max_hidden_connection = torch.max(VH, 0)[1]
    strong_thresh = 0.5 * torch.std(VH)
    h = 0
    for i in range(16):
        for j in range(4):
            idx_p = (max_hidden_connection==h)*(VH[h,:] >  strong_thresh)
            idx_m = (max_hidden_connection==h)*(VH[h,:] < -strong_thresh)
            #idx_not_p = (max_hidden_connection==h)*((VH[h,:] <=  strong_thresh) & (VH[h,:]>0))
            #idx_not_m = (max_hidden_connection==h)*((VH[h,:] >= -strong_thresh) & (VH[h,:]<0))
            
            ax[i,j].scatter(coordinates[:,0], coordinates[:,1], s=15, color = 'blue', alpha=0.01)

            ax[i,j].scatter(coordinates[idx_p,0], coordinates[idx_p,1], s=25, color = 'green', marker='^')
            ax[i,j].scatter(coordinates[idx_m,0], coordinates[idx_m,1], s=25, color = 'red', marker="v")

            #ax[i,j].scatter(coordinates[idx_not_p,0], coordinates[idx_not_p,1], s=25, color = 'red', marker='^', alpha=0.3)
            #ax[i,j].scatter(coordinates[idx_not_m,0], coordinates[idx_not_m,1], s=25, color = 'red', marker="v", alpha=0.3)
                
            h +=1
            
    #mpl.style.use('seaborn')

In [None]:
#del create_plot, line_between_two_neurons
#from utils.visualize_hidden_network import create_plot

#create_plot(crtrbm.W, crtrbm.W_acc, rf, coordinates, dy=0.1, markersize_visibles=50, hiddens_radius=0.04)

## Compare the moments of the cRBM and the cRTRBM

In [None]:
for i in range(1, num_data_sets+1):
    cRTRBM = pickle.load(open(path + '/Results/cRTRBM_40HU_lr1e-4_sp3e-5_x2_subject_'+str(i), 'rb'))
    cRBM = pickle.load(open(path + '/Results/cRBM_40HU_lr1e-4_sp3e-4_x1_subject_'+str(i), 'rb'))
    idx = np.random.randint(0, cRTRBM.V.shape[2]+1)
    
    train_data = cRTRBM.V[:,:,idx].detach().clone()
    train_data = train_data[:,:,None]
    test_data = cRTRBM.test_data[:,:,idx].detach().clone()
    test_data = test_data[:,:,None]
    
    n_batches = train_data.shape[2]
    plot_compare_moments(cRBM, cRTRBM, train_data, test_data, MC_chains=n_batches, \
                         chain=50, pre_gibbs_k=10, gibbs_k=20, config_mode=2)