diff --git a/fastai/_modidx.py b/fastai/_modidx.py index f86fda8c9f..2bea67ec8a 100644 --- a/fastai/_modidx.py +++ b/fastai/_modidx.py @@ -2691,4 +2691,4 @@ 'fastai.vision.widgets._update_children': ( 'vision.widgets.html#_update_children', 'fastai/vision/widgets.py'), 'fastai.vision.widgets.carousel': ('vision.widgets.html#carousel', 'fastai/vision/widgets.py'), - 'fastai.vision.widgets.widget': ('vision.widgets.html#widget', 'fastai/vision/widgets.py')}}} + 'fastai.vision.widgets.widget': ('vision.widgets.html#widget', 'fastai/vision/widgets.py')}}} \ No newline at end of file diff --git a/fastai/learner.py b/fastai/learner.py index dabee8898e..c37b7fb45b 100644 --- a/fastai/learner.py +++ b/fastai/learner.py @@ -594,13 +594,25 @@ def _valid_mets(self): if getattr(self, 'cancel_valid', False): return L() return (L(self.loss) + self.metrics if self.valid_metrics else L()) - def plot_loss(self, skip_start=5, with_valid=True): - plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train') + def plot_loss(self, skip_start=5, with_valid=True, log=False, show_epochs=False, ax=None): + if not ax: + ax=plt.gca() + if log: + ax.loglog(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train') + else: + ax.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train') + if show_epochs: + for x in self.iters: + ax.axvline(x, color='grey', ls=':') + ax.set_ylabel('loss') + ax.set_xlabel('steps') + ax.set_title('learning curve') if with_valid: idx = (np.array(self.iters)