Skip to content

Commit

Permalink
run nbdev_export
Browse files Browse the repository at this point in the history
  • Loading branch information
turbotimon committed Oct 17, 2023
1 parent f121be9 commit a703e4b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion fastai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2691,4 +2691,4 @@
'fastai.vision.widgets._update_children': ( 'vision.widgets.html#_update_children',
'fastai/vision/widgets.py'),
'fastai.vision.widgets.carousel': ('vision.widgets.html#carousel', 'fastai/vision/widgets.py'),
'fastai.vision.widgets.widget': ('vision.widgets.html#widget', 'fastai/vision/widgets.py')}}}
'fastai.vision.widgets.widget': ('vision.widgets.html#widget', 'fastai/vision/widgets.py')}}}
22 changes: 17 additions & 5 deletions fastai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,13 +594,25 @@ def _valid_mets(self):
if getattr(self, 'cancel_valid', False): return L()
return (L(self.loss) + self.metrics if self.valid_metrics else L())

def plot_loss(self, skip_start=5, with_valid=True):
plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')
def plot_loss(self, skip_start=5, with_valid=True, log=False, show_epochs=False, ax=None):
if not ax:
ax=plt.gca()
if log:
ax.loglog(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')
else:
ax.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')
if show_epochs:
for x in self.iters:
ax.axvline(x, color='grey', ls=':')
ax.set_ylabel('loss')
ax.set_xlabel('steps')
ax.set_title('learning curve')
if with_valid:
idx = (np.array(self.iters)<skip_start).sum()
valid_col = self.metric_names.index('valid_loss') - 1
plt.plot(self.iters[idx:], L(self.values[idx:]).itemgot(valid_col), label='valid')
plt.legend()
ax.plot(self.iters[idx:], L(self.values[idx:]).itemgot(valid_col), label='valid')
ax.legend()
return ax

# %% ../nbs/13a_learner.ipynb 136
add_docs(Recorder,
Expand All @@ -610,7 +622,7 @@ def plot_loss(self, skip_start=5, with_valid=True):
after_validate = "Log loss and metric values on the validation set",
after_cancel_train = "Ignore training metrics for this epoch",
after_cancel_validate = "Ignore validation metrics for this epoch",
plot_loss = "Plot the losses from `skip_start` and onward")
plot_loss = "Plot the losses from `skip_start` and onward. Optionally `log=True` for logarithmic axis, `show_epochs=True` for indicate epochs and a matplotlib axis `ax` to plot on.")

if Recorder not in defaults.callbacks: defaults.callbacks.append(Recorder)

Expand Down
2 changes: 1 addition & 1 deletion fastai/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def accumulate(self, learn):
c,t = self.get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)
if c == 0:
smooth_mteval *= 2
c = 1 / smooth_mteval # exp smoothing, method 3 from https://aclanthology.org/W14-3346/
c = 1 / smooth_mteval # exp smoothing, method 3 from http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf
self.corrects[i] += c
self.counts[i] += t

Expand Down

0 comments on commit a703e4b

Please sign in to comment.