In [None]:
from datetime import datetime, date
from uuid import uuid4
from dateutil.tz import tzlocal
from ast import literal_eval

import numpy as np
import pandas as pd
import os
import warnings
import librosa
import pickle

from pynwb import NWBHDF5IO, NWBFile, TimeSeries
from pynwb.file import Subject
from pynwb.behavior import (
    BehavioralEpochs,
    BehavioralEvents,
    BehavioralTimeSeries,
    CompassDirection,
    EyeTracking,
    Position,
    PupilTracking,
    SpatialSeries,
)
from pynwb.epoch import TimeIntervals
from pynwb.image import ImageSeries

from scipy.io import wavfile
import wave

from ndx_manoli_meta import AssayMetadata
from nwb_utils import *
from vak_utils import *
from frequency_stats_utils import *

In [None]:
# ----- USER PARAMETERS -----
# NWB metadata
cohort_tag = 'ScnVoc'
GT = 'Scn2a'
sdate = date(2024,3,14) #20240314
session_description = f'Vocal recordings from cohort: {cohort_tag} with genotype: {GT} starting on date: {sdate}.'
lab='Manoli @ UCSF'

# snippet windowing parameters
winlen = 2
overlap = 0.25

# set parameters for making toml files
defaults_fname = 'vak_defaults.pkl' # move this to a central location
# toml_name = f'{this_rec}-predict.toml'

# set parameters for pretending annotations are full boxes
minfreq = 15000
maxfreq = 65000

# set up pathing
rootpath = 'M:\\vocalizations\\scn2a'
nwbfile_path = os.path.join(rootpath,f'{cohort_tag}_nwb')
analysis_path = os.path.join(rootpath,f'{cohort_tag}_analysis')
acq_path = os.path.join(rootpath,f'{cohort_tag}_acquisition')
pred_path = os.path.join(rootpath,f'{cohort_tag}_predictions')
paths = [nwbfile_path,analysis_path,acq_path,pred_path]
for pathi in paths:
    if not os.path.isdir(pathi):
        os.makedirs(pathi)
        
# load up stored defaults for writing toml files
defaults_file = open(os.path.join('M:\\vocalizations','vak_defaults.pkl'),'rb')
defaults = pickle.load(defaults_file)

# set up metadata table
# metafile = f'metadata_{cohort_tag}.csv'
metafile = 'metadata_ScnVoc_20241208.csv'

# -------- frequency calculation parameters ----------
# times for start and end pad of the snippets
spad = 0 # in seconds
epad = 0

# parameters for Welch PSD
NFFT = 512
noverl = 400
pad = 0.05
cmap = 'jet'
cmin = -60
cmax = 30
n_mfcc = 256
n_mels = n_mfcc
fmin = 20000
fmax = 80000
psdwin = 'hann'

# parameters for mel spec
mel_nfft = 1024
mel_winlen = 900

# false positive thresholding
pwrmin = 0.9

# peak identification parameters
rel_height_param = 0.98 # point on the peak at which to calculate the bandwidth
stdmx = 1
prominence_factor = 2.5 # used to calculate how high frequency peaks need to be

# contour thresholding parameters
thresh = 0.25
grndmeanfac = 1.5

# spectrogram smooothing parameters
wiener_kernel = (4, 4)

# contour fit smoothing parameters
lowess = sm.nonparametric.lowess
fr=0.2
it=4

# mel frequencies
mf = librosa.mel_frequencies(n_mels=n_mels, fmin=fmin, fmax=fmax)

# names and descriptions of stats, in order
allstats = ['callBool','startTime','duration','f0','numberFreqBands','numberPeaks','startFreq','endFreq','minFreq','maxFreq','absMinFreq',
           'absMaxFreq','sinuosity','timeToMaxPower','frequencyAtMaxPower','totalContourLength']
descriptions = ['use call: true for include in analysis, false for false positive',
               'time of call start',
               'duration of call in s',
               'peak frequency of lowest frequency band identified by Welch PSD',
               'number of frequency peaks identified from Welch PSD',
               'number of peaks identified in contours',
               'number of signal peaks identified in contours',
               'frequency at beginning of longest contour',
               'frequency at end of longest contour',
               'lowest contour frequency reached in longest contour',
               'highest contour frequency reached in longest contour',
               'lowest contour frequency reached in call',
               'highest contour frequency reached in call',
               'sinuosity of longest contour',
               'time into call of highest power in s',
               'frequency at max power in Hz',
               'total number of bins with above threshold contour',]

# which call time data to use from the NWB file
use_annos = 'calls_vak_merge_5ms' # corresponds to all merged calls but not adjusted for time

In [None]:
# ----- Generate metadata object based on metadata read in from table -----

def make_metadata_object(atype,excl,dur,room,timeline,etho,exp,timeline_comp,
                         colors,pID,pGT,dpostpair,sID=None,sGT=None,lane=None,chamber=None):
        
    match atype:
        case 'introduction':
            metaObj = AssayMetadata(
                assay_type=atype,
                exclude_flag=excl,
                duration=dur,
                room=room,
                timeline=timeline,
                ethogram=etho,
                experimenter=exp,
                timeline_complete=timeline_comp,
                colors=colors,
                assay_type__partner_ID=pID,
                assay_type__partner_GT=pGT,
                assay_type__description='Standard introduction (vocal series).',
                assay_type__days_post_pairing=dpostpair,
                assay_type__divided=False,            
            )
            
        case 'divided':
            metaObj = AssayMetadata(
                assay_type=atype,
                exclude_flag=excl,
                duration=dur,
                room=room,
                timeline=timeline,
                ethogram=etho,
                experimenter=exp,
                timeline_complete=timeline_comp,
                colors=colors,
                assay_type__partner_ID=pID,
                assay_type__partner_GT=pGT,
                assay_type__description='Vocal recording in home cage with animals separated by a barrier.',
                assay_type__days_post_pairing=dpostpair,
                assay_type__divided=True,            
            )
            
        case 'mating':
            metaObj = AssayMetadata(
                assay_type=atype,
                exclude_flag=excl,
                duration=dur,
                room=room,
                timeline=timeline,
                ethogram=etho,
                experimenter=exp,
                timeline_complete=timeline_comp,
                colors=colors,
                assay_type__partner_ID=pID,
                assay_type__partner_GT=pGT,
                assay_type__description='Standard timed mating (vocal series).',
                assay_type__days_post_pairing=dpostpair,
                assay_type__divided=False,                   
            )
            
        case 'finteract1':
            metaObj = AssayMetadata(
                assay_type=atype,
                exclude_flag=excl,
                duration=dur,
                room=room,
                timeline=timeline,
                ethogram=etho,
                experimenter=exp,
                timeline_complete=timeline_comp,
                colors=colors,
                assay_type__partner_ID=pID,
                assay_type__partner_GT=pGT,
                assay_type__description='Vocal recording in home cage with animals freely interacting.',
                assay_type__days_post_pairing=dpostpair,
                assay_type__divided=False,                   
            )
            
        case 'finteract2':
            metaObj = AssayMetadata(
                assay_type=atype,
                exclude_flag=excl,
                duration=dur,
                room=room,
                timeline=timeline,
                ethogram=etho,
                experimenter=exp,
                timeline_complete=timeline_comp,
                colors=colors,
                assay_type__partner_ID=pID,
                assay_type__partner_GT=pGT,
                assay_type__description='Vocal recording in home cage with animals freely interacting.',
                assay_type__days_post_pairing=dpostpair,
                assay_type__divided=False,               
            )
            
        case 'dividerhalf':
            metaObj = AssayMetadata(
                assay_type=atype,
                exclude_flag=excl,
                duration=dur,
                room=room,
                timeline=timeline,
                ethogram=etho,
                experimenter=exp,
                timeline_complete=timeline_comp,
                colors=colors,
                assay_type__partner_ID=pID,
                assay_type__partner_GT=pGT,
                assay_type__description='Vocal recording in home cage with animals divided for 15 mins then freely interacting for 15 mins.',
                assay_type__days_post_pairing=dpostpair,
                assay_type__divided=True,             
            )
            
        case 'PPT':
            metaObj = AssayMetadata(
                assay_type=atype,
                exclude_flag=excl,
                duration=dur,
                room=room,
                timeline=timeline,
                ethogram=etho,
                experimenter=exp,
                timeline_complete=timeline_comp,
                colors=colors,
                assay_type__partner_ID=pID,
                assay_type__partner_GT=pGT,
                assay_type__description='Standard PPT (vocal series).',
                assay_type__days_post_pairing=dpostpair,
                assay_type__divided=False,
                assay_type__stranger_ID=sID,
                assay_type__stranger_GT=sGT,
                assay_type__PPT_lane=lane,
                assay_type__partner_chamber=chamber,
            )
            
    return metaObj

In [None]:
# ----- Loop over metadata table and write files -----

# whether to write NWB files to disk yet
write_NWB_to_disk = True

# load metadata
meta = pd.read_csv(os.path.join(analysis_path,metafile),sep=',')
meta.FocalColor = meta.FocalColor.apply(literal_eval) # convert the colors to real arrays

# -- loop over metadata
for i, ptag in enumerate(meta.PairTag):
    assay_type = meta.AssayType[i]
    nwbfilename = f'{ptag}_{assay_type}.nwb'
    print(nwbfilename)
    
    # check if file already exists
    wfullpath = os.path.join(nwbfile_path,nwbfilename)    
    if not os.path.exists(wfullpath):    
        
        # get session specific metadata
        thisdate = str(meta.RecDate[i])
        pairdate = str(meta.PairDate[i])

        # set up recording time... it would be nice to get actual video data for the times
        datepieces = get_date_from_block(thisdate)
        rtime = meta.RecTime[i]
        timepieces = rtime.split(':')
        sess_start = datetime(datepieces[0],datepieces[1],datepieces[2],int(timepieces[0]),int(timepieces[1]),0,0,tzlocal())

        session_description = f'Behavioral annotations from pair {ptag} in a(n) {assay_type} assay.'
        
        # calculate days post pairing
        rdate = date(int(thisdate[0:4]),int(thisdate[4:6]),int(thisdate[6:]))
        pdate = date(int(pairdate[0:4]),int(pairdate[4:6]),int(pairdate[6:]))
        dpp = rdate-pdate

        # make NWB file
        nwbfile = NWBFile(
            session_description=session_description,
            identifier = str(uuid4()),
            session_start_time = sess_start,
            lab=lab,
            experimenter=meta.RanBy[i],
            session_id = nwbfilename[0:-4], # check this
        )

        # add subject info
        nwbfile.subject = Subject(
            subject_id = meta.FocalID[i],
            species = 'Microtus ochrogaster',
            sex = meta.FocalSex[i],
            genotype = meta.FocalGT[i]
        )

        # get assay duration
        duration = float(meta.AssayDuration[i])

        # figure out partner info
        if meta.FocalSex[i]=='F':
            pID = meta.MaleID[i]
            pGT = meta.MaleGT[i]
        elif meta.FocalSex[i]=='M':
            pID = meta.FemaleID[i]
            pGT = meta.MaleGT[i]
        else:
            print(f'Focal sex is neither F nor M; something is wrong with {ptag}.')

        # Make lab metadata object lab metadata
        if meta.AssayType[i]=='PPT':
            sID = meta.StrangerID[i]
            sGT = meta.StrangerGT[i]
            lane = int(meta.PPTlane[i])
            chamb = meta.PartnerChamber[i]
            metaObj = make_metadata_object(meta.AssayType[i],False,float(meta.AssayDuration[i]),meta.AssayRoom[i],meta.Timeline[i],
                                           meta.Ethogram[i],meta.RanBy[i],meta.FullTimeline[i],meta.FocalColor[i],pID,pGT,dpp.days,
                                           sID=sID,sGT=sGT,lane=lane,chamber=chamb)
        else:
            metaObj = make_metadata_object(meta.AssayType[i],False,float(meta.AssayDuration[i]),meta.AssayRoom[i],meta.Timeline[i],
                                           meta.Ethogram[i],meta.RanBy[i],meta.FullTimeline[i],meta.FocalColor[i],pID,pGT,dpp.days)

        # Add the test LabMetaDataExtensionExample to the NWBFile
        nwbfile.add_lab_meta_data(lab_meta_data=metaObj)

        # Add video file
        vid_path = os.path.join(meta.VideoPath[i],meta.VideoFile[i])
        vid_rel_path = os.path.relpath(vid_path,nwbfile_path)
        
        video_ext_file = ImageSeries(
            name='behaviorVideo',
            description='Raw original video.',
            unit='n.a.',
            external_file=[vid_rel_path],
            format='external',
            starting_time=meta.VideoAssayStart[i],
            rate=25.0,
        )
        
        # add to NWB file
        nwbfile.add_acquisition(video_ext_file)
        
        # get session specific audio data
        if type(meta.AudioFile[i])==str: # check whether there is an audio file
            aud_path = os.path.join(meta.AudioPath[i],meta.AudioFile[i])
            rel_path = os.path.relpath(aud_path,nwbfile_path)
            with wave.open(aud_path, "rb") as wave_file: # find sample rate
                sampling_rate = wave_file.getframerate()
            Fs = float(sampling_rate)

            # set up acquisition object
            aud_ext_file = ImageSeries( 
                name='behaviorAudio',
                description='Raw freefield audio',
                unit='n.a.',
                external_file=[rel_path],
                format='external',
                starting_time=meta.AudioAssayStart[i],
                rate=Fs,
            )

            # add to NWB file
            nwbfile.add_acquisition(aud_ext_file)
      
        if i==12:
            testfile = nwbfile
        
        # write file to disk
        if write_NWB_to_disk:
            with NWBHDF5IO(wfullpath, "w") as io:
                io.write(nwbfile)