diff --git a/examples/scripts/write_gre.py b/examples/scripts/write_gre.py index 8085c2f8..47500dfb 100644 --- a/examples/scripts/write_gre.py +++ b/examples/scripts/write_gre.py @@ -5,7 +5,7 @@ import pypulseq as pp -def main(plot: bool = False, write_seq: bool = False, seq_filename: str = 'gre_pypulseq.seq'): +def main(plot: bool = False, write_seq: bool = False, seq_filename: str = 'gre_pypulseq.seq', paper_plot: bool = False): # ====== # SETUP # ====== @@ -121,7 +121,10 @@ def main(plot: bool = False, write_seq: bool = False, seq_filename: str = 'gre_p # VISUALIZATION # ====== if plot: - seq.plot() + if paper_plot: + seq.paper_plot() + else: + seq.plot() seq.calculate_kspace() @@ -144,4 +147,4 @@ def main(plot: bool = False, write_seq: bool = False, seq_filename: str = 'gre_p if __name__ == '__main__': - main(plot=False, write_seq=True) + seq = main(plot=True, paper_plot=True, write_seq=False) diff --git a/pyproject.toml b/pyproject.toml index 448386ff..448cb54b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ [project.optional-dependencies] sigpy = ["sigpy>=0.1.26"] +mplcursors = ["mplcursors"] test = [ "coverage", "codecov", diff --git a/src/pypulseq/Sequence/block.py b/src/pypulseq/Sequence/block.py index 9077d0d5..87635a1c 100644 --- a/src/pypulseq/Sequence/block.py +++ b/src/pypulseq/Sequence/block.py @@ -272,7 +272,51 @@ def set_block(self, block_index: int, *args: SimpleNamespace) -> None: self.block_durations[block_index] = float(duration) -# TODO: refactor to get_raw_block_content_id + get_block +def get_raw_block_content_IDs(self, block_index: int) -> SimpleNamespace: + """ + Returns PyPulseq block content IDs at `block_index` position in `self.block_events`. + + No block events are created, only the IDs of the objects are returned. + + Parameters + ---------- + block_index : int + Index of PyPulseq block to be retrieved from `self.block_events`. + + Returns + ------- + block : SimpleNamespace + PyPulseq block content IDs at 'block_index' position in `self.block_events`. + """ + raw_block = SimpleNamespace(block_duration=0, rf=0, gx=0, gy=0, gz=0, adc=0, ext=[]) + event_ind = self.block_events[block_index] + + # Extensions + if event_ind[6] > 0: + next_ext_id = event_ind[6] + while next_ext_id != 0: + ext_data = self.extensions_library.data[next_ext_id] + raw_block.ext.append(ext_data[:2]) + next_ext_id = ext_data[2] + raw_block.ext = np.stack(raw_block.ext, axis=-1) + + # RF + if event_ind[1] > 0: + raw_block.rf = event_ind[1] + + # Gradients + grad_channels = ['gx', 'gy', 'gz'] + for i in range(len(grad_channels)): + if event_ind[2 + i] > 0: + setattr(raw_block, grad_channels[i], event_ind[2 + i]) + + # ADC + if event_ind[5] > 0: + raw_block.adc = event_ind[5] + + return raw_block + + def get_block(self, block_index: int) -> SimpleNamespace: """ Returns PyPulseq block at `block_index` position in `self.block_events`. diff --git a/src/pypulseq/Sequence/sequence.py b/src/pypulseq/Sequence/sequence.py index 614088d3..77a8bc3a 100644 --- a/src/pypulseq/Sequence/sequence.py +++ b/src/pypulseq/Sequence/sequence.py @@ -1,4 +1,3 @@ -import itertools import math from collections import OrderedDict from copy import deepcopy @@ -13,9 +12,7 @@ Self = TypeVar('Self', bound='Sequence') -import matplotlib as mpl import numpy as np -from matplotlib import pyplot as plt from scipy.interpolate import PPoly from pypulseq import __version__, eps @@ -25,15 +22,16 @@ from pypulseq.decompress_shape import decompress_shape from pypulseq.event_lib import EventLibrary from pypulseq.opts import Opts -from pypulseq.Sequence import block, parula +from pypulseq.Sequence import block from pypulseq.Sequence.calc_grad_spectrum import calculate_gradient_spectrum from pypulseq.Sequence.calc_pns import calc_pns from pypulseq.Sequence.ext_test_report import ext_test_report from pypulseq.Sequence.install import detect_scanner from pypulseq.Sequence.read_seq import read from pypulseq.Sequence.write_seq import write as write_seq -from pypulseq.supported_labels_rf_use import get_supported_labels from pypulseq.utils.cumsum import cumsum +from pypulseq.utils.paper_plot import paper_plot as ext_paper_plot +from pypulseq.utils.seq_plot import SeqPlot from pypulseq.utils.tracing import format_trace, trace, trace_enabled major, minor, revision = __version__.split('.')[:3] @@ -606,6 +604,33 @@ def evaluate_labels(self, init: Union[dict, None] = None, evolution: str = 'none return labels + import numpy as np + + def find_block_by_time(self, t: float) -> int: + """ + Find the index of the block containing time `t`. + + Parameters + ---------- + t : float + Time (in seconds) to locate within the sequence. + + Returns + ------- + int or None + Index of the block that contains the given time, or None if out of range. + """ + cumsum_durations = np.cumsum(list(self.block_durations.values())) + block_index = np.searchsorted(cumsum_durations, t, side='right').item() + + if block_index >= len(self.block_durations): + return None + + if self.block_durations[block_index] <= 0: + raise ValueError('Block duration cannot be negative') + + return block_index + def flip_grad_axis(self, axis: str) -> None: """ Invert all gradients along the corresponding axis/channel. The function acts on all gradient objects already @@ -618,6 +643,28 @@ def flip_grad_axis(self, axis: str) -> None: """ self.mod_grad_axis(axis, modifier=-1) + def get_raw_block_content_IDs(self, block_index: int) -> SimpleNamespace: + """ + Returns PyPulseq block content IDs at `block_index` position in `self.block_events`. + + No block events are created, only the IDs of the objects are returned. + + See Also + -------- + - `pypulseq.Sequence.sequence.Sequence.get_block()`. + + Parameters + ---------- + block_index : int + Index of block to be retrieved from `Sequence`. + + Returns + ------- + SimpleNamespace + PyPulseq block content IDs at 'block_index' position in `self.block_events`. + """ + return block.get_raw_block_content_IDs(self, block_index) + def get_block(self, block_index: int) -> SimpleNamespace: """ Return a block of the sequence specified by the index. The block is created from the sequence data with all @@ -883,6 +930,43 @@ def mod_grad_axis(self, axis: str, modifier: int) -> None: self.grad_library.data[selected_events[i]][3] *= modifier self.grad_library.data[selected_events[i]][4] *= modifier + def paper_plot( + self, + time_range: Tuple[float] = (0, np.inf), + line_width: float = 1.2, + axes_color: Tuple[float] = (0.5, 0.5, 0.5), + rf_color: str = 'black', + gx_color: str = 'blue', + gy_color: str = 'red', + gz_color: Tuple[float] = (0, 0.5, 0.3), + rf_plot: str = 'abs', + ): + """ + Plot sequence using paper-style formatting (minimalist, high-contrast layout). + + Parameters + ---------- + time_range : iterable, default=(0, np.inf) + Time range (x-axis limits) for plotting the sequence. + Default is 0 to infinity (entire sequence). + line_width : float, default=1.2 + Line width used in plots. + axes_color : color, default=(0.5, 0.5, 0.5) + Color of horizontal zero axes (e.g., gray). + rf_color : color, default='black' + Color for RF and ADC events. + gx_color : color, default='blue' + Color for gradient X waveform. + gy_color : color, default='red' + Color for gradient Y waveform. + gz_color : color, default=(0, 0.5, 0.3) + Color for gradient Z waveform. + rf_plot : {'abs', 'real', 'imag'}, default='abs' + Determines how to plot RF waveforms (magnitude, real or imaginary part). + + """ + ext_paper_plot(self, time_range, line_width, axes_color, rf_color, gx_color, gy_color, gz_color, rf_plot) + def plot( self, label: str = str(), @@ -892,7 +976,7 @@ def plot( time_disp: str = 's', grad_disp: str = 'kHz/m', plot_now: bool = True, - ) -> None: + ) -> SeqPlot: """ Plot `Sequence`. @@ -917,221 +1001,13 @@ def plot( If false, plots are shown when plt.show() is called. Useful if plots are to be modified. plot_type : str, default='Gradient' Gradients display type, must be one of either 'Gradient' or 'Kspace'. - """ - mpl.rcParams['lines.linewidth'] = 0.75 # Set default Matplotlib linewidth - valid_time_units = ['s', 'ms', 'us'] - valid_grad_units = ['kHz/m', 'mT/m'] - valid_labels = get_supported_labels() - if not all(isinstance(x, (int, float)) for x in time_range) or len(time_range) != 2: - raise ValueError('Invalid time range') - if time_disp not in valid_time_units: - raise ValueError('Unsupported time unit') - - if grad_disp not in valid_grad_units: - raise ValueError('Unsupported gradient unit. Supported gradient units are: ' + str(valid_grad_units)) - - fig1, fig2 = plt.figure(), plt.figure() - sp11 = fig1.add_subplot(311) - sp12 = fig1.add_subplot(312, sharex=sp11) - sp13 = fig1.add_subplot(313, sharex=sp11) - fig2_subplots = [ - fig2.add_subplot(311, sharex=sp11), - fig2.add_subplot(312, sharex=sp11), - fig2.add_subplot(313, sharex=sp11), - ] - - t_factor_list = [1, 1e3, 1e6] - t_factor = t_factor_list[valid_time_units.index(time_disp)] - - g_factor_list = [1e-3, 1e3 / self.system.gamma] - g_factor = g_factor_list[valid_grad_units.index(grad_disp)] - - t0 = 0 - label_defined = False - label_idx_to_plot = [] - label_legend_to_plot = [] - label_store = {} - for i in range(len(valid_labels)): - label_store[valid_labels[i]] = 0 - if valid_labels[i] in label.upper(): - label_idx_to_plot.append(i) - label_legend_to_plot.append(valid_labels[i]) - - if len(label_idx_to_plot) != 0: - p = parula.main(len(label_idx_to_plot) + 1) - label_colors_to_plot = p(np.arange(len(label_idx_to_plot))) - cycler = mpl.cycler(color=label_colors_to_plot) - sp11.set_prop_cycle(cycler) - - # Block timings - block_edges = np.cumsum([0] + [x[1] for x in sorted(self.block_durations.items())]) - block_edges_in_range = block_edges[(block_edges >= time_range[0]) * (block_edges <= time_range[1])] - if show_blocks: - for sp in [sp11, sp12, sp13, *fig2_subplots]: - sp.set_xticks(t_factor * block_edges_in_range) - sp.set_xticklabels(sp.get_xticklabels(), rotation=90) - - for block_counter in self.block_events: - block = self.get_block(block_counter) - is_valid = time_range[0] <= t0 + self.block_durations[block_counter] and t0 <= time_range[1] - if is_valid: - if getattr(block, 'label', None) is not None: - for i in range(len(block.label)): - if block.label[i].type == 'labelinc': - label_store[block.label[i].label] += block.label[i].value - else: - label_store[block.label[i].label] = block.label[i].value - label_defined = True - - if getattr(block, 'adc', None) is not None: # ADC - adc = block.adc - # From Pulseq: According to the information from Klaus Scheffler and indirectly from Siemens this - # is the present convention - the samples are shifted by 0.5 dwell - t = adc.delay + (np.arange(int(adc.num_samples)) + 0.5) * adc.dwell - sp11.plot(t_factor * (t0 + t), np.zeros(len(t)), 'rx') - sp13.plot( - t_factor * (t0 + t), - np.angle(np.exp(1j * adc.phase_offset) * np.exp(1j * 2 * np.pi * t * adc.freq_offset)), - 'b.', - markersize=0.25, - ) - - if label_defined and len(label_idx_to_plot) != 0: - arr_label_store = list(label_store.values()) - lbl_vals = np.take(arr_label_store, label_idx_to_plot) - t = t0 + adc.delay + (adc.num_samples - 1) / 2 * adc.dwell - _t = [t_factor * t] * len(lbl_vals) - # Plot each label individually to retrieve each corresponding Line2D object - p = itertools.chain.from_iterable( - [sp11.plot(__t, _lbl_vals, '.') for __t, _lbl_vals in zip(_t, lbl_vals)] - ) - if len(label_legend_to_plot) != 0: - sp11.legend(list(p), label_legend_to_plot, loc='upper left') - label_legend_to_plot = [] - - if getattr(block, 'rf', None) is not None: # RF - rf = block.rf - time_center, index_center = calc_rf_center(rf) - time = rf.t - signal = rf.signal - - if signal.shape[0] == 2 and rf.freq_offset != 0: - num_samples = min(int(abs(rf.freq_offset)), 256) - time = np.linspace(time[0], time[-1], num_samples) - signal = np.linspace(signal[0], signal[-1], num_samples) - - if abs(signal[0]) != 0: - signal = np.concatenate(([0], signal)) - time = np.concatenate(([time[0]], time)) - index_center += 1 - - if abs(signal[-1]) != 0: - signal = np.concatenate((signal, [0])) - time = np.concatenate((time, [time[-1]])) - - signal_is_real = max(np.abs(np.imag(signal))) / max(np.abs(np.real(signal))) < 1e-6 - - # Compute time vector with delay applied - time_with_delay = t_factor * (t0 + time + rf.delay) - time_center_with_delay = t_factor * (t0 + time_center + rf.delay) - - # Choose plot behavior based on realness of signal - if signal_is_real: - # Plot real part of signal - sp12.plot(time_with_delay, np.real(signal)) - - # Include sign(real(signal)) factor like MATLAB - phase_corrected = ( - signal - * np.sign(np.real(signal)) - * np.exp(1j * rf.phase_offset) - * np.exp(1j * 2 * math.pi * time * rf.freq_offset) - ) - sc_corrected = ( - signal[index_center] - * np.exp(1j * rf.phase_offset) - * np.exp(1j * 2 * math.pi * time[index_center] * rf.freq_offset) - ) - - sp13.plot( - time_with_delay, - np.angle(phase_corrected), - time_center_with_delay, - np.angle(sc_corrected), - 'xb', - ) - else: - # Plot magnitude of complex signal - sp12.plot(time_with_delay, np.abs(signal)) - - # Plot angle of complex signal - phase_corrected = ( - signal * np.exp(1j * rf.phase_offset) * np.exp(1j * 2 * math.pi * time * rf.freq_offset) - ) - sc_corrected = ( - signal[index_center] - * np.exp(1j * rf.phase_offset) - * np.exp(1j * 2 * math.pi * time[index_center] * rf.freq_offset) - ) - - sp13.plot( - time_with_delay, - np.angle(phase_corrected), - time_center_with_delay, - np.angle(sc_corrected), - 'xb', - ) - - grad_channels = ['gx', 'gy', 'gz'] - for x in range(len(grad_channels)): # Gradients - if getattr(block, grad_channels[x], None) is not None: - grad = getattr(block, grad_channels[x]) - if grad.type == 'grad': - # We extend the shape by adding the first and the last points in an effort of making the - # display a bit less confusing... - time = grad.delay + np.array([0, *grad.tt, grad.shape_dur]) - waveform = g_factor * np.array((grad.first, *grad.waveform, grad.last)) - else: - time = np.array( - cumsum( - 0, - grad.delay, - grad.rise_time, - grad.flat_time, - grad.fall_time, - ) - ) - waveform = g_factor * grad.amplitude * np.array([0, 0, 1, 1, 0]) - fig2_subplots[x].plot(t_factor * (t0 + time), waveform) - t0 += self.block_durations[block_counter] - - grad_plot_labels = ['x', 'y', 'z'] - sp11.set_ylabel('ADC') - sp12.set_ylabel('RF mag (Hz)') - sp13.set_ylabel('RF/ADC phase (rad)') - sp13.set_xlabel(f't ({time_disp})') - for x in range(3): - _label = grad_plot_labels[x] - fig2_subplots[x].set_ylabel(f'G{_label} ({grad_disp})') - fig2_subplots[-1].set_xlabel(f't ({time_disp})') - - # Setting display limits - disp_range = t_factor * np.array([time_range[0], min(t0, time_range[1])]) - [x.set_xlim(disp_range) for x in [sp11, sp12, sp13, *fig2_subplots]] - - # Grid on - for sp in [sp11, sp12, sp13, *fig2_subplots]: - sp.grid() - - fig1.tight_layout() - fig2.tight_layout() - if save: - fig1.savefig('seq_plot1.jpg') - fig2.savefig('seq_plot2.jpg') - - if plot_now: - plt.show() + Returns + ------- + SeqPlot + SeqPlot handle. + """ + return SeqPlot(self, label, show_blocks, save, time_range, time_disp, grad_disp, plot_now) def read(self, file_path: str, detect_rf_use: bool = False, remove_duplicates: bool = True) -> None: """ @@ -1263,7 +1139,7 @@ def rf_from_lib_data(self, lib_data: list, use: str = str()) -> SimpleNamespace: compressed.num_samples = shape_data[0] compressed.data = shape_data[1:] phase = decompress_shape(compressed) - rf.signal = amplitude * mag * np.exp(1j * 2 * np.pi * phase) + rf.signal = amplitude * mag * np.exp(1j * 2 * math.pi * phase) time_shape = lib_data[3] if time_shape > 0: shape_data = self.shape_library.data[time_shape] @@ -1315,13 +1191,6 @@ def rf_times( fp_refocusing : np.ndarray Contains frequency and phase offsets of the excitation RF pulses """ - # tc = calc_rf_center(rf) - # t = rf.delay + tc - # if hasattr(rf,'use') is False or rf.use == 'excitation' or rf.use =='undefined': - # tfp_excitation(:,end+1) = [curr_dur+t; full_freq_offset; full_phase_offset + 2* pi * full_freq_offset * tc] - # elif rf.use =='refocusing': - # tfp_refocusing(:,end+1) = [curr_dur+t; full_freq_offset; full_phase_offset + 2 * pi * full_freq_offset * tc] - # Collect RF timing data t_excitation = [] fp_excitation = [] @@ -1358,7 +1227,7 @@ def rf_times( full_freq_offset = rf.freq_offset + rf.freq_ppm * 1e-6 * self.system.gamma * self.system.B0 full_phase_offset = rf.phase_offset + rf.phase_ppm * 1e-6 * self.system.gamma * self.system.B0 - full_phase_offset = full_phase_offset + 2 * np.pi * full_freq_offset * tc + full_phase_offset = full_phase_offset + 2 * math.pi * full_freq_offset * tc if not hasattr(rf, 'use') or block.rf.use in [ 'excitation', @@ -1587,7 +1456,7 @@ def waveforms(self, append_RF: bool = False, time_range: Union[List[float], None rf_piece = np.array( [ curr_dur + rf.delay + rf.t, - rf.signal * np.exp(1j * (full_phase_offset + 2 * np.pi * full_freq_offset * rf.t)), + rf.signal * np.exp(1j * (full_phase_offset + 2 * math.pi * full_freq_offset * rf.t)), ] ) out_len[-1] += len(rf.t) @@ -1724,16 +1593,32 @@ def waveforms_export(self, time_range=(0, np.inf)) -> dict: # is the present convention - the samples are shifted by 0.5 dwell t = adc.delay + (np.arange(int(adc.num_samples)) + 0.5) * adc.dwell adc_t = t0 + t - adc_signal = np.exp(1j * adc.phase_offset) * np.exp(1j * 2 * np.pi * t * adc.freq_offset) + + if adc.phase_modulation is None or len(adc.phase_modulation) == 0: + phase_modulation = 0 + else: + phase_modulation = adc.phase_modulation + + full_freq_offset = np.atleast_1d(adc.freq_offset + adc.freq_ppm * 1e-6 * self.system.B0) + full_phase_offset = np.atleast_1d( + adc.phase_offset + adc.phase_offset * 1e-6 * self.system.B0 + phase_modulation + ) + + adc_signal = np.exp(1j * full_phase_offset) * np.exp(1j * 2 * math.pi * t * full_phase_offset) adc_t_all = np.concatenate((adc_t_all, adc_t)) adc_signal_all = np.concatenate((adc_signal_all, adc_signal)) if block.rf is not None: rf = block.rf + tc, ic = calc_rf_center(rf) t = rf.t + rf.delay tc = tc + rf.delay + full_freq_offset = rf.freq_offset + rf.freq_ppm * 1e-6 * self.system.gamma * self.system.B0 + full_phase_offset = rf.phase_offset + rf.phase_ppm * 1e-6 * self.system.gamma * self.system.B0 + full_phase_offset = full_phase_offset + 2 * math.pi * full_freq_offset * tc + # Debug - visualize # sp12.plot(t_factor * (t0 + t), np.abs(rf.signal)) # sp13.plot(t_factor * (t0 + t), np.angle(rf.signal * np.exp(1j * rf.phase_offset) @@ -1743,7 +1628,7 @@ def waveforms_export(self, time_range=(0, np.inf)) -> dict: # 'xb') rf_t = t0 + t - rf = rf.signal * np.exp(1j * rf.phase_offset) * np.exp(1j * 2 * math.pi * rf.t * rf.freq_offset) + rf = rf.signal * np.exp(1j * (full_phase_offset + 2 * math.pi * full_freq_offset * rf.t)) rf_t_all = np.concatenate((rf_t_all, rf_t)) rf_signal_all = np.concatenate((rf_signal_all, rf)) rf_t_centers = np.concatenate((rf_t_centers, [rf_t[ic]])) diff --git a/src/pypulseq/utils/paper_plot.py b/src/pypulseq/utils/paper_plot.py new file mode 100644 index 00000000..a83aa8c7 --- /dev/null +++ b/src/pypulseq/utils/paper_plot.py @@ -0,0 +1,121 @@ +from typing import Tuple + +import matplotlib.pyplot as plt +import numpy as np + + +def paper_plot( + seq, + time_range: Tuple[float] = (0, np.inf), + line_width: float = 1.2, + axes_color: Tuple[float] = (0.5, 0.5, 0.5), + rf_color: str = 'black', + gx_color: str = 'blue', + gy_color: str = 'red', + gz_color: Tuple[float] = (0, 0.5, 0.3), + rf_plot: str = 'abs', +): + """ + Plot sequence using paper-style formatting (minimalist, high-contrast layout). + + Parameters + ---------- + seq : Sequence + The Pulseq sequence object to plot. + time_range : iterable, default=(0, np.inf) + Time range (x-axis limits) for plotting the sequence. + Default is 0 to infinity (entire sequence). + line_width : float, default=1.2 + Line width used in plots. + axes_color : color, default=(0.5, 0.5, 0.5) + Color of horizontal zero axes (e.g., gray). + rf_color : color, default='black' + Color for RF and ADC events. + gx_color : color, default='blue' + Color for gradient X waveform. + gy_color : color, default='red' + Color for gradient Y waveform. + gz_color : color, default=(0, 0.5, 0.3) + Color for gradient Z waveform. + rf_plot : {'abs', 'real', 'imag'}, default='abs' + Determines how to plot RF waveforms (magnitude, real or imaginary part). + + """ + # Get waveform data + wave_data, _, _, t_adc, _ = seq.waveforms_and_times(append_RF=True, time_range=time_range) + + # Max amplitudes for scaling + gwm = np.max(np.abs(np.concatenate(wave_data[:3], axis=1)), axis=1) + gwm[0] = max(gwm[0], t_adc[-1]) + rfm = np.max(np.abs(wave_data[3]), axis=1) + + # Handle complex RF + if rf_plot == 'real': + rf_waveform = np.real(wave_data[3][1]) + elif rf_plot == 'imag': + rf_waveform = np.imag(wave_data[3][1]) + else: + rf_waveform = np.abs(wave_data[3][1]) + wave_data[3] = np.stack((wave_data[3][0].real, rf_waveform), axis=0) + + # Clean waveforms by inserting NaNs between zero plateaus + for i in range(4): + data = wave_data[i] + j = data.shape[1] - 1 + while j > 0: + if data[1, j] == 0 and data[1, j - 1] == 0: + midpoint = 0.5 * (data[0, j] + data[0, j - 1]) + data = np.hstack([data[:, :j], np.array([[midpoint], [np.nan]]), data[:, j:]]) + wave_data[i] = data + j -= 1 + + # Create figure + fig = plt.figure(figsize=(12, 10), constrained_layout=True) + fig.patch.set_facecolor('white') + spec = fig.add_gridspec(nrows=4, ncols=1, hspace=0.0) + axes = [] + + def format_axis(ax, xlim, ylim): + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_facecolor('white') + ax.spines[:].set_visible(False) + + # ADC + ax = fig.add_subplot(spec[0]) + ax.vlines(t_adc, ymin=0, ymax=rfm[1] / 5, color=rf_color, lw=line_width / 4, zorder=2.5) + + # RF + ax.plot([-0.01 * gwm[0], 1.01 * gwm[0]], [0, 0], color=axes_color, lw=line_width / 5) + ax.plot(wave_data[3][0], wave_data[3][1], color=rf_color, lw=line_width) + + # Format RF + ADC + format_axis(ax, [-0.03 * gwm[0], 1.03 * gwm[0]], [-1.03 * rfm[1], 1.03 * rfm[1]]) + axes.append(ax) + + # Gradient Z + ax = fig.add_subplot(spec[1]) + ax.plot([-0.01 * gwm[0], 1.01 * gwm[0]], [0, 0], color=axes_color, lw=line_width / 5) + ax.plot(wave_data[2][0], wave_data[2][1], color=gz_color, lw=line_width) + format_axis(ax, [-0.03 * gwm[0], 1.03 * gwm[0]], [-1.03 * gwm[1], 1.03 * gwm[1]]) + axes.append(ax) + + # Gradient Y + ax = fig.add_subplot(spec[2]) + ax.plot([-0.01 * gwm[0], 1.01 * gwm[0]], [0, 0], color=axes_color, lw=line_width / 5) + ax.plot(wave_data[1][0], wave_data[1][1], color=gy_color, lw=line_width) + format_axis(ax, [-0.03 * gwm[0], 1.03 * gwm[0]], [-1.03 * gwm[1], 1.03 * gwm[1]]) + axes.append(ax) + + # Gradient X + ax = fig.add_subplot(spec[3]) + ax.plot([-0.01 * gwm[0], 1.01 * gwm[0]], [0, 0], color=axes_color, lw=line_width / 5) + ax.plot(wave_data[0][0], wave_data[0][1], color=gx_color, lw=line_width) + format_axis(ax, [-0.03 * gwm[0], 1.03 * gwm[0]], [-1.03 * gwm[1], 1.03 * gwm[1]]) + axes.append(ax) + + # Link X-axes (time axis) + for ax in axes[1:]: + ax.sharex(axes[0]) diff --git a/src/pypulseq/utils/seq_plot.py b/src/pypulseq/utils/seq_plot.py new file mode 100644 index 00000000..6d411e4c --- /dev/null +++ b/src/pypulseq/utils/seq_plot.py @@ -0,0 +1,399 @@ +import itertools +import math + +import matplotlib as mpl +import numpy as np +from matplotlib import pyplot as plt + +from pypulseq.calc_rf_center import calc_rf_center +from pypulseq.Sequence import parula +from pypulseq.supported_labels_rf_use import get_supported_labels +from pypulseq.utils.cumsum import cumsum + +try: + import mplcursors + + __MPLCURSORS_AVAILABLE__ = True +except ImportError: + __MPLCURSORS_AVAILABLE__ = False + + +class SeqPlot: + """ + Interactive plotter for a Pulseq `Sequence` object. + + Parameters + ---------- + seq : Sequence + The Pulseq sequence object to plot. + label : str, default=str() + Plot label values for ADC events: in this example for LIN and REP labels; other valid labes are accepted as + a comma-separated list. + save : bool, default=False + Boolean flag indicating if plots should be saved. The two figures will be saved as JPG with numerical + suffixes to the filename 'seq_plot'. + show_blocks : bool, default=False + Boolean flag to indicate if grid and tick labels at the block boundaries are to be plotted. + time_range : iterable, default=(0, np.inf) + Time range (x-axis limits) for plotting the sequence. Default is 0 to infinity (entire sequence). + time_disp : str, default='s' + Time display type, must be one of `s`, `ms` or `us`. + grad_disp : str, default='s' + Gradient display unit, must be one of `kHz/m` or `mT/m`. + plot_now : bool, default=True + If true, function immediately shows the plots, blocking the rest of the code until plots are exited. + If false, plots are shown when plt.show() is called. Useful if plots are to be modified. + plot_type : str, default='Gradient' + Gradients display type, must be one of either 'Gradient' or 'Kspace'. + + Attributes + ---------- + fig1 : matplotlib.figure.Figure + Figure containing RF and ADC channels. + fig2 : matplotlib.figure.Figure + Figure containing Gradient or K-space channels. + ax1 : matplotlib.axes.Axes + Axes for fig1. + ax2 : matplotlib.axes.Axes + Axes for fig2. + """ + + def __init__( + self, + seq, + label: str = str(), + show_blocks: bool = False, + save: bool = False, + time_range=(0, np.inf), + time_disp: str = 's', + grad_disp: str = 'kHz/m', + plot_now: bool = True, + ): + self.seq = seq + self._cursors = [] + + self.fig1, self.fig2 = _seq_plot( + seq, + label=label, + save=save, + show_blocks=show_blocks, + time_range=time_range, + time_disp=time_disp, + grad_disp=grad_disp, + ) + self.ax1 = self.fig1.axes[0] + self.ax2 = self.fig2.axes[0] + + if __MPLCURSORS_AVAILABLE__: + self._setup_cursor(self.fig1) + self._setup_cursor(self.fig2) + + if plot_now: + self.show() + + def show(self): + plt.show() + + def _setup_cursor(self, fig): + for ax in fig.axes: + lines = ax.get_lines() + for line in lines: + cursor = mplcursors.cursor(line, multiple=True) + cursor.connect('add', lambda sel: self._on_datatip(sel)) + self._cursors.append(cursor) + + def _on_datatip(self, sel): + artist = sel.artist + ax = artist.axes + x, y = sel.target + ylabel = ax.get_ylabel().lower() + + if ylabel.startswith('adc') or ( + ylabel.startswith('rf/adc') and artist.get_linestyle() == 'none' and artist.get_marker() == '.' + ): + field = 'adc' + else: + field = ylabel[:2] + + t0 = artist.get_xdata()[0] if hasattr(artist, 'get_xdata') else x + block_index = self.seq.find_block_by_time(t0) + rb = self.seq.get_raw_block_content_IDs(block_index) + + lines_txt = [f't: {x:.3f}', f'Y: {y:.3f}'] + + val = getattr(rb, field, None) + if val is not None: + try: + if field[0] == 'a': + name = self.seq.adc_id2name_map[val] + elif field[0] == 'r': + name = self.seq.rf_id2name_map[val] + else: + name = self.seq.grad_id2name_map[val] + + lines_txt.append(f"blk: {block_index} {field}_id: {val} '{name}'") + except Exception: + lines_txt.append(f'blk: {block_index} {field}_id: {val}') + else: + lines_txt.append(f'blk: {block_index}') + + sel.annotation.set_text('\n'.join(lines_txt)) + self._update_guides() + + def _update_guides(self): + for ax in (self.ax1, self.ax2): + ax.relim() + ax.autoscale_view() + + for fig in (self.fig1, self.fig2): + fig.canvas.draw_idle() + + +def _seq_plot( + seq, + label, + show_blocks, + save, + time_range, + time_disp, + grad_disp, +): + mpl.rcParams['lines.linewidth'] = 0.75 # Set default Matplotlib linewidth + + valid_time_units = ['s', 'ms', 'us'] + valid_grad_units = ['kHz/m', 'mT/m'] + valid_labels = get_supported_labels() + if not all(isinstance(x, (int, float)) for x in time_range) or len(time_range) != 2: + raise ValueError('Invalid time range') + if time_disp not in valid_time_units: + raise ValueError('Unsupported time unit') + + if grad_disp not in valid_grad_units: + raise ValueError('Unsupported gradient unit. Supported gradient units are: ' + str(valid_grad_units)) + + fig1, fig2 = plt.figure(), plt.figure() + sp11 = fig1.add_subplot(311) + sp12 = fig1.add_subplot(312, sharex=sp11) + sp13 = fig1.add_subplot(313, sharex=sp11) + fig2_subplots = [ + fig2.add_subplot(311, sharex=sp11), + fig2.add_subplot(312, sharex=sp11), + fig2.add_subplot(313, sharex=sp11), + ] + + t_factor_list = [1, 1e3, 1e6] + t_factor = t_factor_list[valid_time_units.index(time_disp)] + + g_factor_list = [1e-3, 1e3 / seq.system.gamma] + g_factor = g_factor_list[valid_grad_units.index(grad_disp)] + + t0 = 0 + label_defined = False + label_idx_to_plot = [] + label_legend_to_plot = [] + label_store = {} + for i in range(len(valid_labels)): + label_store[valid_labels[i]] = 0 + if valid_labels[i] in label.upper(): + label_idx_to_plot.append(i) + label_legend_to_plot.append(valid_labels[i]) + + if len(label_idx_to_plot) != 0: + p = parula.main(len(label_idx_to_plot) + 1) + label_colors_to_plot = p(np.arange(len(label_idx_to_plot))) + cycler = mpl.cycler(color=label_colors_to_plot) + sp11.set_prop_cycle(cycler) + + # Block timings + block_edges = np.cumsum([0] + [x[1] for x in sorted(seq.block_durations.items())]) + block_edges_in_range = block_edges[(block_edges >= time_range[0]) * (block_edges <= time_range[1])] + if show_blocks: + for sp in [sp11, sp12, sp13, *fig2_subplots]: + sp.set_xticks(t_factor * block_edges_in_range) + sp.set_xticklabels(sp.get_xticklabels(), rotation=90) + + for block_counter in seq.block_events: + block = seq.get_block(block_counter) + is_valid = time_range[0] <= t0 + seq.block_durations[block_counter] and t0 <= time_range[1] + if is_valid: + if getattr(block, 'label', None) is not None: + for i in range(len(block.label)): + if block.label[i].type == 'labelinc': + label_store[block.label[i].label] += block.label[i].value + else: + label_store[block.label[i].label] = block.label[i].value + label_defined = True + + if getattr(block, 'adc', None) is not None: # ADC + adc = block.adc + # From Pulseq: According to the information from Klaus Scheffler and indirectly from Siemens this + # is the present convention - the samples are shifted by 0.5 dwell + t = adc.delay + (np.arange(int(adc.num_samples)) + 0.5) * adc.dwell + sp11.plot(t_factor * (t0 + t), np.zeros(len(t)), 'rx') + + if adc.phase_modulation is None or len(adc.phase_modulation) == 0: + phase_modulation = 0 + else: + phase_modulation = adc.phase_modulation + + full_freq_offset = np.atleast_1d(adc.freq_offset + adc.freq_ppm * 1e-6 * seq.system.B0) + full_phase_offset = np.atleast_1d( + adc.phase_offset + adc.phase_offset * 1e-6 * seq.system.B0 + phase_modulation + ) + + sp13.plot( + t_factor * (t0 + t), + np.angle(np.exp(1j * full_phase_offset) * np.exp(1j * 2 * math.pi * t * full_freq_offset)), + 'b.', + markersize=0.25, + ) + + if label_defined and len(label_idx_to_plot) != 0: + arr_label_store = list(label_store.values()) + lbl_vals = np.take(arr_label_store, label_idx_to_plot) + t = t0 + adc.delay + (adc.num_samples - 1) / 2 * adc.dwell + _t = [t_factor * t] * len(lbl_vals) + # Plot each label individually to retrieve each corresponding Line2D object + p = itertools.chain.from_iterable( + [sp11.plot(__t, _lbl_vals, '.') for __t, _lbl_vals in zip(_t, lbl_vals)] + ) + if len(label_legend_to_plot) != 0: + sp11.legend(list(p), label_legend_to_plot, loc='upper left') + label_legend_to_plot = [] + + if getattr(block, 'rf', None) is not None: # RF + rf = block.rf + time_center, index_center = calc_rf_center(rf) + time = rf.t + signal = rf.signal + + if signal.shape[0] == 2 and rf.freq_offset != 0: + num_samples = min(int(abs(rf.freq_offset)), 256) + time = np.linspace(time[0], time[-1], num_samples) + signal = np.linspace(signal[0], signal[-1], num_samples) + + if abs(signal[0]) != 0: + signal = np.concatenate(([0], signal)) + time = np.concatenate(([time[0]], time)) + index_center += 1 + + if abs(signal[-1]) != 0: + signal = np.concatenate((signal, [0])) + time = np.concatenate((time, [time[-1]])) + + signal_is_real = max(np.abs(np.imag(signal))) / max(np.abs(np.real(signal))) < 1e-6 + + full_freq_offset = rf.freq_offset + rf.freq_ppm * 1e-6 * seq.system.B0 + full_phase_offset = rf.phase_offset + rf.phase_ppm * 1e-6 * seq.system.B0 + + # If off-resonant and rectangular (2 samples), interpolate the pulse + if len(signal) == 2 and full_freq_offset != 0: + num_interp = min(int(abs(full_freq_offset)), 256) + time = np.linspace(time[0], time[-1], num_interp) + signal = np.linspace(signal[0], signal[-1], num_interp) + if abs(signal[0]) != 0: # fix strangely looking phase / amplitude in the beginning + signal = np.concatenate([[0], signal]) + time = np.concatenate([[time[0]], time]) + if abs(signal[-1]) != 0: # fix strangely looking phase / amplitude at the end + signal = np.concatenate([signal, [0]]) + time = np.concatenate([time, [time[-1]]]) + + # Compute time vector with delay applied + time_with_delay = t_factor * (t0 + time + rf.delay) + time_center_with_delay = t_factor * (t0 + time_center + rf.delay) + + # Choose plot behavior based on realness of signal + if signal_is_real: + # Plot real part of signal + sp12.plot(time_with_delay, np.real(signal)) + + # Include sign(real(signal)) factor like MATLAB + phase_corrected = ( + signal + * np.sign(np.real(signal)) + * np.exp(1j * full_phase_offset) + * np.exp(1j * 2 * math.pi * time * full_freq_offset) + ) + sc_corrected = ( + signal[index_center] + * np.exp(1j * full_phase_offset) + * np.exp(1j * 2 * math.pi * time[index_center] * full_freq_offset) + ) + + sp13.plot( + time_with_delay, + np.angle(phase_corrected), + time_center_with_delay, + np.angle(sc_corrected), + 'xb', + ) + else: + # Plot magnitude of complex signal + sp12.plot(time_with_delay, np.abs(signal)) + + # Plot angle of complex signal + phase_corrected = ( + signal * np.exp(1j * full_phase_offset) * np.exp(1j * 2 * math.pi * time * full_freq_offset) + ) + sc_corrected = ( + signal[index_center] + * np.exp(1j * full_phase_offset) + * np.exp(1j * 2 * math.pi * time[index_center] * full_freq_offset) + ) + + sp13.plot( + time_with_delay, + np.angle(phase_corrected), + time_center_with_delay, + np.angle(sc_corrected), + 'xb', + ) + + grad_channels = ['gx', 'gy', 'gz'] + for x in range(len(grad_channels)): # Gradients + if getattr(block, grad_channels[x], None) is not None: + grad = getattr(block, grad_channels[x]) + if grad.type == 'grad': + # We extend the shape by adding the first and the last points in an effort of making the + # display a bit less confusing... + time = grad.delay + np.array([0, *grad.tt, grad.shape_dur]) + waveform = g_factor * np.array((grad.first, *grad.waveform, grad.last)) + else: + time = np.array( + cumsum( + 0, + grad.delay, + grad.rise_time, + grad.flat_time, + grad.fall_time, + ) + ) + waveform = g_factor * grad.amplitude * np.array([0, 0, 1, 1, 0]) + fig2_subplots[x].plot(t_factor * (t0 + time), waveform) + t0 += seq.block_durations[block_counter] + + grad_plot_labels = ['x', 'y', 'z'] + sp11.set_ylabel('ADC') + sp12.set_ylabel('RF mag (Hz)') + sp13.set_ylabel('RF/ADC phase (rad)') + sp13.set_xlabel(f't ({time_disp})') + for x in range(3): + _label = grad_plot_labels[x] + fig2_subplots[x].set_ylabel(f'G{_label} ({grad_disp})') + fig2_subplots[-1].set_xlabel(f't ({time_disp})') + + # Setting display limits + disp_range = t_factor * np.array([time_range[0], min(t0, time_range[1])]) + [x.set_xlim(disp_range) for x in [sp11, sp12, sp13, *fig2_subplots]] + + # Grid on + for sp in [sp11, sp12, sp13, *fig2_subplots]: + sp.grid() + + fig1.tight_layout() + fig2.tight_layout() + if save: + fig1.savefig('seq_plot1.jpg') + fig2.savefig('seq_plot2.jpg') + + return fig1, fig2