In [None]:
import matplotlib.pyplot as plt
from matplotlib import colormaps as cm
import numpy as np
import cv2
from PIL import Image
import os

from sklearn.linear_model import Ridge, ElasticNet

import naplib as nl
from naplib.visualization import strf_plot
import thesis

In [None]:
# thesis.utils.convert_to_npy("response")

In [None]:
raw_response = thesis.load_data.load_response()
raw_stimulus = thesis.load_data.load_stimulus()
height, width = int(raw_response.shape[2]), int(raw_response.shape[3])

In [None]:
from math import gcd

code_params = {
    "block_size": gcd(width, height),
    "strf_fit_dtype": np.float64,
    "strf_fit_batch_size": 10,
    "sr_audio": 250000,
    "sr_response": 100,
    "max_lag": 0.05,
    "end_stim": 37500 / 250000,
    "max_epochs": 150,
}

code_params["block_size"] = (10, 10)
code_params["strf_fit_dtype"] = np.float16

In [None]:
# # make movie
# frame_size = test_data.shape[1:3]

# out_lossless = cv2.VideoWriter('test_video.mkv',cv2.VideoWriter_fourcc(*'FFV1'), 100, (frame_size[1], frame_size[0]))
# out_lossy = cv2.VideoWriter('test_video_lossy.mkv',cv2.VideoWriter_fourcc(*'VP90'), 100, (frame_size[1], frame_size[0]))

# cm_test_data = np.copy(test_data)

# bwr_cm = cm.get_cmap('bwr')
# cm_test_data = bwr_cm((test_data - np.min(test_data)) / (np.max(test_data) - np.min(test_data)))

# cm_test_data = (cm_test_data[:, :, :, :3]*255).astype(np.uint8)

# example_images = 100, 250, 400

# for index, frame in enumerate(cm_test_data):
#     if index in example_images:
#         cv2.imwrite(f"images/frame_{index}.png", frame)
#     out_lossless.write(frame)
#     out_lossy.write(frame)
# out_lossless.release()
# out_lossy.release()

In [None]:
delay = int(code_params["sr_response"] * code_params["max_lag"])
stimulus = raw_stimulus[:, : int(code_params["end_stim"] * code_params["sr_audio"])]
response = raw_response[
    :,
    200
    + delay : 200
    + int(code_params["end_stim"] * code_params["sr_response"] + delay),
    :,
    :,
]
print(response.shape)
small_response = thesis.preprocessing.smallify_response(
    response, code_params["block_size"]
)
small_height, small_width = small_response.shape[2], small_response.shape[3]
small_response = small_response.reshape(
    small_response.shape[0], small_response.shape[1], -1
)

In [None]:
import scipy as sp
from sklearn.preprocessing import StandardScaler

# clip response to stimulus start and normalize
# = MinMaxScaler()
f, t, Sxx = sp.signal.spectrogram(np.sum(stimulus, axis=0), fs=code_params["sr_audio"])
plt.pcolormesh(t, f, Sxx, cmap="bwr")
plt.colorbar()
plt.show()

In [None]:
spec = thesis.generate.generate_spectrogram(
    stimulus, response, sr_audio=code_params["sr_audio"]
)
print(spec.shape)

In [None]:
# apply STRF
import pickle as pkl

tmin = 0
tmax = 0.3

strf_model = nl.encoding.TRF(
    tmin, tmax, code_params["sr_response"], estimator=Ridge(10), show_progress=True
)
# resample
print(f"Size of raw stimulus is: {stimulus[0].shape}")
print(f"Size of audio spectrogram is: {Sxx.shape}")
print(f"After resampling: {spec.shape}")


def batch(X, y, batch_size):
    for i in range(0, len(X), batch_size):
        yield X[i : i + batch_size], y[i : i + batch_size]


try:
    with open("test_coefs.pkl", "rb") as file:
        coef_ridge = pkl.load(file)
except FileNotFoundError:
    print(spec.shape)
    print(small_response.shape)

    strf_X = np.array(spec, code_params["strf_fit_dtype"])
    strf_y = np.array(small_response, code_params["strf_fit_dtype"])

    strf_model.fit(X=strf_X, y=strf_y)
    coef_ridge = strf_model.coef_
    with open("test_coefs.pkl", "wb") as file:
        pkl.dump(strf_model.coef_, file)

In [None]:
pixels = [1, 120, 150, 240]
avg_strf_list = []

fig, axes = plt.subplots(1, 4, figsize=(6, 2.5))
for i, pixel in enumerate(pixels):
    model_1_coef = coef_ridge[pixel]
    strf_plot(model_1_coef, tmin=tmin, tmax=tmax, ax=axes[i])
    axes[i].set_title(f"Ridge, Pixel {pixel}")
fig.tight_layout()
plt.show()

In [None]:
avg_strf_list = np.array([np.mean(coef) for coef in coef_ridge])
print(coef_ridge.shape)
avg_response = np.mean(small_response, axis=(0, 1))
avg_strf_list = avg_strf_list.reshape((small_height, small_width))
avg_strf_list = (avg_strf_list - np.mean(avg_strf_list)) / (
    np.max(avg_strf_list) - np.min(avg_strf_list)
)
plt.imshow(avg_strf_list, cmap="bwr", vmin=-1, vmax=1)
plt.colorbar()
# plt.savefig(f"../thesis_text/Imgs/m{animal_id}_r{recording_id}_t{trial_id}/avg_strf.png")

In [None]:
from IPython.display import Video
from matplotlib import animation, cm

temporal_avg_strf_list = np.mean(coef_ridge, axis=1)
temporal_avg_strf_list = (temporal_avg_strf_list - np.mean(temporal_avg_strf_list)) / (
    np.max(temporal_avg_strf_list) - np.min(temporal_avg_strf_list)
)
temporal_avg_strf_list = np.transpose(temporal_avg_strf_list)
temporal_avg_strf_list = temporal_avg_strf_list.reshape(-1, small_height, small_width)

frames = []
fig = plt.figure()
for frame in temporal_avg_strf_list:
    frames.append([plt.imshow(frame, cmap="bwr", vmin=-1, vmax=1, animated=True)])
ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True, repeat_delay=1000)
plt.show()

## DSTRF

In [None]:
from scipy.io import loadmat
import matplotlib.pyplot as plt
from matplotlib import colormaps as cm
import numpy as np
import cv2
from PIL import Image

from scipy.signal import resample, chirp
from sklearn.linear_model import Ridge, ElasticNet

import naplib as nl
from naplib.visualization import strf_plot

In [None]:
import dynamic_strf as dstrf

In [None]:
import IPython.display as ipd
from hdf5storage import loadmat

import torch
import torchaudio
import pytorch_lightning as plc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
torch.set_float32_matmul_precision("medium")
print(device)

In [None]:
# small_data = np.array([block_reduce(frame, block_size=(10, 10), func=np.mean) for frame in data])
# small_data = small_data.reshape(small_data.shape[0], small_data.shape[1], -1)
print(small_response.shape)

In [None]:
# define builder and trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pl_bolts.callbacks.printing import PrintTableMetricsCallback

### data naming
crossval = True
jackknife = False
reduced = True
logger = False
res = f"{small_height}x{small_width}"

desc_param = (
    ("jackk" if jackknife else "")
    + ("crossval" if crossval else "")
    + ("reduced" if reduced else "")
    + res
)

tb_logger = TensorBoardLogger(
    f"output/{desc_param}/logs", name=f"{desc_param}", log_graph=False
)


def trainer():
    return plc.Trainer(
        accelerator="auto",
        precision="16-mixed",
        gradient_clip_val=10.0,
        max_epochs=code_params["max_epochs"],
        logger=tb_logger,
        log_every_n_steps=1,
        detect_anomaly=False,
        enable_model_summary=False,
        enable_progress_bar=True,
        enable_checkpointing=True,
        callbacks=[PrintTableMetricsCallback()],
    )


def builder():
    return dstrf.modeling.DeepEncoder(
        input_size=spec.shape[2],
        hidden_size=64,
        channels=small_response.shape[2],
    ).to(device)

In [None]:
code_params["crossval"] = True

In [None]:
from torch.utils.data import DataLoader, TensorDataset

dstrf_X = torch.from_numpy(spec)
dstrf_y = torch.from_numpy(small_response)
print(dstrf_X.shape)
print(dstrf_y.shape)
# dstrf_dataset = TensorDataset(dstrf_X, dstrf_y)
# dataloader = DataLoader(dstrf_dataset, shuffle=True)
# print(dataloader)

In [None]:
import os, yaml, ipynbname

code_params["output_prefix"] = f"output/{desc_param}"

with open("code_params.yaml", "w") as file:
    yaml.dump(code_params, file)

os.system(
    f"jupyter nbconvert --to script thesis.ipynb --output-dir {code_params['output_prefix']}"
)

dstrf.modeling.fit_multiple(
    builder=builder,
    data=(dstrf_X, dstrf_y),
    batch_size=10,
    crossval=code_params["crossval"],
    jackknife=jackknife,
    trainer=trainer,
    save_dir=f"{code_params['output_prefix']}/model",
    verbose=1,
)

In [None]:
# import glob, os
# checkpoints = sorted(glob.glob(os.path.join('output/5x128-jackknife-cv', 'model-*.pt')))
# print(checkpoints)

In [None]:
scores = dstrf.modeling.test_multiple(
    model=builder(),
    checkpoints=f"{code_params['output_prefix']}/model",
    data=(dstrf_X[:1], dstrf_y[:1]),
    crossval=code_params["crossval"],
    jackknife_mode="pred",
)
scores.numpy()

In [None]:
scores = scores.reshape((small_height, small_width))
scores = np.nan_to_num(scores)
plt.imshow(scores, cmap="bwr", vmin=-1, vmax=1)

In [None]:
dstrf.estimate.dSTRF_multiple(
    model=builder(),
    checkpoints=f"{code_params['output_prefix']}/model",
    data=dstrf_X[:1],
    crossval=code_params["crossval"],
    save_dir=f"{code_params['output_prefix']}/dstrf",
    chunk_size=10,
)

In [None]:
dstrf_path = f"{code_params['output_prefix']}/dstrf/dSTRF-000.pt"
model_path = f"{code_params['output_prefix']}/model/model-000.pt"
dstrf_model = torch.load(dstrf_path)
cnn_model = torch.load(model_path)

In [None]:
print(dstrf_model)

In [None]:
dstrf.visualize.dSTRF(
    f"{code_params['output_prefix']}/dstrf/dSTRF-000.pt",
    channels=slice(0, None, 1),
    time_range=slice(0, None, 1),
    output_prefix=f"{code_params['output_prefix']}",
    vcodec="libx264",
    xlabel="Time lag (ms)",
    xticks=[0, 65],
    xtick_labels=[-650, 0],
    ylabel="Frequency",
    yticks=[0, 64],
    ytick_labels=["20Hz", "60KHz"],
)

In [None]:
ipd.display(
    ipd.Video(f"{code_params['output_prefix']}/video/channel-0000.mkv", height=300)
)