In [1]:
import cv2
import numpy as np
from tqdm import tqdm
from functools import partial
from itertools import compress
from scipy.interpolate import InterpolatedUnivariateSpline, interp1d

import datajoint as dj
import static_nda as nda
from stimulus import stimulus
from pipeline import meso, fuse,treadmill
pupil = dj.create_virtual_module("pupil", "pipeline_eye")
anatomy = dj.create_virtual_module("anatomy", "pipeline_anatomy")
neuro_data = dj.create_virtual_module('neuro_data','neurodata_static')

Loading local settings from pipeline_config.json
Connecting pfahey@at-database.ad.bcm.edu:3306


Reformatted to consolidate only data processing steps applied for static release dataset out of subfunctions into single notebook.  Removed all references to tiering, which depend on ConditionTier table in at-fabee.ad.bcm.edu database. 

__Please note that the cells in this notebook are interdependent and may reuse variables, and should be run in order.__

In [3]:
scan_key  = {'animal_id': 21067, 'session': 10, 'scan_idx': 18} 

# config parameters used for static release dataset
_valid_types = [
    stimulus.Frame,
    stimulus.MonetFrame,
    stimulus.TrippyFrame,
    stimulus.ColorFrameProjector,
]
imgsize = (144,256)
response_integration_window = 0.5 # seconds
response_integration_offset = 0.05 # seconds
drop_invalid_behavior=True
include_behavior=True
stack_coordinates=False
gamma_correction=False
trace_tolerance = 0.01
behavior_filter = 'boxcar'
neuro_trace_filter = 'boxcar'

## Stimulus Image Preprocessing

Images are extracted from stimulus.StaticImage.Image, where they are stored as 144x256 at uint8 resolution.  They are resized to 144x256 using cv2.INTER_AREA rescaling and retyped to float32.  
https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/frames.py#L17

In [19]:
# images = (stimulus.StaticImage.Image & (stimulus.Frame & 
#                                         (stimulus.Trial & nda.ScanInclude & _valid_types))).fetch('image')
# display(set([np.shape(i) for i in images]))
# display(set([i.dtype for i in images]))

all_conditions = (stimulus.Condition & (stimulus.Trial & scan_key & _valid_types)).fetch("KEY")

frames,frame_types = {},{}
skipped = 0
for cond_key in tqdm(all_conditions, total=len(all_conditions[0]), desc="Loading frames"):
    frame = (stimulus.StaticImage.Image & (stimulus.Frame & cond_key)).fetch1("image")
    frame_type = "stimulus.Frame"
    
    condition_hash = cond_key["condition_hash"]
    frames[condition_hash] = cv2.resize(frame, imgsize[::-1], interpolation=cv2.INTER_AREA).astype(np.float32)
    frame_types[condition_hash] = frame_type
    

Loading frames: 5100it [00:55, 91.44it/s]            


## utility functions
 https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/preprocessing.py#L4
 
https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/traces.py#L13


In [3]:
# 1d interpolation to fill nan values in trace, retaining nan gaps > len = preserve_gap (defaults to len 0)
def fill_nans(x, preserve_gap=None):
    """
    :param x:  1D array  -- will
    :return: the array with nans interpolated
    The input argument is modified.
    """ 
    # check if input argument indicate that large gaps should be kept in "keep" variable
    if preserve_gap is not None:
        assert preserve_gap % 2 == 1, "can only efficiently preserve odd gaps"
        # convolves np.isnan(trace) with preserve_gap length, to enumerage gap length centered on that position
        # values outside edges assumed nonnan
        # identifies position in center of gaps of length >= preserve gap
        # convolves with preserve_gap length, to keep all positions within range of centers located within
        # gap of length >= preserve gap
        keep = np.convolve(
            np.convolve(1 * np.isnan(x), np.ones(preserve_gap), mode="same")
            == preserve_gap,
            np.ones(preserve_gap, dtype=bool),
            mode="same",
        )
    else:
        # otherwise, keep vector set to all zeros
        keep = np.zeros(len(x), dtype=bool)

    # identify nan positions
    nans = np.isnan(x)

    # if all nans, set all trace values to zero
    # otherwise linearly interpolate values at all nan positions from surrounding non-nan positions
    # assumes regular sample intervals?
    x[nans] = (
        0
        if nans.all()
        else np.interp(nans.nonzero()[0], (~nans).nonzero()[0], x[~nans])
    )
    
    # reset trace values at keep positions to nan
    x[keep] = np.nan
    
    # return modified trace
    return x

def adjust_trace_len(traces, frame_times):
    """
    Adjust neural traces to the same length
    Args:
        traces:        np.array of traces
        frame_times:   frametimes corresponding to the traces 1d array
    Returns: traces and frametimes shortened to the same length
    """
    trace_len, nframes = traces.shape[1], frame_times.shape[1]
    if trace_len < nframes:
        frame_times = frame_times[:, :trace_len]
    elif trace_len > nframes:
        traces = traces[:, :nframes]
    return traces, frame_times


## Get neuronal traces
https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/traces.py#L115

In [None]:
# downsample from scan depth times (nframes * ndepths) to scan frame times 
ndepth = len(dj.U('z') & (meso.ScanInfo.Field & scan_key))
frame_times = (stimulus.Sync & scan_key).fetch1("frame_times").squeeze()[::ndepth]

# restrict to masks classified as soma
soma = meso.MaskClassification.Type() & dict(type="soma")

# fetch traces, delays, and keys
spikes = (dj.U("field", "channel") * meso.Activity.Trace * meso.ScanSet.UnitInfo * meso.ScanSet.Unit & soma & scan_key)

traces, ms_delay, trace_keys = spikes.fetch("trace", "ms_delay", dj.key, 
                                            order_by="animal_id, session, scan_idx, unit_id")

# stack traces, linearly interpolate across all nans (no nan gap preservation)
traces = np.vstack([fill_nans(tr.astype(np.float32)).squeeze() for tr in traces])

# correct delay from ms to s
delay = np.fromiter(ms_delay / 1000, dtype=np.float)

# create frame times vector for each trace, including estimated delay per trace
frame_times_mat = delay[:, None] + frame_times[None, :]

# clip to same trace length
traces, frame_times_mat = adjust_trace_len(traces, frame_times_mat)

# find beginning and end of scan frame times
ftmin,ftmax = frame_times_mat.min(),frame_times_mat.max()

# create complete trace keys for all traces
trace_keys = [dict(scan_key, **trace_key) for trace_key in trace_keys]


## Get neuron anatomical info
https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/traces.py#L131

In [10]:
rel = (
        fuse.Activity.Trace.proj("animal_id", "session", "scan_idx", "unit_id")
        * anatomy.AreaMembership
        * anatomy.LayerMembership
        * meso.ScanSet.UnitInfo
        & scan_key
    )

infos = []
for trace_key in tqdm(trace_keys, desc="Get unit info"):
    info = (rel & trace_key).fetch(
        "brain_area",
        "layer",
        "animal_id",
        "session",
        "scan_idx",
        "unit_id",
        "um_x",
        "um_y",
        "um_z",
    )
    infos.append((info))
areas, layers, animal_ids, sessions, scan_idxs, unit_ids, x, y, z = map(
        np.array, zip(*infos)
    )
cell_motor_coordinates = np.c_[x, y, z]
assert len(np.unique(unit_ids)) == len(unit_ids), "unit_ids are not unique"

neuron_info = dict(
                    unit_ids=unit_ids.astype(np.uint16).squeeze(),
                    animal_ids=animal_ids.astype(np.uint16).squeeze(),
                    sessions=sessions.astype(np.uint8).squeeze(),
                    scan_idx=scan_idxs.astype(np.uint8).squeeze(),
                    layer=layers.astype(str).squeeze(),
                    area=areas.astype(str).squeeze(),
                    cell_motor_coordinates=cell_motor_coordinates.squeeze(),
                  )        

Get unit info: 100%|██████████| 8684/8684 [00:22<00:00, 392.33it/s]


## get stimulus onsets

follows neuron trace fetch in order to remove trials with flip times outside the range of scan frame times.

https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/stimulus.py#L129

In [11]:
# find trials affiliated with this scan
# ExcludedTrial contains trials with flip number other than 3, which is redundant with below assertion
targets = (stimulus.Condition & scan_key) * (stimulus.Trial & scan_key) - neuro_data.ExcludedTrial
flip_times, trial_keys = targets.fetch("flip_times", "KEY", order_by="condition_hash")
flip_times = [ft.squeeze() for ft in flip_times]

# mark trials occuring before the beginning of the scan or after the end of the scan as invalid
valid_trials = np.array([ft.min() >= ftmin and ft.max() <= ftmax for ft in flip_times], dtype=bool)
if not np.all(valid_trials):
    non_valid = (~valid_trials).sum()
    print(f"Dropping {non_valid} trials with dropped frames or flips outside the recording interval")
    
# remove flip times and trial keys from invalid trials
flip_times = list(compress(flip_times, valid_trials))
trial_keys = [dict(scan_key, **trial_key) for trial_key in compress(trial_keys, valid_trials)]

# assert predictable number of flips per trial (expected 2-3, expected all one size)
n_ft = np.unique([ft.size for ft in flip_times])
assert len(n_ft) == 1, "Found inconsistent number of fliptimes"
n_ft = int(n_ft)
assert n_ft in (2, 3), "Cannot deal with {} flip times".format(n_ft)

# stack flip times such that trials are rows, column correspond to clear flip, onset_flip
stimulus_onset = np.vstack(flip_times)

# sort trials by first flip time
ft = stimulus_onset[np.argsort(stimulus_onset[:, 0])]

# reduce to only the column containing the second flip, indicating the blank end and 
# onset of imagenet image presentation
stimulus_onsets = stimulus_onset[:, 1]


## Custom Spline for interpolation into traces containing nans

https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/splines.py#L7

In [5]:
class NaNSpline(InterpolatedUnivariateSpline):
    def __init__(self, x, y, **kwargs):
        
        # find nans in x or y
        xnan = np.isnan(x)
        ynan = np.isnan(y)
        
        # weights = nan in x or y
        w = xnan | ynan  # get nans
        
        # convert to arrays
        x, y = map(np.array, [x, y])  # copy arrays
        
        # set y at ynans to zero
        y[ynan] = 0
        
        # set x at xnans to interpolate between known x values
        x[xnan] = np.interp(np.where(xnan)[0], np.where(~xnan)[0], x[~xnan])
        
        # call InterpolatedUnivariateSpline init
        # interpolate with passed kwargs, including k=spline_degree and ext = extrapolation mode beyond boundaries
        # only pass positions that are not nan in x or y (effectively set weight of nan positions to zero)
        super().__init__(x[~w], y[~w], **kwargs)  

        # create interpolator to linearly interpolate x/y nans at any x position (1 = nan)
        self.nans = interp1d(x, 1 * w, kind='linear')

    def __call__(self, x, **kwargs):
        # instantiate zero weight vectors
        ret = np.zeros_like(x)
        newnan = np.zeros_like(x)

        # mark current nan positions in x vector with ones
        old_nans = np.isnan(x)
        newnan[old_nans] = 1
        
        # interpolate remaining positions from x/y nan interpolator
        newnan[~old_nans] = self.nans(x[~old_nans])

        # find any position that is currently nan or nan adjacent from the linear interpolator
        idx = newnan > 0
        
        # set those positions to nan
        ret[idx] = np.nan
        
        # all other positions, interpolate from the spline, 
        # passing kwargs (including k = spline degree,ext = extrapolation mode)
        ret[~idx] = super().__call__(x[~idx], **kwargs)
        return ret

## Pupil Preprocessing

https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/behavior.py#L64

In [15]:
# convert loading parameters to function internal variable names
duration = response_integration_window
filter_type = behavior_filter
DEEP_LAB_CUT = 2
tracking_method = DEEP_LAB_CUT
tolerance = trace_tolerance

# sample point is midpoint of trial extraction window (takes offset into account)
sample_point = (response_integration_offset + response_integration_window / 2)

# find sample point for each stimulus trial by adding to onsets
stimulus_mid_points = stimulus_onsets + sample_point

# load radius and xy centers of fitted circles to the pupil using Deep Lab Cut
r, center = (pupil.FittedPupil.Circle & 
             {**scan_key, "tracking_method": tracking_method}).fetch("radius", "center", 
                                                                     order_by="frame_id")
# identify frames missing a fitted circle (indicated by nans)
detectedFrames = ~np.isnan(r)

# reformat centers from: [(x,y) or nan] to: [(x,y) or (nan,nan)]
xy = np.full((len(r), 2), np.nan)
xy[detectedFrames, :] = np.vstack(center[detectedFrames])

# linearly interpolate across small nan gaps in x and y traces independently, 
# preserving nan gaps with length >= 3
xy = np.vstack(map(partial(fill_nans, preserve_gap=3), xy.T))
if np.any(np.isnan(xy)):
    print("Keeping some nans in the pupil location trace")

# linearly interpolate across small nan gaps in radius trace, preserving nan gaps with length >= 3
pupil_radius = fill_nans(r.squeeze(), preserve_gap=3)
if np.any(np.isnan(pupil_radius)):
    print("Keeping some nans in the pupil location trace")
radius = pupil_radius
    
# fetch the frame times of the eye camera in the behavior clock
eye_time = (pupil.Eye() & scan_key).fetch1("eye_time").squeeze()

# fetch the scan frame times in the behavior clock
behavior_clock = (stimulus.BehaviorSync & scan_key).fetch1('frame_times').squeeze()[::ndepth]

# trim scan frame times in stimulus clock (frame times) and scan frame times in behavior clock (behavior_clock)
# to be the same length, allowing interpolation between the two clocks. 
if len(frame_times) - len(behavior_clock) != 0:
    assert (
        abs(len(frame_times) - len(behavior_clock)) < 2
    ), "Difference bigger than 2 time points"
    l = min(len(frame_times), len(behavior_clock))
    frame_times = frame_times[:l]
    behavior_clock = behavior_clock[:l]

# interpolation object for moving from frame times in stimulus clock to frame times in behavior clock
# k = 1, degree of spline
# ext = 3 -> extrapolation mode 3 'constant', return the boundary value
fr2beh = NaNSpline(frame_times, behavior_clock, k=1, ext=3)


# create boxcar trace for all stimulus mid points
# -duration/2 to duration/2, step length = tolerance
print("Extracting pupil signal using boxcar")
assert len(stimulus_mid_points.shape) == 1, "stimulus_points need to be a 1d array"
dt = np.arange(-duration / 2, duration / 2, tolerance)[:, None]
T = dt + stimulus_mid_points[None, :]

# convert boxcar windows from stimulus clock to behavior clock
T = fr2beh(T.ravel()).reshape(T.shape)

# create function for interpolating boxcar windows into signal traces, 
# using nearest interpolation
# then taking mean across boxcar period
def filter(signal):
    upsampler = interp1d(eye_time, signal, kind="nearest")
    return upsampler(T).mean(axis=0)

# use to extract mean signal for all boxcar windows from the pupil radius, delta pupil radius, 
# and pupil center traces
pupil = filter(radius)
dpupil = filter(np.gradient(radius))
center = np.vstack([filter(coord) for coord in xy])

# for any trial containing a nan in one of the three pupil traces
# set the values of that trial to -1 in all three traces
# mark that trial as invalid in valid_eye vector
# note, this is after linearly interpolating across nan gaps with length <3
# and will include nans inherited from nearest neighboring positions due to interp1d
valid = ~np.isnan(pupil + dpupil + center.sum(axis=0))
if not np.all(valid):
    print("Found {} NaN trials. Setting to -1".format((~valid).sum()))
    pupil[~valid] = -1
    dpupil[~valid] = -1
    center[:, ~valid] = -1
pupil_center = center.T
valid_eye = valid


Keeping some nans in the pupil location trace
Keeping some nans in the pupil location trace


  


## Treadmill Preprocessing
https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/behavior.py#L137

reuses interpolator / sample points from pupil processing above

In [None]:
# fetch treadmill time and velocity for scan
treadmill_time,treadmill_signal = (treadmill.Treadmill() & scan_key).fetch1("treadmill_time", "treadmill_vel")
treadmill_time,treadmill_signal = treadmill_time.squeeze(), treadmill_signal.squeeze()

# create boxcar trace for all stimulus mid points
# -duration/2 to duration/2, step length = tolerance
print("Extracting treadmill signal using boxcar")
assert len(stimulus_mid_points.shape) == 1, "stimulus_points need to be a 1d array"
dt = np.arange(-duration / 2, duration / 2, tolerance)[:, None]
T = dt + stimulus_mid_points[None, :]

# convert boxcar windows from stimulus clock to behavior clock
T = fr2beh(T.ravel()).reshape(T.shape)

# interpolate boxcar windows into treadmill_signal using nearest interpolation
# then take mean across boxcar period
upsampler = interp1d(treadmill_time, treadmill_signal, kind="nearest")
tm = upsampler(T).mean(axis=0)

# for any trial containing a nan in the mean treadmill value
# set the value of that trial to -1
# mark this trial as invalid in valid_treadmill vector
# note that treadmill values are not preprocessed by nan interpolation utility fill_nans
valid = ~np.isnan(tm)
if not np.all(valid):
    warn("Found {} NaN trials. Setting to -1".format((~valid).sum()))
    tm[~valid] = -1
valid_treadmill = valid

# rename tread_mill_values
tread_mill_values = tm



## Trial Exclusion for Behavior

https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/imagenet.py#L200

In [12]:
# create vector, default all trials included
valid_signals = np.ones(len(stimulus_onsets), dtype=bool)

# join with valid trials due to nan detection in either pupil or treadmill traces
valid_signals &= valid_eye
valid_signals &= valid_treadmill

# join behavior as a single matrix
behavior = np.c_[pupil, dpupil, tread_mill_values]

# if any invalid trials, and flag for dropping invalid trials due to behavior = True
if not np.all(valid_signals) and drop_invalid_behavior:
    non_valid = (~valid_signals).sum()
    print(f"Dropping {non_valid} trials because of missing signals in behavioral recordings")
    
    # remove invalid trials from stimulus onsets, trial_keys, behavior traces, pupil center traces, and 
    # valid trial tracker 
    stimulus_onsets = stimulus_onsets[valid_signals]
    trial_keys = list(compress(trial_keys, valid_signals))
    if include_behavior:
        behavior = behavior[valid_signals]
        pupil_center = pupil_center[valid_signals]
        valid_signals = valid_signals[valid_signals]


## Complete Trace Preprocessing

https://github.com/sinzlab/nexport/blob/80d66912bcb6785f8c54ec596e70de4656ac964a/nexport/exporters/utils/traces.py#L74

Follow behavior preprocessing to allow for removal of trials with nans in behavior traces.

In [None]:
# sample point is midpoint of trial extraction window (takes offset into account)
sample_point = (response_integration_offset + response_integration_window / 2)

# find sample point for each stimulus trial by adding to onsets
stimulus_points = stimulus_onsets + sample_point

sampling_period = response_integration_window
print(f"Generating lowpass filters to {1/sampling_period}Hz")

responses = []
assert len(stimulus_points.shape) == 1, "stimulus_points need to be a 1d array"
for frame_time, trace in tqdm(zip(frame_times_mat, traces), desc="Sampling traces", total=len(traces)):
    # Upsampling and convolution with boxcar is too slow. Instead we "upsample" time at the size of the
    # boxcar symmetrically around zero, add that window to all sampling points, extract the nearest
    # interpolation of the original signal around those points and average across each window. This is orders
    # of magnitude faster

    # create boxcar trace for all stimulus mid points
    # -duration/2 to duration/2, step length = tolerance
    dt = np.arange(-sampling_period / 2, sampling_period / 2, tolerance)[:, None]
    T = dt + stimulus_points[None, :]
    
    # interpolate boxcar windows into trace signals using nearest interpolation
    # then take mean across boxcar period
    upsampler = interp1d(frame_time, trace, kind="nearest")
    responses.append(upsampler(T).mean(axis=0))
responses = np.vstack(responses).T

In [None]:
{'img_size': (144, 256),
  'response_integration_window': 0.5,
  'response_integration_offset': 0.05,
  'include_behavior': True,
  'neuro_trace_filter': 'boxcar',
  'behavior_filter': 'boxcar',
  'stack_coordinates': False,
  'gamma_correction': False,
  'trace_tolerance': 0.01}

In [None]:
import shutil
from itertools import compress
from operator import itemgetter
from pathlib import Path

import numpy as np
from tqdm import tqdm
from .. import logger as log
from ..schemas.bcm.stim_info import ConditionTier
from . import Exporter
from .utils import DEEP_LAB_CUT
from .utils.behavior import extract_pupil, extract_treadmill
from .utils.frames import load_frame_and_type, rescale_frame, stack_frames
from .utils.statistics import run_stats
from .utils.stimulus import (
    get_trialkeys_and_stimulus_onsets,
    get_extra_stimulus_info,
    get_album,
)
from .utils.storage import save_dict_to_hdf5, save_to_folder, zip_dir
from .utils.traces import (
    extract_traces,
    get_neuron_info,
    get_stack_coordinates,
    get_xmatches,
    extract_neural_responses,
)
from .utils.verification import (
    pupil_present,
    treadmill_present,
    album_present,
    behavior_synced,
    area_and_layers_consistent,
    has_motor_coordinates,
    has_stack_coordinates,
    has_xmatch_coordinates,
    has_traces,
    has_gamma_correction,
    soma_classiafication_present,
    has_condition_tier_assigned,
)
from ..cajal import stimulus, multi_mei


class ImageNet(Exporter):
    _base_stimulus = None

    _valid_types = [
        stimulus.Frame,
        stimulus.MonetFrame,
        stimulus.TrippyFrame,
        stimulus.ColorFrameProjector,
    ]

    def __init__(
        self,
        img_size=(36, 64),
        response_integration_window=0.5,  # in seconds
        response_integration_offset=0.05,  # in seconds
        include_behavior=True,
        tracking_method=DEEP_LAB_CUT,  # TODO: replace with more concrete values/string
        neuro_trace_filter="hamming",
        behavior_filter="hamming",
        stack_coordinates=False,
        gamma_correction=False,
        trace_tolerance=0.01,
        drop_invalid_behavior=True,
    ):
        self.img_size = img_size
        self.response_integration_window = response_integration_window
        self.response_integration_offset = response_integration_offset
        self.include_behavior = include_behavior
        self.tracking_method = tracking_method
        self.behavior_filter = behavior_filter
        self.stack_coordinates = stack_coordinates
        self.gamma_correction = gamma_correction
        self.neuro_trace_filter = neuro_trace_filter
        self.trace_tolerance = trace_tolerance
        self.drop_invalid_behavior = drop_invalid_behavior
        self._cache = {}

    def filename(self, scan_key, confighash):
        """
        Get the filename corresponding to the particular scan and exporter config as indicated by the `confighash`
        Args:
            scan_key (dict): A dictionary corresponding to DataJoint key for a single Scan table entry
            confighash (str): A string to uniquely identify a specific exporter configuration
        Returns:
            str: The filename to adequately identify the particular scan to be exported with this particular exporter
        """
        return "static{animal_id}-{session}-{scan_idx}-{exporter_name}-{confighash}".format(
            exporter_name=self.__class__.__name__, confighash=confighash, **scan_key
        )

    def clean_cache(self):
        self._cache = {}

    def to_hdf5(self, data, scan_key, confighash, overwrite=False, base_path="."):
        filename = self.filename(scan_key, confighash) + ".h5"
        full_path = Path(base_path) / filename
        if full_path.exists() and not overwrite:
            raise FileExistsError(f"File {full_path} already exists")
        save_dict_to_hdf5(data, full_path)

    def to_folder(
        self,
        data,
        scan_key,
        confighash,
        overwrite=False,
        zip=True,
        base_path=".",
    ):
        dirname = self.filename(scan_key, confighash)
        full_path = Path(base_path) / dirname
        if full_path.exists() and not overwrite:
            raise FileExistsError(f"Path {full_path} already exists")
        else:
            full_path.mkdir(exist_ok=True, parents=True)
        save_to_folder(
            data, full_path, overwrite=overwrite, include_behavior=self.include_behavior
        )
        if zip:
            zip_dir(full_path.with_suffix(".zip"), full_path)
            shutil.rmtree(full_path)

    def verify(self, scan_key):
        if self.include_behavior:
            behavior_synced(scan_key)
            pupil_present(scan_key)
            treadmill_present(scan_key)

        has_condition_tier_assigned(scan_key, valid_conditions=self._valid_types)
        # Adapt to work with all valid stimulus types
        album_present(scan_key, self._base_stimulus)
        has_motor_coordinates(scan_key)
        area_and_layers_consistent(scan_key)
        has_traces(scan_key)
        soma_classification_present(scan_key)

        if self.stack_coordinates:
            has_stack_coordinates(scan_key)

        if self.gamma_correction:
            has_gamma_correction(scan_key)

    def export(self, scan_key):
        self.clean_cache()
        final_dataset = {}
        # --- Preprocess frames
        frames = {}
        frame_types = {}
        frame_tiers = {}

        skipped = 0
        all_conditions = (
            stimulus.Condition * ConditionTier
            & (stimulus.Trial & scan_key & self._valid_types)
        ).fetch("KEY", "tier")
        for cond_key, tier in tqdm(
            zip(*all_conditions),
            total=len(all_conditions[0]),
            desc="Loading frames",
        ):
            frame, frame_type = load_frame_and_type(cond_key)

            condition_hash = cond_key["condition_hash"]
            frames[condition_hash] = rescale_frame(frame, self.img_size)
            frame_tiers[condition_hash] = tier
            frame_types[condition_hash] = frame_type

        # --- Preprocess responses
        # get middle of extraction window
        sample_point = (
            self.response_integration_offset + self.response_integration_window / 2
        )
        log.info(
            f"Sampling neural responses at {self.response_integration_window}s intervals"
        )

        # get interpolation of neurotraces along with key for each trace and min/max frametimes
        traces, trace_keys, frame_times, ftmin, ftmax = extract_traces(scan_key)
        trace_keys = [dict(scan_key, **trace_key) for trace_key in trace_keys]
        self._cache["trace_keys"] = trace_keys

        neuron_info = get_neuron_info(scan_key, trace_keys)
        if self.stack_coordinates:
            log.info("Adding multi-unit matches and coordinates")
            stack_info = get_stack_coordinates(scan_key, trace_keys)
            neuron_info.update(**stack_info)

        # get stimulus fliptimes and keys for single trials
        trial_keys, stimulus_onsets = get_trialkeys_and_stimulus_onsets(
            scan_key=scan_key, ftmin=ftmin, ftmax=ftmax, restr=ConditionTier
        )

        # --- extract eye values
        valid_signals = np.ones(len(stimulus_onsets), dtype=bool)
        if self.include_behavior:
            pupil, dpupil, pupil_center, valid_eye = extract_pupil(
                scan_key,
                stimulus_onsets + sample_point,
                self.response_integration_window,
                self.behavior_filter,
                self.tracking_method,
                tolerance=self.trace_tolerance,
            )
            valid_signals &= valid_eye

            # --- extract treadmill values
            tread_mill_values, valid_treadmill = extract_treadmill(
                scan_key,
                stimulus_onsets + sample_point,
                self.response_integration_window,
                self.behavior_filter,
                tolerance=self.trace_tolerance,
            )
            valid_signals &= valid_treadmill
            behavior = np.c_[pupil, dpupil, tread_mill_values]

        if not np.all(valid_signals) and self.drop_invalid_behavior:
            non_valid = (~valid_signals).sum()
            log.info(
                f"Dropping {non_valid} trials because of missing signals in behavioral recordings"
            )
            stimulus_onsets = stimulus_onsets[valid_signals]
            trial_keys = list(compress(trial_keys, valid_signals))
            if self.include_behavior:
                behavior = behavior[valid_signals]
                pupil_center = pupil_center[valid_signals]
                valid_signals = valid_signals[valid_signals]

        hashes = np.array(
            list(map(itemgetter("condition_hash"), trial_keys)), dtype=str
        )
        tiers = np.array([frame_tiers[h] for h in hashes]).astype(str)
        types = np.array([frame_types[h] for h in hashes]).astype(str)
        trial_idxs = np.array(list(map(itemgetter("trial_idx"), trial_keys)), dtype=int)

        responses = extract_neural_responses(
            stimulus_onsets + sample_point,
            frame_times,
            traces,
            self.response_integration_window,
            filter_type=self.neuro_trace_filter,
            tolerance=self.trace_tolerance,
        )

        images = stack_frames([frames[ch] for ch in hashes])
        row_info = get_extra_stimulus_info(scan_key, hashes, trial_idxs, types)

        row_info["album"] = get_album(scan_key, trial_keys, self._base_stimulus)

        # gamma correction
        if self.gamma_correction:
            log.info("Applying gamma correction")
            # This won't work as `get_fs` is not defined anywhere
            # TODO: fix this
            f, f_inv = (multi_mei.ClosestCalibration & scan_key).get_fs()
            images = f(images)

        # --- compute statistics
        response_statistics = run_stats(
            selector=lambda ix: responses[ix],
            types=types,
            ix=(tiers == "train" if "train" in tiers else tiers == "probe"),
            axis=0,
        )
        input_statistics = run_stats(
            selector=lambda ix: images[ix],
            types=types,
            ix=(tiers == "train" if "train" in tiers else tiers == "probe"),
        )

        statistics = dict(images=input_statistics, responses=response_statistics)

        if self.include_behavior:
            # ---- include statistics
            behavior_statistics = run_stats(
                selector=lambda ix: behavior[ix],
                types=types,
                ix=(tiers == "train" if "train" in tiers else tiers == "probe"),
                axis=0,
            )
            eye_statistics = run_stats(
                selector=lambda ix: pupil_center[ix],
                types=types,
                ix=(tiers == "train" if "train" in tiers else tiers == "probe"),
                axis=0,
            )

            statistics["behavior"] = behavior_statistics
            statistics["pupil_center"] = eye_statistics

        # TODO: Gamma correction
        # --- assemble final dataset
        final_dataset.update(
            images=images,
            trial_idxs=trial_idxs,
            responses=responses,
            hashes=hashes,
            tiers=tiers,
            types=types,
            item_info=row_info,
            neurons=neuron_info,
            statistics=statistics,
        )
        if self.include_behavior:
            final_dataset["behavior"] = behavior
            final_dataset["pupil_center"] = pupil_center
            final_dataset["valid_behavior"] = valid_signals

        return final_dataset