diff --git a/fastai/learner.py b/fastai/learner.py index dabee8898e..c37b7fb45b 100644 --- a/fastai/learner.py +++ b/fastai/learner.py @@ -594,13 +594,25 @@ def _valid_mets(self): if getattr(self, 'cancel_valid', False): return L() return (L(self.loss) + self.metrics if self.valid_metrics else L()) - def plot_loss(self, skip_start=5, with_valid=True): - plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train') + def plot_loss(self, skip_start=5, with_valid=True, log=False, show_epochs=False, ax=None): + if not ax: + ax=plt.gca() + if log: + ax.loglog(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train') + else: + ax.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train') + if show_epochs: + for x in self.iters: + ax.axvline(x, color='grey', ls=':') + ax.set_ylabel('loss') + ax.set_xlabel('steps') + ax.set_title('learning curve') if with_valid: idx = (np.array(self.iters)