Skip to content

Commit

Permalink
FIX: potentially overflow for some test sets. Now explicitly scale th…
Browse files Browse the repository at this point in the history
…e data to the dynamic range of the value representation. Uses ADC resolution (sensitivity), gain (sensitivityCorrection) and baseline (0 offset).
  • Loading branch information
tcpan committed Jun 14, 2024
1 parent c5b65ad commit e96d826
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 28 deletions.
20 changes: 13 additions & 7 deletions waveform_benchmark/formats/dcm_utils/dcm_waveform_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,21 @@ def get_multiplex_array(fileobj : BinaryIO,
baseline = float(ch.ChannelBaseline)
sensitivity = float(ch.ChannelSensitivity)
correction = float(ch.ChannelSensitivityCorrectionFactor)
# nominal = v * gain = (encoded * sensitivity - baseline) - see reader.
# v = nominal * correction
adjustment = sensitivity * correction
if (adjustment != 1.0) and (baseline != 0.0):
arr[ch_idx, ...] = np.where(arr[ch_idx, ...] == padding_value, np.nan, arr[ch_idx, ...] * adjustment + baseline)
elif (adjustment != 1.0):
arr[ch_idx, ...] = np.where(arr[ch_idx, ...] == padding_value, np.nan, arr[ch_idx, ...] * adjustment)
elif (baseline != 0.0):
arr[ch_idx, ...] = np.where(arr[ch_idx, ...] == padding_value, np.nan, arr[ch_idx, ...] + baseline)
base = baseline * correction
# print(" reading ", ch_idx, sensitivity, baseline, correction, adjustment, base)
if (adjustment != 1.0):
if (base != 0.0):
arr[ch_idx, ...] = np.where(arr[ch_idx, ...] == padding_value, np.nan, arr[ch_idx, ...] * adjustment - base)
else:
arr[ch_idx, ...] = np.where(arr[ch_idx, ...] == padding_value, np.nan, arr[ch_idx, ...] * adjustment)
else:
arr[ch_idx, ...] = np.where(arr[ch_idx, ...] == padding_value, np.nan, arr[ch_idx, ...])
if (base != 0.0):
arr[ch_idx, ...] = np.where(arr[ch_idx, ...] == padding_value, np.nan, arr[ch_idx, ...] - base)
else:
arr[ch_idx, ...] = np.where(arr[ch_idx, ...] == padding_value, np.nan, arr[ch_idx, ...])


return cast("np.ndarray", arr)
Expand Down
88 changes: 71 additions & 17 deletions waveform_benchmark/formats/dcm_utils/dcm_waveform_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

import math
from datetime import datetime
import decimal

import warnings
# warnings.filterwarnings("error")

from waveform_benchmark.formats.base import BaseFormat

# dicom3tools currently does NOT validate the IODs for waveform. IT does validate the referencedSOPClassUIDInFile in DICOMDIR file.

# types of waveforms and constraints:
# https://dicom.nema.org/medical/dicom/current/output/chtml/part03/PS3.3.html
# https://dicom.nema.org/medical/dicom/current/output/chtml/part17/chapter_C.html (data organization, and use cases)
Expand Down Expand Up @@ -80,20 +82,20 @@ class DICOMWaveformVR:
class DICOMWaveform8(DICOMWaveformVR):
WaveformBitsAllocated = 8
WaveformSampleInterpretation = "SB"
PaddingValue = int(-128)
PythonDatatype = np.int8
PaddingValue = np.iinfo(PythonDatatype).min

class DICOMWaveform16(DICOMWaveformVR):
WaveformBitsAllocated = 16
WaveformSampleInterpretation = "SS"
PaddingValue = int(-32768)
PythonDatatype = np.int16
PaddingValue = np.iinfo(PythonDatatype).min

class DICOMWaveform32(DICOMWaveformVR):
WaveformBitsAllocated = 32
WaveformSampleInterpretation = "SL"
PaddingValue = int(-2147483648)
PythonDatatype = np.int32
PaddingValue = np.iinfo(PythonDatatype).min


# relevant definitions:
Expand Down Expand Up @@ -448,11 +450,12 @@ def set_waveform_acquisition_info(self, dataset,


# channel_chunks is a list of tuples (channel, chunk).
#
# minmax is a dict of channel to (min, max)
def create_multiplexed_chunk(self, waveforms: dict,
iod: DICOMWaveformIOD,
group: int,
channel_chunk: list,
minmax: dict,
start_time: float,
end_time: float):

Expand Down Expand Up @@ -525,6 +528,11 @@ def create_multiplexed_chunk(self, waveforms: dict,
dtype=iod.VR.PythonDatatype)
# now collect the chunks into an array, generate the multiplex group, and increment the chunk ids.
unprocessed_chunks.sort(key=lambda x: (x[0], x[3], x[4])) # ordered by channel
# using 90% of the dynamic range to avoid overflow
out_min = np.round(float(iod.VR.PaddingValue) * 0.9)
out_max = np.round(float(np.iinfo(iod.VR.PythonDatatype).max) * 0.9)

gains = {}
for channel, chunk_id, start_src, start_target, end_target in unprocessed_chunks:

chunk = waveforms[channel]['chunks'][chunk_id]
Expand All @@ -537,15 +545,39 @@ def create_multiplexed_chunk(self, waveforms: dict,
# values is in original type. nan replaced with PaddingValue then will be multiplied by gain, so divide here to avoid overflow.
# values = np.nan_to_num(np.frombuffer(chunk['samples'][start_src:end_src], dtype=np.dtype(chunk['samples'].dtype)),
# nan = float(iod.VR.PaddingValue) / float(chunk['gain']) )
# per dicom standard
# channelSensitivity is gain * adc resolution
# datavalue * channelsensitivity = nominal value in unit specified. should match input.
# nominal value * sensitivity correction factor = actual value.
#

# get the input values
v = np.frombuffer(chunk['samples'][start_src:end_src], dtype=np.dtype(chunk['samples'].dtype))
gain = float(chunk['gain'])
values = np.where(np.isnan(v), float(iod.VR.PaddingValue), v * gain)


min_v = float(minmax[channel][0])
max_v = float(minmax[channel][1])

# sensitivity, baseline, and scaling factor:
# encoded = (input - min(input)) / (max(input) - min(input)) * (max(output) - min(output)) + min(output)
# = input * scale1 + min(output) - min(input) * scale1
# = input * scale1 + baseline
# where scale1 = (max(output) - min(output)) / (max(input) - min(input),
scale1 = (out_max - out_min) / (max_v - min_v)
# baseline = min(output) - min(input) * scale1
base1 = out_min - min_v * scale1

chan_id = channels[channel]
samples[chan_id][start_target:end_target] = np.where(np.isnan(v), float(iod.VR.PaddingValue), np.round(v * scale1 + base1, decimals=0)).astype(iod.VR.PythonDatatype)
# nominal data = input * gain = (encoded - baseline) / scale1 * gain
# = encoded * scale2 - baseline2
# where scale2 = gain / scale1, baseline2 = baseline * scale2

if channel not in gains.keys():
gains[channel] = set()
gains[channel].add(float(chunk['gain']))

# write out in integer format
# samples[chan_id][start_target:end_target] = np.round(values * float(chunk['gain']), decimals=0).astype(iod.VR.PythonDatatype)
samples[chan_id][start_target:end_target] = np.round(values, decimals=0).astype(iod.VR.PythonDatatype)
# samples[chan_id][start_target:end_target] = np.round(values, decimals=0).astype(iod.VR.PythonDatatype)
# print("chunk shape:", chunk['samples'].shape)
# print("values shape: ", values.shape)
# print("samples shape: ", samples.shape)
Expand Down Expand Up @@ -579,7 +611,21 @@ def create_multiplexed_chunk(self, waveforms: dict,
source.CodingSchemeVersion = "unknown"
source.CodeMeaning = channel

chdef.ChannelSensitivity = 1.0

if len(gains[channel]) > 1:
print("ERROR: Different gains for the same channel is not supported. ", gains[channel])
gain = gains[channel].pop()

# actual data = nominal data / gain.
# baseline: standards def: offset of encoded sample value 0 from actual (nonimal) 0 in same unit as nominal
# set as baseline2
# sensitivity = scale2
# sensitivity correction factor would be 1/gain, so we can recover gain.
sens_corr = str(decimal.Decimal(1.0 / gain))

scale1inv = (minmax[channel][1] - minmax[channel][0]) / (out_max - out_min)
sensitivity = str(decimal.Decimal(gain * scale1inv))
chdef.ChannelSensitivity = sensitivity if len(sensitivity) <= 16 else sensitivity[:16] # gain and ADC resolution goes here
chdef.ChannelSensitivityUnitsSequence = [Dataset()]
units = chdef.ChannelSensitivityUnitsSequence[0]

Expand All @@ -590,14 +636,20 @@ def create_multiplexed_chunk(self, waveforms: dict,
units.CodeMeaning = UCUM_ENCODING[unit] # this needs to be fixed.

# multiplier to apply to the encoded value to get back the orginal input.
ds = str(float(1.0) / float(chunk['gain']))
chdef.ChannelSensitivityCorrectionFactor = ds if len(ds) <= 16 else ds[:16]
chdef.ChannelBaseline = '0'
chdef.ChannelSensitivityCorrectionFactor = sens_corr if len(sens_corr) <= 16 else sens_corr[:16]
# Offset of encoded sample value 0 from actual 0 using the units defined in the Channel Sensitivity Units Sequence (003A,0211).
# baseline2 = baseline * scale2 = baseline * gain * scale1inv = (min(output) - min(input) * scale1) * gain * scale1inv
# = min(output) * gain * scale1inv - min(input) * gain
# = min(output) * sensitivity - min(input) * gain
# nom data = encoded * scale2 - baseline2
baseln = str(decimal.Decimal((out_min * scale1inv - minmax[channel][0]) * gain))
chdef.ChannelBaseline = baseln if len(baseln) <= 16 else baseln[:16]

chdef.WaveformBitsStored = iod.VR.WaveformBitsAllocated
# only for amplifier type of AC
# chdef.FilterLowFrequency = '0.05'
# chdef.FilterHighFrequency = '300'
channeldefs[channel] = chdef
channeldefs[channel] = chdef


wfDS.ChannelDefinitionSequence = channeldefs.values()
Expand All @@ -608,10 +660,12 @@ def create_multiplexed_chunk(self, waveforms: dict,
return wfDS


# minmax is dictionary of channel to (min, max) values
def add_waveform_chunks_multiplexed(self, dataset,
iod: DICOMWaveformIOD,
chunk_info: dict,
waveforms: dict):
waveforms: dict,
minmax: dict):

# dicom waveform -> sequence of multiplex group -> channels with same sampling frequency
# multiplex group can have labels.
Expand All @@ -629,7 +683,7 @@ def add_waveform_chunks_multiplexed(self, dataset,

for group, chanchunks in grouped_channels.items():
# input to this is [(channel, chunk), ... ]
multiplexGroupDS = self.create_multiplexed_chunk(waveforms, iod, group, chanchunks,
multiplexGroupDS = self.create_multiplexed_chunk(waveforms, iod, group, chanchunks, minmax,
start_time=start_time, end_time=end_time)
dataset.WaveformSequence.append(multiplexGroupDS)

Expand Down
20 changes: 16 additions & 4 deletions waveform_benchmark/formats/dicom.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,17 @@ def _pretty_print(self, table: dict):
print(key, ": ", value['start_t'], " ", value['end_t'])
for k, v in value['channel_chunk'].items():
print(" ", k, v)

# get channel min and max values, across chunks.
def _get_waveform_channel_minmax(self, waveforms):
minmax = {}
for channel, wf in waveforms.items():
mins = [ np.nanmin(chunk['samples']) for chunk in wf['chunks'] ]
maxs = [ np.nanmax(chunk['samples']) for chunk in wf['chunks'] ]

minmax[channel] = (np.nanmin(mins), np.nanmax(maxs))
return minmax


def write_waveforms(self, path, waveforms):
fs = FileSet()
Expand Down Expand Up @@ -485,7 +496,8 @@ def write_waveforms(self, path, waveforms):
subchunks1 = self.split_chunks_temporal_merged(channel_table)
# print("merged", len(subchunks1))
# self._pretty_print(subchunks1)


minmax = self._get_waveform_channel_minmax(waveforms)
#========== now write out =============

# count channels belonging to respiratory data this is needed for the iod
Expand Down Expand Up @@ -514,7 +526,7 @@ def write_waveforms(self, path, waveforms):
dicom = self.writer.set_study_info(dicom, studyUID = studyInstanceUID, studyDate = datetime.now())
dicom = self.writer.set_series_info(dicom, iod, seriesUID=seriesInstanceUID)
dicom = self.writer.set_waveform_acquisition_info(dicom, instanceNumber = file_id)
dicom = self.writer.add_waveform_chunks_multiplexed(dicom, iod, chunk_info, waveforms)
dicom = self.writer.add_waveform_chunks_multiplexed(dicom, iod, chunk_info, waveforms, minmax)

# Save DICOM file. write_like_original is required
# these the initial path when added - it points to a temp file.
Expand Down Expand Up @@ -792,13 +804,13 @@ class DICOMHighBitsChunked(DICOMHighBits):
# waveform lead names to dicom IOD mapping. Incomplete.
# avoiding 12 lead ECG because of the limit in number of samples.

chunkSize = 86400.0 # chunk as 1 day.
chunkSize = 3600.0 # chunk as 1 hr.
hifi = True


class DICOMLowBitsChunked(DICOMLowBits):

chunkSize = 86400.0
chunkSize = 3600.0
hifi = False


Expand Down

0 comments on commit e96d826

Please sign in to comment.