In [1]:
from fastai.vision.all import *

In [2]:
path = untar_data(URLs.PETS)/'images'

In [3]:
fnames = get_image_files(path/'images')
pat = r'(.+)_\d+.jpg$'
item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))
batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]
bs=64

In [4]:
pets = DataBlock(blocks=(ImageBlock, CategoryBlock),
                 get_items=get_image_files,
                 splitter=RandomSplitter(),
                 get_y=RegexLabeller(pat = r'/([^/]+)_\d+.*'),
                 item_tfms=item_tfms,
                 batch_tfms=batch_tfms)

In [5]:
def label_to_list(o): return [o]

In [6]:
multi_pets = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(),
    get_y=Pipeline(
        [RegexLabeller(pat = r'/([^/]+)_\d+.*'), label_to_list]
    ),
    item_tfms=item_tfms,
    batch_tfms=batch_tfms
)

In [7]:
dls = multi_pets.dataloaders(path, bs=32)

In [8]:
dls.show_batch()

In [9]:
train_idxs, valid_idxs = RandomSplitter()(get_image_files(path))

In [10]:
tfms = [
    [PILImage.create],
    [
        RegexLabeller(pat = r'/([^/]+)_\d+.*'),
        label_to_list,
        MultiCategorize(vocab=list(dls.vocab)),
        OneHotEncode(len(dls.vocab))
    ]
]

In [11]:
dsets = Datasets(get_image_files(path), tfms=tfms, splits=[train_idxs, valid_idxs])

In [12]:
dsets[0]

In [13]:
dls = dsets.dataloaders(
    after_item=[ToTensor(), RandomResizedCrop(460, min_scale=.75)],
    after_batch=[IntToFloatTensor(), *aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)],
    bs=32
)

In [14]:
dls.show_batch()

In [15]:
learn = vision_learner(dls, resnet34, metrics=[partial(accuracy_multi, thresh=0.95)])

In [16]:
learn.fine_tune(4, 2e-3)

In [17]:
learn.loss_func.thresh = 0.95

In [18]:
PERSIAN_CAT_URL = "https://azure.wgp-cdn.co.uk/app-yourcat/posts/iStock-174776419-1.jpg"

In [19]:
response = requests.get(PERSIAN_CAT_URL)
im = PILImage.create(response.content)

In [20]:
im.show();

In [21]:
learn.predict(im)[0]

In [22]:
DONKEY_URL = "https://cdn.britannica.com/68/143568-050-5246474F/Donkey.jpg"
response = requests.get(DONKEY_URL)
learn.predict(response.content)[0]