#Create toy data

We will create a simulation of brain waves. Such data might arise when when using multiple electrodes to record electrical brain waves over a period of time. Specifically, we are simulating two different conditions, each of which leads to a different pattern of wave propagation.


In [1]:
import numpy as np
import matplotlib.pyplot as plt

T = 5  # trial length in seconds
fs = 20  # sampling rate in Hz
f = 1  # wave frequency in Hz
t = np.linspace(0, T, T * fs, endpoint=False)
wave = np.exp(1j * 2 * np.pi * f * t)  # an oscillation in time
n_channels = 10
p1 = np.exp(1j * np.linspace(0, np.pi * 2, n_channels, endpoint=False))  # Create phase delays across electrodes (pattern 1)
p2 = np.exp(1j * np.linspace(0, -np.pi * 2, n_channels, endpoint=False))  # Create phase delays across electrodes (pattern 2)
p1 = p1 / np.linalg.norm(p1)
p2 = p2 / np.linalg.norm(p2)

data = np.vstack((np.outer(wave, p1), np.outer(wave, p2)))  # concatenate two different patterns of wave propagation into a single data set.
data = data + (np.random.randn(data.shape[0], data.shape[1]) + 1j * np.random.randn(data.shape[0], data.shape[1])) / 10  # add complex-valued noise
labels = np.hstack((np.zeros(t.size), np.ones(t.size)))  # label samples for each condition as 0 and 1, respectively
inds_train = np.hstack((t < 4, t < 4))  # define train indices corresponding to the first 80% of each condition
inds_test = ~inds_train  # define test indices corresponding to the last 20% of each condition

proj1 = data @ p1.conj()  # project data onto p1. Remember to take p1's conjugate
proj2 = data @ p2.conj()  # project data onto p1. Remember to take p2's conjugate

fig, axs = plt.subplots(3, 1, figsize=[10, 5]);

axs[0].imshow(data.real.T, aspect='auto');
axs[0].set_ylabel('Channel #');
axs[0].set_xticks([])
axs[0].set_yticks([0, n_channels - 1])
axs[0].set_title('Data');
# axs[0].set_ylabel('Time step');
axs[1].plot(proj1.real, 'b');  # Plot the real component of proj1
axs[1].plot(proj1.imag, 'b--');  # Plot the imaginary component of proj1
axs[1].plot(np.abs(proj1), 'r');  # Plot the magnitude of proj1
axs[1].autoscale(enable=True, axis='both', tight=True)
axs[1].legend(['Real', 'Imag', 'Abs'], loc='right')
axs[1].set_xticks([])
axs[1].set_title('Pattern 1 projection');
axs[2].plot(proj2.real, 'b');  # Plot the real component of proj2
axs[2].plot(proj2.imag, 'b--');  # Plot the imaginary component of proj2
axs[2].plot(np.abs(proj2), 'r');  # Plot the magnitude of proj2
axs[2].autoscale(enable=True, axis='both', tight=True)
axs[2].set_xlabel('Time step');
axs[2].set_title('Pattern 2 projection');


# Install TIMBRE

Now, we will create a neural network (named TIMBRE) that learns multi-channel patterns in data that best predict the label associated with each sample.

In [6]:
!pip install -e TIMBRE

Obtaining file:///C:/Users/infin/Box/college/TIMBRE/TIMBRE
Collecting tensorflow>=2.0.0 (from beatLab==0.1)
  Downloading https://files.pythonhosted.org/packages/55/d1/a3631a36859ee324e1767fa7554fdf7af17965571d8537b20b311b76bcfe/tensorflow-2.11.0-cp37-cp37m-win_amd64.whl
Collecting hdf5storage (from beatLab==0.1)
  Downloading https://files.pythonhosted.org/packages/ec/29/ed9f2df3e77400b5312787b4ade31791e8eca91a39a7ccd80677490f4ea5/hdf5storage-0.1.19-py2.py3-none-any.whl (53kB)
Collecting numpy==1.26.4 (from beatLab==0.1)


  ERROR: Could not find a version that satisfies the requirement numpy==1.26.4 (from beatLab==0.1) (from versions: 1.3.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, 1.6.1, 1.6.2, 1.7.0, 1.7.1, 1.7.2, 1.8.0, 1.8.1, 1.8.2, 1.9.0, 1.9.1, 1.9.2, 1.9.3, 1.10.0.post2, 1.10.1, 1.10.2, 1.10.4, 1.11.0, 1.11.1, 1.11.2, 1.11.3, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 1.13.3, 1.14.0, 1.14.1, 1.14.2, 1.14.3, 1.14.4, 1.14.5, 1.14.6, 1.15.0, 1.15.1, 1.15.2, 1.15.3, 1.15.4, 1.16.0, 1.16.1, 1.16.2, 1.16.3, 1.16.4, 1.16.5, 1.16.6, 1.17.0, 1.17.1, 1.17.2, 1.17.3, 1.17.4, 1.17.5, 1.18.0, 1.18.1, 1.18.2, 1.18.3, 1.18.4, 1.18.5, 1.19.0, 1.19.1, 1.19.2, 1.19.3, 1.19.4, 1.19.5, 1.20.0, 1.20.1, 1.20.2, 1.20.3, 1.21.0, 1.21.1, 1.21.2, 1.21.3, 1.21.4, 1.21.5, 1.21.6)
ERROR: No matching distribution found for numpy==1.26.4 (from beatLab==0.1)


# Run TIMBRE on simulated data

We will train the network. By defaut, the network has one node per class. Since there are two classes, there are two nodes.

In [3]:
from TIMBRE.TIMBRE import TIMBRE

m, fm, _ = TIMBRE(data, labels, inds_test, inds_train)  # train neural network without hidden layer

ModuleNotFoundError: No module named 'keras'

#Examine performance

We will observe how the network's performance improves during training.

In [None]:
from TIMBRE.helpers import layer_output

fig,axs = plt.subplots(1,2,figsize=(10,5));
axs[0].plot(fm.history['accuracy']);
axs[0].plot(fm.history['val_accuracy']);
axs[0].legend(['Train', 'Test']);
axs[0].set_title('Accuracy');
axs[0].set_xlabel('Training epoch');
axs[1].plot(fm.history['loss']);
axs[1].plot(fm.history['val_loss']);
axs[1].legend(['Train', 'Test']);
axs[1].set_title('Loss');
axs[1].set_xlabel('Training epoch');

# Visualize network activity

Finally, we visualize the response of the trained model's layers to the input:
1.  **Complex-valued projection of the input.** Note that each of the two nodes learns one of the two patterns present in the data.
2.  **Amplitude of the projection.** This discards the phase of the projection, so we know how much of each pattern is present.
3.  **Softmax of the amplitude.** This converts the response to a probability distribution that sums to 1.

In [None]:
titles = ['Projection', 'Amplitude', 'Softmax', 'Softmax 2'];
fig1, axs1 = plt.subplots(1, len(m.layers), figsize=(20, 5));
styles = ['b', 'r', 'b--', 'r--']
X = np.concatenate((np.real(data), np.imag(data)), axis=1)  # preprocess
for i in range(len(m.layers)):  # plot the output of each layer in network
    pr = layer_output(data, m, i)
    for j in range(pr.shape[1]):
        axs1[i].plot(pr[:, j], styles[j]);
    axs1[i].set_title(titles[i]);
    axs1[i].set_xlabel('Time step');
    axs1[i].autoscale(enable=True, axis='both', tight=True);

axs1[0].legend(['Node 1 (real)', 'Node 2 (real)', 'Node 1 (imag)', 'Node 2 (imag)']);
axs1[1].legend(['Node 1', 'Node 2']);
axs1[2].legend(['Node 1', 'Node 2']);