# User defined filter function

This notebooks demos how to create a user defined filter function (can be generalized to any function).

Here we are trying to create a dataloader for a subset of the Experiment data. This subset is defined either by a list of `frame id intervals`.

In [1]:
import json
from tqdm import tqdm
import numpy as np

from experanto.intervals import (
    TimeInterval,
    find_complement_of_interval_array,
    uniquefy_interval_array,
)

valid_keys = ['00003','00005','00006','00009',] # Sample

In [2]:
from experanto.datasets import register_callable
from experanto.interpolators import Interpolator

@register_callable("filter2")
def id2interval(dataset=None, id_list=[], complement=False):
    '''Convert a list of IDs to intervals.
    Args:
        id_list (list): The list of IDs.
        complement (bool): If True, return the complement of the intervals.
    Returns:
        list: A list of intervals.'''

    def implementation(device_: Interpolator, 
                       id_list=id_list, 
                       dataset=dataset, 
                       complement=complement):

        if not id_list:
            return []
        
        id_list = sorted(id_list)
    
        meta_path = f"/data/test_upsampling_without_hamming_30.0Hz/{dataset}/screen/combined_meta.json"
        with open(meta_path, 'rb') as f:
            meta = json.load(f)

        if complement:
            all_ids = set(meta.keys())
            used_ids = set(id_list)
            complement_ids = sorted(all_ids - used_ids)
            return id2interval(dataset=dataset, id_list=complement_ids)(device_)

        timestamps = np.load(f"/data/test_upsampling_without_hamming_30.0Hz/{dataset}/screen/timestamps.npy")
        
        intervals = []
        for i in range(1, len(id_list)):
            start = meta[id_list[i]]['first_frame_idx']
            end = start + meta[id_list[i]]['num_frames']
            intervals.append(TimeInterval(timestamps[start], timestamps[end]))  # start inclusive, end exclusive
        
        valid_intervals = uniquefy_interval_array(intervals)
        
        return valid_intervals
        
    return implementation

In [3]:
id_list = valid_keys
dataset = "dynamic29515-10-12-Video-021a75e56847d574b9acbcc06c675055_30hz"
id2interval(id_list=id_list, dataset=dataset)(None)

[TimeInterval(start=1682534632.6021829, end=1682534652.623929),
 TimeInterval(start=1682534673.1462207, end=1682534683.1570945)]

### ToDo: Use the time intervals to dataloader function to complete the demo

In [4]:
# Jupyter notebook setup
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Standard imports
import sys
import os
from pathlib import Path

# Add project root to path if needed
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

In [5]:
import matplotlib.pyplot as plt
from os import path

from tqdm import tqdm
import torch
from omegaconf import OmegaConf, open_dict

from experanto.datasets import ChunkDataset
from experanto.dataloaders import get_multisession_dataloader

In [6]:
from experanto.configs import DEFAULT_CONFIG as cfg

print(OmegaConf.to_yaml(cfg))

dataset:
  global_sampling_rate: null
  global_chunk_size: null
  add_behavior_as_channels: false
  replace_nans_with_means: false
  cache_data: false
  out_keys:
  - screen
  - responses
  - eye_tracker
  - treadmill
  - timestamps
  normalize_timestamps: true
  modality_config:
    screen:
      keep_nans: false
      sampling_rate: 30
      chunk_size: 60
      valid_condition:
        tier: train
      offset: 0
      sample_stride: 1
      include_blanks: true
      transforms:
        normalization: normalize
        Resize:
          _target_: torchvision.transforms.v2.Resize
          size:
          - 144
          - 256
      interpolation:
        rescale: true
        rescale_size:
        - 144
        - 256
    responses:
      keep_nans: false
      sampling_rate: 8
      chunk_size: 16
      offset: 0.0
      transforms:
        normalization: standardize
      interpolation:
        interpolation_mode: nearest_neighbor
      filters:
        nan_filter:
          __tar

In [7]:
cfg.dataset.modality_config.screen.include_blanks = True
cfg.dataset.modality_config.screen.valid_condition = {"tier": "train"}
cfg.dataloader.num_workers = 8

In [8]:
# Example of defining a new function in a different file './my_functions/common_filters.py'
cfg.dataset.modality_config.treadmill.filters.nan_filter = {"__key__": "filter1", "vicinity": 0.05}
cfg.dataset.modality_config.eye_tracker.filters.nan_filter = {"__key__": "filter1", "vicinity": 0.05}
cfg.dataset.modality_config.responses.filters.nan_filter = {"__key__": "filter1", "vicinity": 0.05}

# Example of defining a new function in a the jupyter notebook
# cfg.dataset.modality_config.treadmill.filters.nan_filter = {"__key__": "filter2", "id_list": id_list, "dataset": dataset}
# cfg.dataset.modality_config.eye_tracker.filters.nan_filter = {"__key__": "filter2", "id_list": id_list, "dataset": dataset}
# cfg.dataset.modality_config.responses.filters.nan_filter = {"__key__": "filter2", "id_list": id_list, "dataset": dataset}

In [9]:
print(OmegaConf.to_yaml(cfg))

dataset:
  global_sampling_rate: null
  global_chunk_size: null
  add_behavior_as_channels: false
  replace_nans_with_means: false
  cache_data: false
  out_keys:
  - screen
  - responses
  - eye_tracker
  - treadmill
  - timestamps
  normalize_timestamps: true
  modality_config:
    screen:
      keep_nans: false
      sampling_rate: 30
      chunk_size: 60
      valid_condition:
        tier: train
      offset: 0
      sample_stride: 1
      include_blanks: true
      transforms:
        normalization: normalize
        Resize:
          _target_: torchvision.transforms.v2.Resize
          size:
          - 144
          - 256
      interpolation:
        rescale: true
        rescale_size:
        - 144
        - 256
    responses:
      keep_nans: false
      sampling_rate: 8
      chunk_size: 16
      offset: 0.0
      transforms:
        normalization: standardize
      interpolation:
        interpolation_mode: nearest_neighbor
      filters:
        nan_filter:
          __key

In [10]:
from experanto.dataloaders import get_multisession_dataloader
from my_functions import common_filters

paths = ["/data/test_upsampling_without_hamming_30.0Hz/dynamic29515-10-12-Video-021a75e56847d574b9acbcc06c675055_30hz"]
train_dl = get_multisession_dataloader(paths, cfg)

