From a62c6275da36787f6220b244bf4e2ef9c53ce8dc Mon Sep 17 00:00:00 2001 From: Michael Lam Date: Wed, 21 Apr 2021 00:40:43 -0400 Subject: [PATCH] allow for single-dimension data to be plotted via imshow --- pypulse/archive.py | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/pypulse/archive.py b/pypulse/archive.py index 9ba1da5..26c5f5d 100644 --- a/pypulse/archive.py +++ b/pypulse/archive.py @@ -1357,8 +1357,12 @@ def imshow(self, ax=None, cbar=False, mask=None, show=True, """ data = self.getData(setnan=setnan) shape = self.shape(squeeze=False) - if len(np.shape(data)) == 2: + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(111) + + if len(np.shape(data)) == 2: extent = None if shape[0] == 1 and shape[1] == 1: mode = "freq-phase" @@ -1377,9 +1381,6 @@ def imshow(self, ax=None, cbar=False, mask=None, show=True, else: raise IndexError("Unimplemented shape for imshow()") - if ax is None: - fig = plt.figure() - ax = fig.add_subplot(111) #cmap.set_bad(color='k', alpha=1.0) @@ -1402,7 +1403,7 @@ def imshow(self, ax=None, cbar=False, mask=None, show=True, ax.set_ylabel("Time (%s)"%unit) ax2 = ax.twinx() ax2.set_ylim(0, self.getNsubint()) - ax2.set_ylabel("Subintegration Numer") + ax2.set_ylabel("Subintegration Number") elif mode == "freq-time": unit = u.unitchanger(self.getTimeUnit()) ax.set_xlabel("Time (%s)"%unit) @@ -1414,17 +1415,32 @@ def imshow(self, ax=None, cbar=False, mask=None, show=True, ax_time = ax.twiny() ax_time.set_xlim(0, self.getNsubint()) ax_time.set_xlabel("Subintegration Number") - - if cbar: - plt.colorbar() - if filename is not None: - plt.savefig(filename) - if show: - plt.show() + elif len(np.shape(data)) == 1: + # Trust that this is a single subintegration in time? + Tedges = self.getAxis('T', edges=True) #is this true? + extent = [0, 1, Tedges[0], Tedges[-1]] + u.imshow(data[np.newaxis, ...], ax=ax, extent=extent, cmap=cmap, **kwargs) + + ax.set_xlabel("Pulse Phase") + unit = u.unitchanger(self.getTimeUnit()) + ax.set_ylabel("Time (%s)"%unit) + ax2 = ax.twinx() + ax2.set_ylim(0, self.getNsubint()) + ax2.set_ylabel("Subintegration Number") + else: raise IndexError("Invalid dimensions for plotting") + + + if cbar: + plt.colorbar() + if filename is not None: + plt.savefig(filename) + if show: + plt.show() + return ax def pavplot(self, ax=None, mode="GTpd", show=True, wcfreq=True):