In [58]:
import os
import pickle
from copy import deepcopy
from typing import List, Optional, Tuple

import h5py as h5
import matplotlib.pyplot as plt
import numpy as np
import torch
from openretina.constants import CLIP_LENGTH, NUM_CLIPS, NUM_VAL_CLIPS
from openretina.dataloaders import get_movie_dataloader
from openretina.misc import CustomPrettyPrinter

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [62]:
base_folder = "/Data/fd_export"
data_path = os.path.join(base_folder, "2024-01-11_neuron_data_stim_8c18928_responses_99c71a0.pkl")
movies_path = os.path.join(base_folder, "2024-01-11_movies_dict_8c18928.pkl")
old_movies_path = os.path.join(base_folder, "movies_8c18928.pkl")

In [63]:
data = pickle.load(open(data_path, "rb"))
movies = pickle.load(open(movies_path, "rb"))

old_movies = pickle.load(open(old_movies_path, "rb"))

In [4]:
pp = CustomPrettyPrinter(indent=4)
pp.pprint(data)
pp.pprint(movies)

{   '1_ventral1_20210929': {   'eye': 'left',
                               'group_assignment': numpy.ndarray(shape=(86,)),
                               'key': {   'date': '2021-09-29',
                                          'exp_num': 1,
                                          'experimenter': 'Szatko',
                                          'field_id': 1,
                                          'stim_id': 5},
                               'responses_final': numpy.ndarray(shape=(86, 18450)),
                               'roi_coords': torch.Tensor(shape=[86, 2]),
                               'roi_ids': numpy.ndarray(shape=(86,)),
                               'scan_sequence_idx': 8,
                               'stim_id': 5,
                               'traces': numpy.ndarray(shape=(104,)),
                               'tracestimes': numpy.ndarray(shape=(104,))},
    '1_ventral1_20210930': {   'eye': 'left',
                               'group_assignment': nu

In [74]:
def get_all_movie_combinations(
    movie_train,
    movie_test,
    random_sequences: np.ndarray,
    val_clip_idx: Optional[List[int]] = None,
    seed=1000,
):
    """
    Generates combinations of movie data for 'left' and 'right' perspectives and
    prepares training, validation, and test datasets. It reorders the training
    movies based on random sequences and flips the movies for the 'left' perspective.

    Parameters:
    - movie_train: Tensor representing the training movie data.
    - movie_test: Tensor representing the test movie data.
    - random_sequences: Numpy array of random sequences for reordering training movies.
    - val_clip_idx: list of indices for validation clips. Needs to be between 0 and the number of clips.
    -seed: seed for random number generator, if val_clip_idx is None.

    Returns:
    - movies: Dictionary with processed movies for 'left' and 'right' perspectives, each
      containing 'train', 'validation', and 'test' datasets.
    """
    if val_clip_idx is None:
        rnd = np.random.RandomState(seed)
        val_clip_idx = list(rnd.choice(NUM_CLIPS, NUM_VAL_CLIPS, replace=False))

    # Convert movie data to tensors
    movie_train = torch.tensor(movie_train, dtype=torch.float)
    movie_test = torch.tensor(movie_test, dtype=torch.float)

    channels, train_length, px_y, px_x = movie_train.shape
    clip_length = train_length // random_sequences.shape[0]

    # Prepare validation movie data
    movie_val = torch.zeros(
        (channels, len(val_clip_idx) * clip_length, px_y, px_x), dtype=torch.float
    )
    for i, ind in enumerate(val_clip_idx):
        movie_val[:, i * clip_length : (i + 1) * clip_length] = movie_train[
            :, ind * clip_length : (ind + 1) * clip_length
        ]

    # Initialize movie dictionaries
    movies = {
        "left": {
            "train": {},
            "validation": torch.flip(movie_val, [-1]),
            "test": torch.flip(movie_test, [-1]),
        },
        "right": {"train": {}, "validation": movie_val, "test": movie_test},
    }

    # Process training movies for each random sequence
    for i in range(random_sequences.shape[1]):
        reordered_movie = torch.zeros_like(movie_train)
        for k, ind in enumerate(random_sequences[:, i]):
            reordered_movie[:, k * clip_length : (k + 1) * clip_length] = movie_train[
                :, ind * clip_length : (ind + 1) * clip_length
            ]

        movies["right"]["train"][i] = reordered_movie
        movies["left"]["train"][i] = torch.flip(reordered_movie, [-1])

    movies["val_clip_idx"] = val_clip_idx

    return movies


def gen_start_indices(
    random_sequences, val_clip_idx, clip_length, chunk_size, num_clips
):  # 108 x 20 integer
    """
    Generates a list of indices into movie frames that can be used as start
    indices for training chunks without including validation clips in the
    training set.

    Args:
        random_sequences (np.ndarray): Integer array of shape (108, 20) giving the ordering of the
                                       108 training clips for the 20 different sequences.
        val_clip_idx (list): List of integers indicating the 15 clips to be used for validation.
        clip_length (int): Clip length in frames (5s * 30 frames/s = 150 frames).
        chunk_size (int): Temporal chunk size per sample in frames (50).
        num_clips (int): Total number of training clips (108).

    Returns:
        dict: A dictionary with keys "train", "validation", and "test", and index lists as values.
    """
    val_start_idx = list(
        np.linspace(
            0, clip_length * (len(val_clip_idx) - 1), len(val_clip_idx), dtype=int
        )
    )

    start_idx_dict = {"train": {}, "validation": val_start_idx, "test": [0]}
    for i in range(
        random_sequences.shape[1]
    ):  # iterate over the 20 different movie permutations
        start_idx = 0
        current_idx = 0
        seq_start_idx = []
        seq_length = []
        for k, ind in enumerate(
            random_sequences[: num_clips // 2, i]
        ):  # over first half of the clips
            if ind in val_clip_idx:
                length = current_idx - start_idx
                if length > 0:
                    seq_start_idx.append(start_idx)
                    seq_length.append(length)
                start_idx = current_idx + clip_length
            current_idx += clip_length
        length = current_idx - start_idx
        if length > 0:
            seq_start_idx.append(start_idx)
            seq_length.append(length)
        start_idx = current_idx
        for k, ind in enumerate(
            random_sequences[num_clips // 2 :, i]
        ):  # over second half of the clips
            if ind in val_clip_idx:
                length = current_idx - start_idx
                if length > 0:
                    seq_start_idx.append(start_idx)
                    seq_length.append(length)
                start_idx = current_idx + clip_length
            current_idx += clip_length
        length = current_idx - start_idx
        if length > 0:
            seq_start_idx.append(start_idx)
            seq_length.append(length)

        chunk_start_idx = []
        for start, length in zip(seq_start_idx, seq_length):
            idx = np.arange(start, start + length - chunk_size + 1, chunk_size)
            chunk_start_idx += list(idx[:-1])
        start_idx_dict["train"][i] = chunk_start_idx
    return start_idx_dict


def optimized_gen_start_indices(
    random_sequences, val_clip_idx, clip_length, chunk_size, num_clips
):
    """
    Optimized function to generate a list of indices for training chunks while
    excluding validation clips.

    :param random_sequences: int np array; 108 x 20, giving the ordering of the
                             108 training clips for the 20 different sequences
    :param val_clip_idx:     list of integers indicating the 15 clips to be used
                             for validation
    :param clip_length:      clip length in frames (5s*30frames/s = 150 frames)
    :param chunk_size:       temporal chunk size per sample in frames (50)
    :param num_clips:        total number of training clips (108)
    :return: dict; with keys train, validation, and test, and index list as
             values
    """
    val_clip_set = set(val_clip_idx)
    val_start_idx = list(
        np.linspace(
            0, clip_length * (len(val_clip_idx) - 1), len(val_clip_idx), dtype=int
        )
    )

    start_idx_dict = {"train": {}, "validation": val_start_idx, "test": [0]}

    for sequence_index in range(random_sequences.shape[1]):
        start_idx = 0
        current_idx = 0
        seq_start_idx = []
        seq_length = []

        for clip_index in random_sequences[:, sequence_index]:
            if clip_index in val_clip_set:
                length = current_idx - start_idx
                if length > 0:
                    seq_start_idx.append(start_idx)
                    seq_length.append(length)
                start_idx = current_idx + clip_length
            current_idx += clip_length

        # Handling the last segment
        length = current_idx - start_idx
        if length > 0:
            seq_start_idx.append(start_idx)
            seq_length.append(length)

        chunk_start_idx = [
            idx
            for start, length in zip(seq_start_idx, seq_length)
            for idx in range(start, start + length - chunk_size + 1, chunk_size)[:-1]
        ]
        start_idx_dict["train"][sequence_index] = chunk_start_idx

    return start_idx_dict

In [75]:
all_movies_dict = get_all_movie_combinations(movies["train"], movies["test"], movies["random_sequences"], None)

In [76]:
start_indices = gen_start_indices(
    movies["random_sequences"],
    all_movies_dict["val_clip_idx"],
    CLIP_LENGTH,
    50,
    NUM_CLIPS,
)

In [33]:
# start_indices_opt = optimized_gen_start_indices(random_sequences, all_movies_dict["val_clip_idx"], CLIP_LENGTH, 150, NUM_CLIPS)

In [None]:
class NeuronData:
    def __init__(
        self,
        eye,
        group_assignment,
        key,
        responses_final,
        roi_coords,
        roi_ids,
        scan_sequence_idx,
        stim_id,
        traces,
        tracestimes,
        random_sequences,
        val_clip_idx,
        num_clips,
        clip_length,
    ):
        """
        Boilerplate class to store neuron data. Added for backwards compatibility with Hoefling et al., 2022.
        """

        self.eye = eye
        self.group_assignment = group_assignment
        self.key = key
        self.responses_final = responses_final
        self.roi_coords = roi_coords
        self.roi_ids = roi_ids
        self.scan_sequence_idx = scan_sequence_idx
        self.stim_id = stim_id
        self.traces = traces
        self.tracestimes = tracestimes
        self.clip_length = clip_length
        self.num_clips = num_clips
        self.random_sequences = random_sequences
        self.val_clip_idx = val_clip_idx

    #! this has to become a regular method in the future
    @property
    def response_dict(self):
        num_neurons = self.responses_final.shape[0]
        movie_ordering = (
            np.arange(self.num_clips)
            if len(self.random_sequences) == 0
            else self.random_sequences[:, self.scan_sequence_idx]
        )

        if self.stim_id == 0:
            responses_test = self.responses_final[:, : 10 * self.clip_length].T
            responses_train = self.responses_final[:, 10 * self.clip_length :].T
            test_responses_by_trial = None
        else:
            responses_test = np.zeros((5 * self.clip_length, num_neurons))
            responses_train = np.zeros((self.num_clips * self.clip_length, num_neurons))
            test_responses_by_trial = []
            for roi in range(num_neurons):
                tmp = np.vstack(
                    (
                        self.responses_final[roi, : 5 * self.clip_length],
                        self.responses_final[
                            roi, 59 * self.clip_length : 64 * self.clip_length
                        ],
                        self.responses_final[roi, 118 * self.clip_length :],
                    )
                )
                test_responses_by_trial.append(tmp)
                responses_test[:, roi] = np.mean(tmp, 0)
                responses_train[:, roi] = np.concatenate(
                    (
                        self.responses_final[
                            roi, 5 * self.clip_length : 59 * self.clip_length
                        ],
                        self.responses_final[
                            roi, 64 * self.clip_length : 118 * self.clip_length
                        ],
                    )
                )
            test_responses_by_trial = np.asarray(test_responses_by_trial)

        if self.stim_id == 0:
            responses_val = np.zeros(
                [len(self.val_clip_idx), self.clip_length, num_neurons]
            )
            for i, ind in enumerate(self.val_clip_idx):
                responses_val[i] = responses_train[
                    ind * self.clip_length : (ind + 1) * self.clip_length, :
                ]
        else:
            responses_val = np.zeros(
                [len(self.val_clip_idx) * self.clip_length, num_neurons]
            )
            inv_order = np.argsort(movie_ordering)
            for i, ind1 in enumerate(self.val_clip_idx):
                ind2 = inv_order[ind1]
                responses_val[
                    i * self.clip_length : (i + 1) * self.clip_length, :
                ] = responses_train[
                    ind2 * self.clip_length : (ind2 + 1) * self.clip_length, :
                ]

        response_dict = {
            "train": torch.tensor(responses_train).to(torch.float),
            "validation": torch.tensor(responses_val).to(torch.float),
            "test": {
                "avg": torch.tensor(responses_test).to(torch.float),
                "by_trial": torch.tensor(test_responses_by_trial),
            },
        }

        return response_dict

    def transform_roi_mask(self, roi_mask):
        roi_coords = np.zeros((len(self.roi_ids), 2))
        for i, roi_id in enumerate(self.roi_ids):
            single_roi_mask = np.zeros_like(roi_mask)
            single_roi_mask[roi_mask == -roi_id] = 1
            roi_coords[i] = self.roi2readout(single_roi_mask)
        return roi_coords

    def roi2readout(
        self,
        single_roi_mask,
        roi_mask_pixelsize=2,
        readout_mask_pixelsize=50,
        x_offset=2.75,
        y_offset=2.75,
    ):
        """
        Maps a roi mask of a single roi from recording coordinates to model
        readout coordinates
        :param single_roi_mask: 2d array with nonzero values indicating the pixels
                of the current roi
        :param roi_mask_pixelsize: size of a pixel in the roi mask in um
        :param readout_mask_pixelsize: size of a pixel in the readout mask in um
        :param x_offset: x offset indicating the start of the recording field in readout mask
        :param y_offset: y offset indicating the start of the recording field in readout mask
        :return:
        """
        pixel_factor = readout_mask_pixelsize / roi_mask_pixelsize
        y, x = np.nonzero(single_roi_mask)
        y_trans, x_trans = y / pixel_factor, x / pixel_factor
        y_trans += y_offset
        x_trans += x_offset
        x_trans = x_trans.mean()
        y_trans = y_trans.mean()
        coords = np.asarray(
            [
                self.map_to_range(max=8, val=y_trans),
                self.map_to_range(max=8, val=x_trans),
            ],
            dtype=np.float32,
        )
        return coords

    def map_to_range(self, max, val):
        val = val / max
        val = val - 0.5
        val = val * 2
        return val

In [None]:
def natmov_dataloaders_v2(
    movies_dictionary,
    neuron_data_dictionary,
    train_chunk_size: int = 50,
    batch_size: int = 32,
    seed: int = 42,
):
    # make sure movies and responses arrive as torch tensors!!!
    rnd = np.random.RandomState(seed)  # make sure whether we want the validation set to depend on the seed

    num_clips, clip_length = NUM_CLIPS, CLIP_LENGTH
    val_clip_idx = list(rnd.choice(NUM_CLIPS, NUM_VAL_CLIPS, replace=False))

    clip_chunk_sizes = {
        "train": train_chunk_size,
        "validation": clip_length,
        "test": 5 * clip_length,
    }
    dataloaders = {"train": {}, "validation": {}, "test": {}}
    # draw validation indices so that a validation movie can be returned!
    random_sequences = movies_dictionary["random_sequences"]
    movies = get_all_movie_combinations(
        movies_dictionary["train"], movies_dictionary["test"], random_sequences, val_clip_idx=val_clip_idx
    )
    start_indices = gen_start_indices(random_sequences, val_clip_idx, clip_length, train_chunk_size, num_clips)
    for session_key, session_data in neuron_data_dictionary.items():
        neuron_data = NeuronData(
            **session_data,
            random_sequences=random_sequences,  # Used together with the validation index to get the validation response in the corresponding dict
            val_clip_idx=val_clip_idx,
            num_clips=num_clips,
            clip_length=clip_length,
        )

        if neuron_data.responses_train.shape[-1] == 0:
            print("skipped: {}".format(session_key))
            break
        for fold in ["train", "validation", "test"]:
            if not (hasattr(neuron_data, "roi_coords")):
                neuron_data.roi_mask = []
            dataloaders[fold][session_key] = get_movie_dataloader(
                movies[neuron_data.eye][fold],
                neuron_data.response_dict[fold],
                neuron_data.roi_ids,
                neuron_data.roi_coords,
                neuron_data.group_assignment,
                neuron_data.scan_sequence_idx,
                fold,
                clip_chunk_sizes[fold],
                start_indices[fold],
                batch_size,
            )

    return dataloaders