# Run `ProSVD`

In [82]:
import numpy as np
import matplotlib.pylab as plt
import scipy.signal as signal
import mat73 # for loading .mat files
import scipy.io as sio

# from bubblewrap import Bubblewrap
from proSVD import proSVD
from tqdm import tqdm

In [80]:
data = sio.loadmat(r'spont_M150824_MP019_20160405.mat', squeeze_me=True, variable_names='Fsp')['Fsp']
data = data.T

framerate = 3 # Hz
bin_size_ms = 1/framerate * 1000

In [98]:
#%%
# proSVD params
k = 4 # reduced dimension
trueSVD = False # whether proSVD should track true SVD basis (a little slower)
l1 = 100 # columns used to initialize
decay = 1 # 1 = effective window is all of data
num_iters = np.floor((data.shape[0] - l1 - l)/l).astype('int')
update_times = np.arange(l1, num_iters*l, l) # index of when updates happen (not including init)

In [99]:
# smoothing params
kern_sd = 500
smooth_filt = signal.gaussian(int(6 * kern_sd / bin_size_ms), int(kern_sd / bin_size_ms), sym=True)
smooth_filt /=  np.sum(smooth_filt)
data_init_smooth = np.apply_along_axis(lambda x, filt: np.convolve(x, filt, 'same'),
                                       0, data[:l1, :], filt=smooth_filt)





# initialize proSVD

data_init = data[:l1, :]

pro = proSVD(k=k, decay_alpha=decay, trueSVD=trueSVD, history=0)
pro.initialize(data_init_smooth.T)

# storing dimension-reduced data
data_red = np.zeros((data.shape[0], k))
data_red[:l1, :] = data_init_smooth @ pro.Q

# run online
# pro_end = int(data.shape[0]/8) # when proSVD will stop updating and bubblewrap starts
pro_diffs = []
smooth_window = data[l1-len(smooth_filt):l1, :]
for i, t in enumerate(tqdm(update_times)):
    start, end = t, t+pro.w_len
    dat_curr = data[start:end, :]
    smooth_window[:-1, :] = smooth_window[1:, :]
    smooth_window[-1, :] = dat_curr
    dat_smooth = smooth_filt @ smooth_window

    # proSVD updates
    # if t < pro_end:  # only update for first n steps
    pro.preupdate()
    pro.updateSVD(dat_smooth[:,None])
    pro.postupdate()
    pro_diffs.append(np.linalg.norm(pro.Q-pro.Q_prev, axis=0))

    # getting projected data
    data_red[start:end, :] = dat_smooth @ pro.Q

100%|███████████████████████████████████| 20854/20854 [00:14<00:00, 1417.41it/s]


In [33]:
# # smoothing params
# kern_sd = 50
# smooth_filt = signal.gaussian(int(6 * kern_sd / bin_size_ms), int(kern_sd / bin_size_ms), sym=True)
# smooth_filt /=  np.sum(smooth_filt)
# data_init_smooth = np.apply_along_axis(lambda x, filt: np.convolve(x, filt, 'same'),
#                                        0, data[:l1, :], filt=smooth_filt)

# initialize proSVD


data_init = data[:l1, :]

pro = proSVD(k=k, decay_alpha=decay, trueSVD=trueSVD, history=0)
pro.initialize(data_init.T)

# storing dimension-reduced data
data_red = np.zeros((data.shape[0], k))
data_red[:l1, :] = data_init @ pro.Q

# run online
# pro_end = int(data.shape[0]/8) # when proSVD will stop updating and bubblewrap starts
pro_diffs = []
# smooth_window = data[l1-len(smooth_filt):l1, :]
for i, t in enumerate(tqdm(update_times)):
    start, end = t, t+pro.w_len
    dat_curr = data[start:end, :]
    # smooth_window[:-1, :] = smooth_window[1:, :]
    # smooth_window[-1, :] = dat_curr
    # dat_smooth = smooth_filt @ smooth_window

    # proSVD updates
    # if t < pro_end:  # only update for first n steps
    pro.preupdate()
    pro.updateSVD(dat_curr.T)
    pro.postupdate()
    pro_diffs.append(np.linalg.norm(pro.Q-pro.Q_prev, axis=0))

    # getting projected data
    data_red[start:end, :] = dat_curr @ pro.Q

100%|███████████████████████████████████| 20854/20854 [00:17<00:00, 1167.94it/s]


In [94]:
%matplotlib qt
fig, ax = plt.subplots()
pro_diffs = np.array(pro_diffs)
ax.plot(np.arange(pro_diffs.shape[0])/framerate, pro_diffs)
ax.set( #title='proSVD stabilizes with <1 min of data',
       xlabel='seconds of data seen',
       ylabel=r'$|\Delta\mathbf{Q}|$');

In [54]:
beh = sio.loadmat(r'spont_M150824_MP019_20160405.mat', squeeze_me=True, struct_as_record=False, variable_names='beh')['beh']

In [101]:
obs = np.hstack([beh.runSpeed[:,None], beh.pupil.area[:,None], beh.face.motionSVD[:,:10]])

In [102]:
# y: (1, n, d)
# x: (n, d2)
np.savez("mouse_stringer_smoothed_4d.npz", x=obs, y=data_red[None,:,:])

In [75]:
data_red.shape

(21055, 10)

In [78]:
u,s,vh = np.linalg.svd(data, full_matrices=False)

In [79]:
plt.plot(s)

[<matplotlib.lines.Line2D at 0x7f01b3f89940>]