# Dataloader
This object takes in neural dataset and feature extractor objects and provides high level acccess to DNN features and neural spikes. Since extracting hidden layer features from all the layers of the DNN for all the stimuli takes time, the dataloader method implements features caching mechanism. It is recommended to call the get features method once at the start, that makes sure they are saved to the cache memory and sub-sequent access is faster.

In [2]:
from auditory_cortex.utils import set_up_logging
set_up_logging()

from auditory_cortex.neural_data import create_neural_dataset
from auditory_cortex.dnn_feature_extractor import create_feature_extractor
from auditory_cortex.dataloader2 import DataLoader

dataset_name = 'ucdavis'
session_id = 3
neural_dataset = create_neural_dataset(dataset_name, session_id)

model_name = 'whisper_tiny'
feature_extractor = create_feature_extractor(model_name)

dataloader = DataLoader(neural_dataset, feature_extractor)


    If you do not have SoX, proceed here:
     - - - http://sox.sourceforge.net/ - - -

    If you do (or think that you should) have SoX, double-check your
    path variables.
    


/bin/sh: line 1: sox: command not found
  torchaudio.set_audio_backend("sox_io")


INFO:Changing convolution kernels for: whisper_tiny


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [18]:
bin_width = 50
repeated=False
mVocs=False
spikes = dataloader.get_session_spikes(
    bin_width, repeated=repeated, mVocs=mVocs
)

In [19]:
stim_ids = list(spikes.keys())
print(f"Number of stimulus IDs: {len(stim_ids)}")

Number of stimulus IDs: 451


In [20]:
stim_ids = list(spikes.keys())
channel_ids = list(spikes[stim_ids[0]].keys())
print(channel_ids)

[1001, 1002, 201, 202, 2001, 301, 3001, 4001, 4002]


In [21]:
stim_ids = list(spikes.keys())
channel_ids = list(spikes[stim_ids[0]].keys())
print(spikes[stim_ids[0]][channel_ids[0]].shape)

(1, 41)


#### DNN features

In [3]:
bin_width=50
mVocs=False
features = dataloader.get_resampled_DNN_features(
    bin_width, mVocs=mVocs
)

INFO:Reading features for model: whisper_tiny
INFO:Resamping ANN features at bin-width: 50


In [5]:
features.keys()

dict_keys([0, 1, 2, 3, 4, 5])

In [6]:
layer_ids = list(features.keys())
features[layer_ids[0]].keys()

dict_keys(['319-mjsw0_si1640.wfm', '117-fpls0_si1590.wfm', '269-mgsl0_si534.wfm', '225-mdrd0_si1382.wfm', '126-fsbk0_si1699.wfm', '403-mreb0_si745.wfm', '286-mjdg0_si1705.wfm', '175-mbom0_si2274.wfm', '191-mcre0_si1725.wfm', '285-mjde0_si463.wfm', '274-milb0_si2163.wfm', '405-mrem0_si961.wfm', '40-fdrd1_si1544.wfm', '219-mdma0_si1238.wfm', '458-mtab0_si2202.wfm', '377-mmws1_si1701.wfm', '176-mbpm0_si1577.wfm', '390-mprd0_si2061.wfm', '158-marc0_si1818.wfm', '316-mjrk0_si1662.wfm', '26-fcmr0_si1735.wfm', '157-marc0_si1188.wfm', '301-mjlb0_si2246.wfm', '310-mjpg0_si1191.wfm', '93-flmc0_si2002.wfm', '122-frjb0_si1794.wfm', '90-fljg0_si2241.wfm', '309-mjmp0_si1535.wfm', '493-mwew0_si731.wfm', '428-mrms0_si2100.wfm', '476-mtqc0_si480.wfm', '106-fmkc0_si1702.wfm', '384-mpeb0_si1034.wfm', '238-mdwh0_si1925.wfm', '162-mbbr0_si1685.wfm', '217-mdls0_si2258.wfm', '395-mrab0_si594.wfm', '146-mabc0_si1620.wfm', '59-fjhk0_si1652.wfm', '19-fcjs0_si2237.wfm', '349-mmab1_si2124.wfm', '154-majp0_si1074.

In [7]:
layer_ids = list(features.keys())
stim_ids = list(features[layer_ids[0]].keys())
features[layer_ids[0]][stim_ids[0]].shape    #(time, feature_dim)

(33, 384)