In [None]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import torch

from openretina.insilico.VectorFieldAnalysis.vector_field_analysis import (
    plot_untreated_vectorfield,
    plot_clean_vectorfield,
)
from openretina.models.core_readout import load_core_readout_model


# Load model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = load_core_readout_from_remote(
#     "karamanlis_2024_base", device="cuda" if torch.cuda.is_available() else "cpu"
# )

# TO DO : Repair loading function again
model = load_core_readout_model(
    "/home/baptiste/Documents/LabPipelines/open-retina/openretina_assets/runs/core_readout_goldin_2022/2025-09-01_10-50-49/checkpoints/epoch=69_val_correlation=0.000.ckpt",
    device="cuda" if torch.cuda.is_available() else "cpu",
)
model

In [None]:
# pick a random session
session_id = list(model.readout.keys())[-1]
print(session_id)
n_neurons = model.readout[session_id].outdims
print(f"Number of neurons: {n_neurons}")

# Load natural images

In [None]:
from openretina.data_io.goldin_2022.responses import load_all_responses
from openretina.data_io.goldin_2022.stimuli import load_all_stimuli

# Load training images (Feel free to use another dataset instead but make sure the preprocessing is correct)
movies = load_all_stimuli(
    base_data_path="/home/baptiste/Documents/LabPipelines/open-retina/notebooks/data/omarre_lab/goldin_2022",
    normalize_stimuli=True,
    specie='mouse'
)
responses = load_all_responses(
    base_data_path="/home/baptiste/Documents/LabPipelines/open-retina/notebooks/data/omarre_lab/goldin_2022",
    specie='mouse'
)

# Test Goldin 2022 model

In [None]:
test_movies = movies[session_id].test_dict['test']
test_responses = responses[session_id].test_dict['test']

In [None]:
import h5py

# Load the HDF5 file
file_path = f'data/omarre_lab/goldin_2022/{session_id}.h5'
with h5py.File(file_path, 'r') as f:
    # Convert HDF5 file to dictionary recursively
    def h5_to_dict(item):
        if isinstance(item, h5py.Dataset):
            return item[()]
        elif isinstance(item, h5py.Group):
            return {key: h5_to_dict(item[key]) for key in item.keys()}

    # Convert the entire file to a dictionary
    full_test = h5_to_dict(f['test'])
    print("Available keys in the file:", list(f.keys()))
    print()

In [None]:
reliability_list = []
for cell in full_test['repeats'].keys():
    odd_trials = full_test['repeats'][cell][::2]
    even_trials = full_test['repeats'][cell][1::2]
    reliability_list.append(np.corrcoef(odd_trials.mean(axis=0), even_trials.mean(axis=0))[0,1])
reliability_list = np.array(reliability_list)
print(f"Average reliability across cells: {np.mean(reliability_list)}")


In [None]:
test_movies.shape, test_responses.shape

In [None]:
model.eval()
predict_responses = model(torch.tensor(test_movies).swapaxes(0,1).unsqueeze(1).to(device)).cpu().detach().numpy()[:,0,:].swapaxes(0,1)
predict_responses.shape

In [None]:
def get_test_score(cell_id):
    return np.corrcoef(test_responses[cell_id], predict_responses[cell_id])[0,1]
all_scores = [get_test_score(i) for i in range(n_neurons)]
print(f"Average r2 test score across cells: {np.mean(all_scores)}")

In [None]:
corrected_scores = [all_scores[i] / np.sqrt(reliability_list[i]) for i in range(n_neurons)]
print(f"Average corrected r2 test score across cells: {np.mean(corrected_scores)}")
print("Even better than reported in the paper :)")

# Prepare natural movie dataset

In [None]:
# HOTFIX : Add input_shape to data_info to keep the original function intact
model.data_info = {}
model.data_info['input_shape'] = movies[session_id].train.swapaxes(0, 1).shape[1:]
# movies[session_id].train.expand_dims(1)
# movies= np.expand_dims(movies[session_id].train, axis=1)
movies = movies[session_id].train.swapaxes(0, 1)  # put channels first

In [None]:
movies, n_empty_frames = prepare_movies_dataset(model, session_id, device = device, image_library = movies, normalize_movies=False)

In [None]:
# Display the last frame of an example movie with both channels side by side
example_idx = 20  # First movie as example
last_frame = movies[example_idx, :, -1, :, :]  # Shape: (2, 72, 64)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# First channel (channel 0)
im1 = axes[0].imshow(last_frame[0], cmap='grey', vmin=movies[:, 0, :, :].min(), vmax=movies[:, 0, :, :].max())
axes[0].set_title('First Channel')
plt.colorbar(im1, ax=axes[0])
axes[0].axis('off')
# Optional : Second channel (channel 1)
# im2 = axes[1].imshow(last_frame[1], cmap='Purples', vmin=movies[:, 1, :, :].min(), vmax=movies[:, 1, :, :].max())
# axes[1].set_title('Second Channel')
# # plt.colorbar(im2, ax=axes[1])
axes[1].axis('off')
plt.tight_layout()
plt.show()

print(f"First channel range on this image: [{last_frame[0].min():.3f}, {last_frame[0].max():.3f}]")
# print(f"Second channel range: [{last_frame[1].min():.3f}, {last_frame[1].max():.3f}]")

In [None]:
cell_id = 10 # For Goldin 2022, mouse => Use 10 for an example contrast cell
cell_reliability = reliability_list[cell_id]
cell_test_score = corrected_scores[cell_id]
print(f"Cell {cell_id} reliability: {cell_reliability}")
print(f"Cell {cell_id} test score: {cell_test_score}")

In [None]:
# The slow cell
lsta_library, response_library = compute_lsta_library(model, movies, session_id, cell_id, batch_size=64, integration_window=(10,15), device=device)

In [None]:
# We can check some sample responses, making sure the integrity of the response profile is predicted and we did not cut it too short.

plt.plot(response_library[0, :, cell_id])
plt.plot(response_library[1, :, cell_id])
plt.plot(response_library[2, :, cell_id])
plt.plot(response_library[3, :, cell_id])
plt.plot(response_library[4, :, cell_id])
plt.plot(response_library[5, :, cell_id])
plt.xlabel('Time (frames)')
plt.ylabel('Cell predicted Response (au)')

In [None]:
lsta_library.shape

In [None]:
lsta_library.mean()
lsta_library.max()
lsta_library.var()

In [None]:
# We can now plot an example LSTA and the corresponding image

image = 500
channel = 0
lsta = lsta_library[image, channel]
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# image
axes[0].imshow(movies[image,0,-1], cmap='gray')
axes[0].set_title('Original Image')
axes[0].axis('off')

# LSTA
axes[1].imshow(lsta, cmap='bwr', vmin=-abs(lsta).max(), vmax=abs(lsta).max())
axes[1].set_title('LSTA (Local Spatiotemporal Average)')
axes[1].axis('off')

plt.tight_layout()
plt.show()

# Do PCA on the LSTA library

In [None]:
# Select channel
channel = 0
# lsta_library = lsta_library[:, :, :]

PC1, PC2, explained_variance = get_pc_from_pca(model, channel, lsta_library, plot=True)

In [None]:
# Project the images onto PCA space
images = movies[:,channel,-1,:,:]
images_coordinate = get_images_coordinate(images, PC1, PC2, plot=False)

images_coordinate

In [None]:
images_coordinate.shape, lsta_library.shape

In [None]:
# Plot the vector field of the LSTA in PCA space
from openretina.insilico.VectorFieldAnalysis.vector_field_analysis import plot_untreated_vectorfield

fig = plot_untreated_vectorfield(lsta_library, channel = 0, PC1 = PC1, PC2 = PC2, images_coordinate = images_coordinate)

In [None]:
from openretina.insilico.VectorFieldAnalysis.vector_field_analysis import plot_clean_vectorfield

fig = plot_clean_vectorfield(
    lsta_library,
    channel,
    PC1,
    PC2,
    images,
    images_coordinate,
    explained_variance,
    x_bins=31,
    y_bins=31,
)
fig.show()

# ADD the firing rate by the side + mention cell types

# Create all arrowplots of a session

In [None]:
# from openretina.insilico.VectorFieldAnalysis.vector_field_analysis import plot_clean_vectorfield


# for cell_id in range(n_neurons):
#     print(f"Processing cell {cell_id}...")
#     lsta_library, response_library = compute_lsta_library(model, movies, session_id, cell_id, batch_size=64, device=device)
#     PC1, PC2, explained_variance = get_pc_from_pca(model, channel, lsta_library, plot=True)
#     images_coordinate = get_images_coordinate(images, PC1, PC2, plot=False)
#     fig = plot_clean_vectorfield(
#         lsta_library,
#         channel,
#         PC1,
#         PC2,
#         images,
#         images_coordinate,
#         explained_variance,
#         x_bins=31,
#         y_bins=31,
#     )
#     fig.savefig(f"goldin_2022_vector_fields/cell_{cell_id}_mouse.png", dpi=300, bbox_inches='tight')
#     plt.close(fig)