In [6]:
import pickle as pkl
import numpy as np
import yaml
import naplib as nl
from sklearnex import patch_sklearn
patch_sklearn()

from sklearn.linear_model import Ridge

Extension for Scikit-learn* enabled (https://github.com/uxlfoundation/scikit-learn-intelex)


In [7]:
with open('prep_response.npy', 'rb') as f:
    small_response = np.load(f)
with open('prep_spec.npy', 'rb') as f:
    spec = np.load(f)

with open("code_params.yaml", "r") as file:
    code_params = yaml.load(file, yaml.Loader)

In [None]:
# apply STRF

tmin = 0
tmax = 0.3

strf_model = nl.encoding.TRF(
    tmin, tmax, code_params["sr_response"], estimator=Ridge(10), show_progress=True
)

try:
	with open('test_coefs_full.pkl', 'rb') as file:
		coef_ridge = pkl.load(file)
except FileNotFoundError:    
    print(spec.shape)
    print(small_response.shape)

    strf_model.fit(X=spec, y=small_response)
    coef_ridge = strf_model.coef_
    with open('test_coefs_full.pkl', 'wb') as file:
        pkl.dump(coef_ridge, file)

(50, 15, 129)
(50, 15, 25840)


  0%|          | 0/25840 [00:00<?, ?it/s]

In [None]:
pixels = [1, 120, 150, 240]
avg_strf_list = []

fig, axes = plt.subplots(1, 4, figsize=(6, 2.5))
for i, pixel in enumerate(pixels):
    model_1_coef = coef_ridge[pixel]
    strf_plot(model_1_coef, tmin=tmin, tmax=tmax, ax=axes[i])
    axes[i].set_title(f"Ridge, Pixel {pixel}")
fig.tight_layout()
plt.show()

In [None]:
avg_strf_list = np.array([np.mean(coef) for coef in coef_ridge])
print(coef_ridge.shape)
avg_response = np.mean(small_response, axis=(0, 1))
avg_strf_list = avg_strf_list.reshape((small_height, small_width))
avg_strf_list = (avg_strf_list - np.mean(avg_strf_list)) / (
    np.max(avg_strf_list) - np.min(avg_strf_list)
)
plt.imshow(avg_strf_list, cmap="bwr", vmin=-1, vmax=1)
plt.colorbar()
# plt.savefig(f"../thesis_text/Imgs/m{animal_id}_r{recording_id}_t{trial_id}/avg_strf.png")

In [None]:
from IPython.display import Video
from matplotlib import animation, cm

temporal_avg_strf_list = np.mean(coef_ridge, axis=1)
temporal_avg_strf_list = (temporal_avg_strf_list - np.mean(temporal_avg_strf_list)) / (
    np.max(temporal_avg_strf_list) - np.min(temporal_avg_strf_list)
)
temporal_avg_strf_list = np.transpose(temporal_avg_strf_list)
temporal_avg_strf_list = temporal_avg_strf_list.reshape(-1, small_height, small_width)

frames = []
fig = plt.figure()
for frame in temporal_avg_strf_list:
    frames.append([plt.imshow(frame, cmap="bwr", vmin=-1, vmax=1, animated=True)])
ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True, repeat_delay=1000)
plt.show()