In [None]:
import os
import torch
import struct
import scipy
from scipy import io
import numpy as np
import matplotlib.pyplot as plt

from auditory_cortex.dnn_feature_extractor import create_feature_extractor
from auditory_cortex.neural_data import UCDavisDataset, UCSFDataset, create_neural_dataset
from auditory_cortex.dataloader2 import DataLoader
from auditory_cortex.data_assembler import STRFDataAssembler, DNNDataAssembler
from auditory_cortex.encoding import TRF
import numpy as np
%matplotlib inline
import logging

# Configure the logging system
logging.basicConfig(
    level=logging.INFO,  # Set the logging level
    # format="%(name)s - %(message)s",
	# format="%(name)s - %(levelname)s - %(message)s",
)

  torchaudio.set_audio_backend("sox_io")


In [16]:
dataset_name = 'ucdavis'
session_id = 2
dataset = create_neural_dataset(dataset_name, session_id)

In [17]:
dataset.data

{'__header__': b'MATLAB 5.0 MAT-file, Platform: PCWIN64, Created on: Fri Nov 22 16:10:34 2024',
 '__version__': '1.0',
 '__globals__': [],
 'TrialStimData': array([<scipy.io.matlab.mio5_params.mat_struct object at 0x2b360cae7f70>,
        <scipy.io.matlab.mio5_params.mat_struct object at 0x2b360cae7fa0>],
       dtype=object),
 'WM_1': <scipy.io.matlab.mio5_params.mat_struct at 0x2b360dc8f040>,
 'WM_2': <scipy.io.matlab.mio5_params.mat_struct at 0x2b360dc8f070>,
 'WM_3': <scipy.io.matlab.mio5_params.mat_struct at 0x2b360dc8f0a0>,
 'WM_4': <scipy.io.matlab.mio5_params.mat_struct at 0x2b360dc8f0d0>}

In [18]:
dataset.data['TrialStimData'].size

2

In [19]:
dataset.data['TrialStimData'][0]._fieldnames # StimulusType

['StimulusName',
 'StimulusTimeOn',
 'StimulusIsSimultaneous',
 'StimulusPWA',
 'StimulusType',
 'StimulusWhichTrial',
 'StimulusSpeaker',
 'StimulusDirection',
 'StimulusTrialAtten',
 'StimulusTargetdB',
 'StimulusDuration',
 'StimulusCarrierLow',
 'StimulusCarrierHigh',
 'StimulusSeed',
 'StimulusAMPhaseDeg',
 'StimulusAMDepth',
 'StimulusAMHz',
 'StimulusAMHz2',
 'StimulusTonePhaseDeg',
 'StimulusFMSweepStart',
 'StimulusFMSweepEnd',
 'StimulusCarrierType',
 'StimulusCombType',
 'StimulusCombToothBW',
 'StimulusCombToothSpacing',
 'StimulusGapInt',
 'StimulusGapMS',
 'StimulusCompleted',
 'StimulusTimeDone',
 'StimulusCountPerTrial',
 'StimulusResult',
 'StimulusAudiogramPreviousdB',
 'StimulusTimeOff',
 'TrialResult',
 'TrialCorrect',
 'TrialStartBlink',
 'TrialResponseTime',
 'InitRewardMS',
 'InitRewardTime',
 'InitRewardTrial',
 'HoldRewardMS',
 'HoldRewardTime',
 'HoldRewardTrial',
 'RewardMS',
 'RewardTime',
 'RewardTrial',
 'Error',
 'ErrorTime']

In [20]:
np.unique(dataset.data['TrialStimData'][0].StimulusType)

array(['BMM3'], dtype=object)

In [21]:
np.unique(dataset.data['TrialStimData'][1].StimulusType)

array(['BMT3'], dtype=object)

In [56]:
total_trial_repeats = 3
trial_ids = np.random.choice(3, size=num_trials, replace=True)
trial_ids

array([1, 0, 2])

In [57]:
max_num_trials =3
num_trials = 3
trial_ids = np.random.choice(max_num_trials, size=num_trials, replace=True)
trial_ids

array([1, 0, 0])

In [None]:
total_trial_repeats = y_all_trials.shape[0]
	trial_ids = np.arange(total_trial_repeats)
	if test_trial is not None:
		np.random.shuffle(trial_ids)
		trial_ids = trial_ids[:test_trial]

### exploring data...

In [2]:
dataset_name = 'ucsf'
session_id = 200206
dataset_obj = create_neural_dataset(dataset_name, session_id)

# print(f"Experiments: {dataset_obj.exp_tr_data.keys()}")


INFO:auditory_cortex.neural_data.ucsf_data.ucsf_dataset:NeuralData:  Creating object for session: 200206 ... 


INFO:auditory_cortex.neural_data.ucsf_data.ucsf_dataset:Done.


In [5]:
mVocs = False
# stim_duration = dataset_obj.total_stimuli_duration(mVocs)
stim_ids = dataset_obj.get_training_stim_ids(mVocs)
duration = 0
for stim_id in stim_ids:
    duration += dataset_obj.get_stim_duration(stim_id)
print(f"Total stimuli duration: {duration:.2f}")


# print(f"Unique stimuli duration: {stim_duration['unique']:.2f}")
# print(f"Repeated stimuli duration: {stim_duration['repeated']:.2f}")

Total stimuli duration: 1005.54


#### timit stimuli...

In [30]:
mVocs = False
stim_duration = dataset_obj.total_stimuli_duration(mVocs)

print(f"Unique stimuli duration: {stim_duration['unique']:.2f}")
print(f"Repeated stimuli duration: {stim_duration['repeated']:.2f}")

Unique stimuli duration: 762.45
Repeated stimuli duration: 84.83


In [20]:
mVocs = False
repeated=True
spikes = dataset_obj.extract_spikes(repeated=repeated, mVocs=mVocs)

stimuli_type = 'unique' if not repeated else 'repeated'
print(f"For '{stimuli_type}' stimuli: ")
stim_ids = list(spikes.keys())
print(f"Number of stumiuli: {len(stim_ids)}")
channels = list(spikes[stim_ids[0]].keys())
print(f"Number of channels: {len(channels)}")
print(f"Channel IDs: {channels}")
print(f"Shape of spikes: {spikes[stim_ids[0]][channels[0]].shape}")
print(f"Number of repeats: {spikes[stim_ids[0]][channels[0]].shape[0]}")

For 'repeated' stimuli: 
Number of stumiuli: 46
Number of channels: 4
Channel IDs: ['WM_1', 'WM_2', 'WM_3', 'WM_4']
Shape of spikes: (3, 41)
Number of repeats: 3


In [21]:
mVocs = False
repeated=False
spikes = dataset_obj.extract_spikes(repeated=repeated, mVocs=mVocs)

stimuli_type = 'unique' if not repeated else 'repeated'
print(f"For '{stimuli_type}' stimuli: ")
stim_ids = list(spikes.keys())
print(f"Number of stumiuli: {len(stim_ids)}")
channels = list(spikes[stim_ids[0]].keys())
print(f"Number of channels: {len(channels)}")
print(f"Channel IDs: {channels}")
print(f"Shape of spikes: {spikes[stim_ids[0]][channels[0]].shape}")
print(f"Number of repeats: {spikes[stim_ids[0]][channels[0]].shape[0]}")

For 'unique' stimuli: 
Number of stumiuli: 451
Number of channels: 4
Channel IDs: ['WM_1', 'WM_2', 'WM_3', 'WM_4']
Shape of spikes: (1, 39)
Number of repeats: 1


#### mVocs stimuli...

In [10]:
mVocs = True
stim_duration = dataset_obj.total_stimuli_duration(mVocs)

print(f"Unique stimuli duration: {stim_duration['unique']:.2f}")
print(f"Repeated stimuli duration: {stim_duration['repeated']:.2f}")

Unique stimuli duration: 900.25
Repeated stimuli duration: 100.07


In [8]:
mVocs = True
repeated=True
spikes = dataset_obj.extract_spikes(repeated=repeated, mVocs=mVocs)

stimuli_type = 'unique' if not repeated else 'repeated'
print(f"For '{stimuli_type}' stimuli: ")
stim_ids = list(spikes.keys())
print(f"Number of stumiuli: {len(stim_ids)}")
channels = list(spikes[stim_ids[0]].keys())
print(f"Number of channels: {len(channels)}")
print(f"Channel IDs: {channels}")
print(f"Shape of spikes: {spikes[stim_ids[0]][channels[0]].shape}")
print(f"Number of repeats: {spikes[stim_ids[0]][channels[0]].shape[0]}")

For 'repeated' stimuli: 
Number of stumiuli: 153
Number of channels: 4
Channel IDs: ['WM_1', 'WM_2', 'WM_3', 'WM_4']
Shape of spikes: (3, 21)
Number of repeats: 3


In [10]:
mVocs = True
repeated=False
spikes = dataset_obj.extract_spikes(repeated=repeated, mVocs=mVocs)

stimuli_type = 'unique' if not repeated else 'repeated'
print(f"For '{stimuli_type}' stimuli: ")
stim_ids = list(spikes.keys())
print(f"Number of stumiuli: {len(stim_ids)}")
channels = list(spikes[stim_ids[0]].keys())
print(f"Number of channels: {len(channels)}")
print(f"Channel IDs: {channels}")
print(f"Shape of spikes: {spikes[stim_ids[0]][channels[0]].shape}")
print(f"Number of repeats: {spikes[stim_ids[0]][channels[0]].shape[0]}")

For 'unique' stimuli: 
Number of stumiuli: 1415
Number of channels: 4
Channel IDs: ['WM_1', 'WM_2', 'WM_3', 'WM_4']
Shape of spikes: (1, 15)
Number of repeats: 1


In [4]:
unique_stim_ids = dataset_obj.get_stim_ids(mVocs)['unique']

In [5]:
for stim_id in unique_stim_ids:
	stim_dur = dataset_obj.metadata.mVocs_dur_dict[stim_id]
	if np.isnan(stim_dur):
		print(f"Stimulus ID: {stim_id} is NaN")
		break

In [6]:
stim_id

'43-Cue-03-Grunt.wfm'

In [None]:
if stim_id in mVocs_wfm_files:
	print(f"stim file exists")
else:
	print(f"stim file does not exist")

In [15]:
dataset_obj.metadata.mVocs_meta['MSL']._fieldnames

['SoundID',
 'SourceNum',
 'CueNum',
 'Dur',
 'VocType',
 'useThisSound',
 'jeffVocType',
 'actualDur',
 'WFMname',
 'indUsed3',
 'indUsed12',
 'indUsed',
 'attenuationSpeaker1']

In [17]:
dataset_obj.metadata.mVocs_meta['MSL'].Dur.shape

(3046,)

In [19]:
dataset_obj.metadata.mVocs_meta['MSL'].actualDur.shape

(3046,)

In [20]:
dataset_obj.metadata.mVocs_meta['MSL'].WFMname.shape

(3046,)

In [23]:
np.where(dataset_obj.metadata.mVocs_meta['MSL'].WFMname == stim_id)

(array([534]),)

In [31]:
dataset_obj.metadata.mVocs_meta['MSL'].Dur[534]

nan

In [32]:
dataset_obj.metadata.mVocs_meta['MSL'].actualDur[534]

0.792875

In [29]:
aud = dataset_obj.metadata.get_stim_audio(stim_id, mVocs=True)

In [30]:
aud.size/48000

0.792875

In [10]:
mVocs_wfm_files = np.array(os.listdir(os.path.join(dataset_obj.data_dir, 'NIMH_Mvoc_WFM'))).astype(str)
print(f"mVocs waveforms in directory: {mVocs_wfm_files.size}")

mVocs_ids_used = np.array(dataset_obj.metadata.mVocs_ids).astype(str)
print(f"mVocs stimuli used (as per metadata): {mVocs_ids_used.size}")

mVocs waveforms in directory: 3040
mVocs stimuli used (as per metadata): 3046


In [4]:
mVocs = True
repeated=True
spikes = dataset_obj.extract_spikes(repeated=repeated, mVocs=mVocs)

stimuli_type = 'unique' if not repeated else 'repeated'
print(f"For '{stimuli_type}' stimuli: ")
stim_ids = list(spikes.keys())
print(f"Number of stumiuli: {len(stim_ids)}")
channels = list(spikes[stim_ids[0]].keys())
print(f"Number of channels: {len(channels)}")
print(f"Channel IDs: {channels}")
print(f"Shape of spikes: {spikes[stim_ids[0]][channels[0]].shape}")
print(f"Number of repeats: {spikes[stim_ids[0]][channels[0]].shape[0]}")

For 'repeated' stimuli: 
Number of stumiuli: 153
Number of channels: 4
Channel IDs: ['WM_1', 'WM_2', 'WM_3', 'WM_4']
Shape of spikes: (3, 21)
Number of repeats: 3


In [None]:
mVocs = True
repeated=False
spikes = dataset_obj.extract_spikes(repeated=repeated, mVocs=mVocs)

stimuli_type = 'unique' if not repeated else 'repeated'
print(f"For '{stimuli_type}' stimuli: ")
stim_ids = list(spikes.keys())
print(f"Number of stumiuli: {len(stim_ids)}")
channels = list(spikes[stim_ids[0]].keys())
print(f"Number of channels: {len(channels)}")
print(f"Channel IDs: {channels}")
print(f"Shape of spikes: {spikes[stim_ids[0]][channels[0]].shape}")
print(f"Number of repeats: {spikes[stim_ids[0]][channels[0]].shape[0]}")

#### mVocs stimuli discrepency...

In [48]:
timit_wfm_files = os.listdir(os.path.join(dataset_obj.data_dir, 'TIMIT_48000'))
print(f"TIMIT waveforms in directory: {len(timit_wfm_files)}")

timit_ids_used = dataset_obj.metadata.timit_ids
print(f"Timit stimuli used (as per metadata): {len(timit_ids_used)}")

TIMIT waveforms in directory: 499
Timit stimuli used (as per metadata): 497


In [3]:
mVocs_wfm_files = np.array(os.listdir(os.path.join(dataset_obj.data_dir, 'NIMH_Mvoc_WFM'))).astype(str)
print(f"mVocs waveforms in directory: {mVocs_wfm_files.size}")

mVocs_ids_used = np.array(dataset_obj.metadata.mVocs_ids).astype(str)
print(f"mVocs stimuli used (as per metadata): {mVocs_ids_used.size}")

mVocs waveforms in directory: 3040
mVocs stimuli used (as per metadata): 2198


In [4]:
mVocs_ids_rec_data = dataset_obj.exp_stim_ids['BMM3']
mVocs_ids_rec_data = np.concatenate([mVocs_ids_rec_data['unique'], mVocs_ids_rec_data['repeated']]).astype(str)
print(f"mVocs stimuli used (as per recording): {mVocs_ids_rec_data.size}")

files_available = mVocs_ids_rec_data[np.isin(mVocs_ids_rec_data, mVocs_ids_used, invert=False)]
files_not_available = mVocs_ids_rec_data[np.isin(mVocs_ids_rec_data, mVocs_ids_used, invert=True)]
print(f"Files available: {files_available.size}")
print(f"Files not available: {files_not_available.size}")

mVocs stimuli used (as per recording): 1568
Files available: 1116
Files not available: 452


In [68]:
if '03-Cue-263-Coo.wfm' in mVocs_wfm_files:
	print("File available")
else:
	print("File not available")

File available


In [59]:
files_not_available

array(['20-Cue-16-Scream.wfm', '10-Cue-02-Coo.wfm',
       '35-Cue-208-Scream.wfm', '03-Cue-263-Coo.wfm',
       '38-Cue-04-Scream.wfm', '47-Cue-262-Grunt.wfm',
       '12-Cue-57-Grunt.wfm', '26-Cue-65-Coo.wfm', '46-Cue-11-Grunt.wfm',
       '8-Cue-117-Coo.wfm', '42-Cue-04-Grunt.wfm', '22-Cue-136-Coo.wfm',
       '8-Cue-70-Coo.wfm', '25-Cue-23-Scream.wfm', '35-Cue-50-Grunt.wfm',
       '20-Cue-04-Scream.wfm', '38-Cue-02-Scream.wfm',
       '03-Cue-412-Scream.wfm', '15-Cue-152-Coo.wfm', '15-Cue-54-Coo.wfm',
       '47-Cue-81-Coo.wfm', '03-Cue-116-Grunt.wfm', '35-Cue-19-Grunt.wfm',
       '12-Cue-133-Scream.wfm', '03-Cue-255-Coo.wfm',
       '03-Cue-257-Coo.wfm', '04-Cue-16-Coo.wfm', '34-Cue-14-Coo.wfm',
       '28-Cue-17-Coo.wfm', '21-Cue-56-Grunt.wfm', '35-Cue-107-Grunt.wfm',
       '03-Cue-524-Grunt.wfm', '03-Cue-188-Coo.wfm',
       '47-Cue-293-Grunt.wfm', '15-Cue-37-Coo.wfm',
       '12-Cue-23-Scream.wfm', '34-Cue-61-Scream.wfm',
       '8-Cue-110-Coo.wfm', '03-Cue-184-Grunt.wfm',
 

In [55]:
mVocs_ids_rec_data[np.isin(mVocs_ids_rec_data, mVocs_ids_used, invert=False)]

(1116,)

In [56]:
mVocs_ids_rec_data[np.isin(mVocs_ids_rec_data, mVocs_ids_used, invert=True)].shape

(452,)

In [43]:
dataset_obj.exp_stim_ids
total_stim_ids

{'BMM3': {'unique': array(['20-Cue-16-Scream.wfm', '10-Cue-02-Coo.wfm',
         '35-Cue-208-Scream.wfm', ..., '40-Cue-16-Grunt.wfm',
         '03-Cue-488-Coo.wfm', '43-Cue-03-Grunt.wfm'], dtype='<U21'),
  'repeated': array(['03-Cue-04-Coo.wfm', '03-Cue-11-Grunt.wfm', '03-Cue-145-Grunt.wfm',
         '03-Cue-198-Grunt.wfm', '03-Cue-210-Grunt.wfm',
         '03-Cue-214-Coo.wfm', '03-Cue-222-Grunt.wfm', '03-Cue-245-Coo.wfm',
         '03-Cue-254-Coo.wfm', '03-Cue-262-Coo.wfm',
         '03-Cue-288-Scream.wfm', '03-Cue-291-Scream.wfm',
         '03-Cue-30-Grunt.wfm', '03-Cue-303-Coo.wfm',
         '03-Cue-358-Scream.wfm', '03-Cue-381-Scream.wfm',
         '03-Cue-392-Coo.wfm', '03-Cue-393-Scream.wfm',
         '03-Cue-400-Scream.wfm', '03-Cue-418-Scream.wfm',
         '03-Cue-433-Scream.wfm', '03-Cue-446-Coo.wfm',
         '03-Cue-471-Scream.wfm', '03-Cue-501-Scream.wfm',
         '03-Cue-52-Grunt.wfm', '03-Cue-72-Grunt.wfm',
         '03-Cue-75-Grunt.wfm', '04-Cue-01-Coo.wfm', '04-Cue-06

In [None]:
mVocs = True
repeated=True
spikes = dataset_obj.extract_spikes(repeated=repeated, mVocs=mVocs)

stimuli_type = 'unique' if not repeated else 'repeated'
print(f"For '{stimuli_type}' stimuli: ")
stim_ids = list(spikes.keys())
print(f"Number of stumiuli: {len(stim_ids)}")
channels = list(spikes[stim_ids[0]].keys())
print(f"Number of channels: {len(channels)}")
print(f"Channel IDs: {channels}")
print(f"Shape of spikes: {spikes[stim_ids[0]][channels[0]].shape}")
print(f"Number of repeats: {spikes[stim_ids[0]][channels[0]].shape[0]}")

In [13]:
stim_duration = dataset_obj.total_stimuli_duration(mVocs)
print(f"Unique stimuli duration: {stim_duration['unique']:.2f}")
print(f"Repeated stimuli duration: {stim_duration['repeated']:.2f}")

Unique stimuli duration: 762.45
Repeated stimuli duration: 84.83


### setting up the dataset...