Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have interp use ds_idx, add tests #3332

Merged
merged 2 commits into from Apr 30, 2021
Merged

Have interp use ds_idx, add tests #3332

merged 2 commits into from Apr 30, 2021

Conversation

muellerzr
Copy link
Contributor

@muellerzr muellerzr commented Apr 22, 2021

Summary of Changes

The Interpretation class now calls .new(shuffled=False, drop_last=False) when generating the DataLoader if none were found. This is similar to what is done during get_preds, and is needed as shuffling and dropping last will cause Interpretation to complain about missing indexs (see issue here: #3248)

Closes: #3051 and #3248

What This Does to the End Users

Allows them to specify ds_idx=0 without raising any issues.

Added Tests

I've changed the tests in Interpretation as a result of the current tests not being extensive enough.

Before the tests were:

learn = synth_learner()
interp = Interpretation.from_learner(learn)
x,y = learn.dls.valid_ds.tensors
test_eq(interp.inputs, x)
test_eq(interp.targs, y)
out = learn.model.a * x + learn.model.b
test_eq(interp.preds, out)
test_eq(interp.losses, (out-y)[:,0]**2)

And while they do get the job done, they're not flexible enough for us to check if the training or validation dataset is doing okay, and that decoding went well (along with the rest of the pipeline).

As a result I've instead added in using the MNIST_TINY dataset, and a subsample of it:

#hide
from fastai.vision.all import *
mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), 
                  get_items=get_image_files, 
                  splitter=RandomSubsetSplitter(.1,.1, seed=42),
                  get_y=parent_label)
test_dls = mnist.dataloaders(untar_data(URLs.MNIST_SAMPLE), bs=8)
test_learner = cnn_learner(test_dls, resnet18)

Now here are the associated tests:

#hide
interp = Interpretation.from_learner(test_learner)
x, y, out = [], [], []
for batch in test_learner.dls.valid:
    x += batch[0]
    y += batch[1]
    out += test_learner.model(batch[0])
x,y,out = torch.stack(x), torch.stack(y, dim=0), torch.stack(out, dim=0)
test_eq(interp.inputs, x)
test_eq(interp.targs, y)
losses = torch.stack([test_learner.loss_func(p,t) for p,t in zip(out,y)], dim=0)
test_close(interp.losses, losses)
#hide
#dummy test to ensure we can run on the training set
interp = Interpretation.from_learner(test_learner, ds_idx=0)
x, y, out = [], [], []
for batch in test_learner.dls.train.new(drop_last=False, shuffle=False):
    x += batch[0]
    y += batch[1]
    out += test_learner.model(batch[0])
x,y,out = torch.stack(x), torch.stack(y, dim=0), torch.stack(out, dim=0)
test_eq(interp.inputs, x)
test_eq(interp.targs, y)
losses = torch.stack([test_learner.loss_func(p,t) for p,t in zip(out,y)], dim=0)
test_close(interp.losses, losses)

It's a bit longer, but it does the same thing that the previous test did, but also takes into account us using the true DataBlock api, and adds a test for the ds_idx=0.

Let me know if there are any suggestions, such as keeping that first test back in with addition to these, or having these in with #slow.

cc @jph00 and @hamelsmu

@muellerzr muellerzr requested a review from jph00 as a code owner April 22, 2021 14:47
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@hamelsmu
Copy link
Member

I'll let Jeremy review this one

@jph00
Copy link
Member

jph00 commented Apr 30, 2021

Many thanks @muellerzr !

@jph00 jph00 merged commit 71a35a3 into fastai:master Apr 30, 2021
@hamelsmu hamelsmu changed the title Have interp safely use ds_idx, add tests Have interp use ds_idx, add tests Apr 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ClassificationInterpretation.from_learner not working on training set
3 participants