Import some libraries that we will need

In [None]:
%matplotlib inline
import scipy.io
from scipy.stats import stats
import numpy as np
import brainiak.funcalign.srm.SRM as SRM
import brainiak.funcalign.rsrm.RSRM as RSRM
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

Experiment setup

In [None]:
voxels = 100
samples = 200
subjects = 10
features = 3
snr = 20 # in dB

Now we create some synthetic data

In [None]:
# Create a Shared response R with K = 3
theta = np.linspace(-4 * np.pi, 4 * np.pi, samples)
z = np.linspace(-2, 2, samples)
r = z ** 2 + 1
x = r * np.sin(theta)
y = r * np.cos(theta)
curve = np.vstack((x, y, z))
print('Print curve max, min values:', np.max(curve), np.min(curve))

# Create the subjects' data
data = [None] * subjects
W = [None] * subjects
noise_level = 0.0
for s in range(subjects):
    R = curve
    W[s], _ = np.linalg.qr(np.random.randn(voxels, 3))
    data[s] = W[s].dot(R)
    noise_level += np.sum(np.abs(data[s])**2)

# Compute noise_sigma from desired SNR
noise_level = noise_level / (10 ** (snr / 10))
noise_level = np.sqrt(noise_level  / subjects / voxels / samples)
print(noise_level)

for s in range(subjects):
    n = noise_level * np.random.randn(voxels, samples)
    data[s] += n
    print(20 * np.log10(np.sum(np.abs(data[s])) / np.sum(np.abs(n)) ))

Now we fit the algorithms, SRM and RSRM, to the synthetic data

In [None]:
srm = SRM(features=3, n_iter=20)
srm.fit(data)


rsrm = RSRM(features=3, gamma=0.35, n_iter=20)
rsrm.fit(data)

The following function finds the orthogonal transform to align the shared response to the original curve.

In [None]:
def find_orthogonal_transform(shared_response):
    u,_,vt = np.linalg.svd(shared_response.dot(curve.T))
    q = u.dot(vt)
    aligned_curve = q.T.dot(shared_response)
    return aligned_curve

Plot the results

In [None]:
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot(curve[0, :], curve[1, :], curve[2, :], '-g', label='original', lineWidth=5)

proj = find_orthogonal_transform(srm.s_)
ax.plot(proj[0, :], proj[1, :], proj[2, :], '-b', label='SRM', lineWidth=3)
proj = find_orthogonal_transform(rsrm.r_)
ax.plot(proj[0, :], proj[1, :], proj[2, :], '-r', label='RSRM', lineWidth=3)

