Skip to content

Commit

Permalink
Merge pull request #58 from chorus-ai/format_vital
Browse files Browse the repository at this point in the history
Tidying Vital format for the purpose of testing
  • Loading branch information
briangow committed May 8, 2024
2 parents c0e457e + 3daa150 commit f397ff4
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions waveform_benchmark/formats/vital.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def write_waveforms(self, path, waveforms):
vitalfile.dtend = max(vitalfile.dtend, dtend)
samples = chunk['samples']

for istart in range(0, len(samples), round(srate)):
for istart in range(0, len(samples), 240000):
recs.append({'dt': dtstart + istart / srate,
'val': samples[istart:istart+round(srate)]})
'val': samples[istart:istart+240000]})

gain = max(gain, chunk['gain'])
mindisp = min(mindisp, np.nanmin(samples))
Expand All @@ -52,28 +52,29 @@ def read_waveforms(self, path, start_time, end_time, signal_names):
signal_names = [f"{x}/{x}" for x in signal_names]
file_name = f"{path}.vital"
vitalfile = vitaldb.VitalFile(file_name, track_names=signal_names)
vitalfile.crop(start_time, end_time)
results = {}
for dtname, trk in vitalfile.trks.items():
if dtname.find('/') >= 0:
dtname = dtname.split('/')[-1]
for signal_name in signal_names:
trk = vitalfile.trks[signal_name]
if signal_name.find('/') >= 0:
signal_name = signal_name.split('/')[-1]
sample_length = round((end_time - start_time) * trk.srate)
samples = np.empty(sample_length, dtype=np.float32)
samples[:] = np.nan
for i, rec in enumerate(trk.recs):
dtstart = rec['dt']
dtend = rec['dt'] + len(rec['val']) / trk.srate
if dtstart > end_time:
for rec in trk.recs:
rec_start_time = rec['dt']
rec_end_time = rec['dt'] + len(rec['val']) / trk.srate
if rec_start_time > end_time:
break
if i == 0 and start_time > dtstart:
crop_start = round((start_time - dtstart) * trk.srate)
rec['dt'] = start_time
dtstart = start_time
rec['val'] = rec['val'][crop_start:]
st = round((dtstart - start_time) * trk.srate)
et = min(round((dtend - start_time) * trk.srate),
sample_length)
if et > st:
samples[st:et] = rec['val'][:et-st]
results[dtname] = samples
if start_time <= rec_start_time <= end_time or start_time <= rec_end_time <= end_time:
if start_time > rec_start_time and start_time <= rec_end_time < end_time:
crop_start = round((start_time - rec_start_time) * trk.srate)
rec['dt'] = start_time
rec_start_time = start_time
rec['val'] = rec['val'][crop_start:]
st = round((rec_start_time - start_time) * trk.srate)
et = min(round((rec_end_time - start_time) * trk.srate),
sample_length)
if et > st:
samples[st:et] = rec['val'][:et-st]
results[signal_name] = samples
return results

0 comments on commit f397ff4

Please sign in to comment.