Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use an already trained Torch model to predict on lots of data #111

Closed
mrocklin opened this issue Oct 10, 2019 · 9 comments · Fixed by #114
Closed

Use an already trained Torch model to predict on lots of data #111

mrocklin opened this issue Oct 10, 2019 · 9 comments · Fixed by #114
Assignees

Comments

@mrocklin
Copy link
Member

mrocklin commented Oct 10, 2019

Extending on #35 it would be nice to have an example using Dask and Torch together to parallelize prediction. This should be a simple embarrassingly parallel use case, but I suspect that it would be pragmatic for lots of folks.

The challenge, I think, is constructing a simple example that hopefully doesn't get too much into Torch or a dataset. In my ideal world this would be something like

import torchvision
model = torchvision.get_model("model_name")

dataset = get_canned_dataset()
>>> imshow(dataset[0])  # show an example image

>>> model.predict(dataset[0])
"this is a cat"

... # then dask things here

Does anyone have good pointers to such a simple case?

cc @stsievert @TomAugspurger @AlbertDeFusco

@stsievert
Copy link
Member

cc @muammar

In addition to this example, I'd also link to integration with a Scikit-learn wrapper for PyTorch skorch and Dask-ML's ParallelPostFit.

@TomAugspurger TomAugspurger self-assigned this Oct 15, 2019
@TomAugspurger
Copy link
Member

Should have an example ready tomorrow.

hopefully doesn't get too much into Torch or a dataset.

I think we'll want to go into some detail about torch.utils.data.Dataset, because it's not 100% straightforward how to get the data loaded onto workers. To predict for a directory of images, I had to write the following myself

import glob

from PIL import Image


def default_loader(path, fs=__builtins__):
    with fs.open(path, 'rb') as f:
        img = Image.open(f).convert("RGB")
        return img


class FileDataset(torch.utils.data.Dataset):
    def __init__(self, files, transform=None, target_transform=None,
                 classes=None,
                 loader=default_loader):
        self.files = files
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        if classes is None:
            classes = list(sorted(set(x.split(os.path.sep)[-2] for x in files)))
        else:
            classes = list(classes)
        self.classes = classes

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, index):
        filename = self.files[index]
        img = self.loader(filename)
        target = self.classes.index(filename.split(os.path.sep)[-2])
        
        if self.transform is not None:
            img = self.transform(img)
            
        if self.target_transform is not None:
            target = self.target_transform(target)
        
        return img, target

and use it as

files = glob.glob("hymenoptera_data/val/*/*.jpg")
dataset = FileDataset(files, transform=data_transforms['val'])

For s3, the usage would be FileDataset(files, ..., loader=functools.partial(default_loader, fs=s3fs.S3FileSystem(...)). As a relative newcomer to PyTorch, writing that wasn't 100% straightforward.

Things seem to be working out well after that. PyTorch models seem to (de)serialize much better than tensorflow's did last time I tried.

@mrocklin
Copy link
Member Author

Do we need to use the Torch Dataset API here?

because it's not 100% straightforward how to get the data loaded onto workers

I guess my hope is that, for image data at least, we could just pass around Numpy arrays. So we might created dask.delayed objects using skimage.io.imread or something similar. (maybe like https://blog.dask.org/2019/06/20/load-image-data , but before the dask array bit)

@mrocklin
Copy link
Member Author

@TomAugspurger
Copy link
Member

TomAugspurger commented Oct 15, 2019 via email

@mrocklin
Copy link
Member Author

mrocklin commented Oct 15, 2019 via email

@stsievert
Copy link
Member

I'm curious about inputs of Dask Arrays and outputs of model predictions too. I think PyTorch Datasets will need to play an intermediate role; at least that's what skorch uses when tracing net.py's Net.predict to net.py#L1150.

Distributed training would also be interesting of course, but my guess is that that's more of an open problem

It's also mentioned in dask/distributed#2581

@AlbertDeFusco
Copy link
Contributor

Skorch looks interesting to me. Can the wrapper be used after loading the model from disk where the wrapper was not used?

I've practiced applying the dask-ml parallelpostfit wrapper on a pre-trained model and I remember having to do a few manual steps before running predictions. I need to dig up that code.

@stsievert
Copy link
Member

stsievert commented Oct 19, 2019

Can the wrapper be used after loading the model from disk where the wrapper was not used?

Yup. The underlying model is an attribute (.module_), so it's simple:

import torch
from skorch import NeuralNetClassifier

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        ...

model = Net()
#  Train model

# Save trained model using PyTorch
torch.save(model.state_dict(), "trained_model.pt")

# Use skorch later (not necessarily the training session)
sk_net = NeuralNetClassifier(Net)
sk_net.initialize()

# Load parameters saved with PyTorch
sk_net.module_.load_state_dict(torch.load("trained_model.pt"))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants