Skip to content

Commit

Permalink
allow for single-dimension data to be plotted via imshow
Browse files Browse the repository at this point in the history
  • Loading branch information
mtlam committed Apr 21, 2021
1 parent 5c707a8 commit a62c627
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions pypulse/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,8 +1357,12 @@ def imshow(self, ax=None, cbar=False, mask=None, show=True,
"""
data = self.getData(setnan=setnan)
shape = self.shape(squeeze=False)
if len(np.shape(data)) == 2:

if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)

if len(np.shape(data)) == 2:
extent = None
if shape[0] == 1 and shape[1] == 1:
mode = "freq-phase"
Expand All @@ -1377,9 +1381,6 @@ def imshow(self, ax=None, cbar=False, mask=None, show=True,
else:
raise IndexError("Unimplemented shape for imshow()")

if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)

#cmap.set_bad(color='k', alpha=1.0)

Expand All @@ -1402,7 +1403,7 @@ def imshow(self, ax=None, cbar=False, mask=None, show=True,
ax.set_ylabel("Time (%s)"%unit)
ax2 = ax.twinx()
ax2.set_ylim(0, self.getNsubint())
ax2.set_ylabel("Subintegration Numer")
ax2.set_ylabel("Subintegration Number")
elif mode == "freq-time":
unit = u.unitchanger(self.getTimeUnit())
ax.set_xlabel("Time (%s)"%unit)
Expand All @@ -1414,17 +1415,32 @@ def imshow(self, ax=None, cbar=False, mask=None, show=True,
ax_time = ax.twiny()
ax_time.set_xlim(0, self.getNsubint())
ax_time.set_xlabel("Subintegration Number")


if cbar:
plt.colorbar()
if filename is not None:
plt.savefig(filename)
if show:
plt.show()
elif len(np.shape(data)) == 1:
# Trust that this is a single subintegration in time?
Tedges = self.getAxis('T', edges=True) #is this true?
extent = [0, 1, Tedges[0], Tedges[-1]]

u.imshow(data[np.newaxis, ...], ax=ax, extent=extent, cmap=cmap, **kwargs)

ax.set_xlabel("Pulse Phase")
unit = u.unitchanger(self.getTimeUnit())
ax.set_ylabel("Time (%s)"%unit)
ax2 = ax.twinx()
ax2.set_ylim(0, self.getNsubint())
ax2.set_ylabel("Subintegration Number")

else:
raise IndexError("Invalid dimensions for plotting")


if cbar:
plt.colorbar()
if filename is not None:
plt.savefig(filename)
if show:
plt.show()

return ax

def pavplot(self, ax=None, mode="GTpd", show=True, wcfreq=True):
Expand Down

0 comments on commit a62c627

Please sign in to comment.