Imports

In [1]:
import matplotlib.pyplot as plt
import torch
from scipy.stats import pearsonr
from tqdm import tqdm

from data.reshape_data import reshape, reshape_from_batches
from data.poisson_data_v import PoissonTimeShiftedData
from utils.funcs import pairwise_moments, get_reconstruction_mean_pairwise_correlations
from boltzmann_machines.cp_rtrbm import RTRBM

Simulation parameters

In [2]:
n_v = 100
n_h = 10
delay = 1  # temporal dynamics
frequency_range = [5, 10]
phase_range = [0, torch.pi]
amplitude_range = [0.4, 0.5]

hiddens_range = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
N = 5  # number of RTRBMs per run

In [None]:
rtrbm_list = []
rec_errors = torch.zeros(len(hiddens_range), N)
rec_corr = torch.zeros(len(hiddens_range), N)
mean_corr = torch.zeros(len(hiddens_range), N)
pw_corr = torch.zeros(len(hiddens_range), N)

for i, n_hidden in tqdm(enumerate(hiddens_range)):

    for n in range(N):

        # initialize random temporal connections
        temporal_connections = torch.randn(n_h, n_h) / n_h

        # get data
        gaus = PoissonTimeShiftedData(
            neurons_per_population=n_v//n_h, n_populations=n_h, n_batches=1, time_steps_per_batch=35000,
            fr_mode='gaussian', delay=delay, temporal_connections=temporal_connections, norm=1,
            frequency_range=frequency_range, amplitude_range=amplitude_range, phase_range=phase_range
        )

        # split data in train and test set
        data = reshape(gaus.data[..., 0], T=100, n_batches=350)
        train, test = data[..., :280], data[..., 280:]

        # initialize and train RTRBM
        rtrbm = RTRBM(train, N_H=n_hidden, device="cpu", debug_mode=False)
        rtrbm.learn(batch_size=10, n_epochs=2, max_lr=1e-3, min_lr=8e-4, lr_schedule='geometric_decay', CDk=10, mom=0.6, wc=0.0002, sp=0, x=1, disable_tqdm=True)

        # perform a plot to test if simulation went well
        if i == 0 and n == 0:
            fig, ax = plt.subplots(1, 4, figsize=(16, 4))
            ax[0].plot(rtrbm.errors)
            ax[0].set_xlabel('Epochs')
            ax[0].set_ylabel('RMSE')
            ax[1].imshow(gaus.temporal_connections, aspect='auto', cmap=plt.get_cmap('RdYlGn'))
            ax[1].set_title('True hidden connections')
            cm = ax[2].imshow(rtrbm.U, aspect='auto', cmap=plt.get_cmap('RdYlGn'))
            fig.colorbar(cm, ax=ax[2])
            ax[2].set_title('rtrbm.U')
            cm = ax[3].imshow(rtrbm.W, aspect='auto', cmap=plt.get_cmap('RdYlGn'))
            fig.colorbar(cm, ax=ax[3])
            ax[3].set_title('rtrbm.W')

        # infer over all test batches
        T, n_batches = test.shape[1], test.shape[2]
        vs = torch.zeros(n_v, T, n_batches)
        for batch in range(n_batches):
            vs[:, :, batch], _ = rtrbm.infer(test[:, :T//2, batch], mode=1, pre_gibbs_k=100, gibbs_k=100, disable_tqdm=True)

        # save rec error and correlations
        rec_errors[i, n] = torch.mean((vs[:, :T//2, :] - test[:, :T//2, :])**2)
        rec_corr[i, n], mean_corr[i, n], pw_corr[i, n] = get_reconstruction_mean_pairwise_correlations(test[:, T//2:, :], vs[:, T//2:, :])

        # save rtrbm temporarily in list
        rtrbm_list += [rtrbm]

0it [00:00, ?it/s]