In [1]:
from fastai.vision.all import *

from sklearn.model_selection import StratifiedKFold

In [2]:
path = untar_data(URLs.PETS)
fnames = get_image_files(path/'images')
pat = r'(.+)_\d+.jpg$'
item_tfms = [RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.)), ToTensor()]
batch_tfms = [IntToFloatTensor(), *aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]
batch_size = 64

In [3]:
random.shuffle(fnames)

train_fnames = [filename for filename in fnames[:int(len(fnames) * .9)]]
test_fnames = [filename for filename in fnames[int(len(fnames) * .9):]]

In [4]:
vocab = list(map(RegexLabeller(pat=r'/([^/]+)_\d+.*'), train_fnames))

In [5]:
pipe = Pipeline([
    RegexLabeller(pat=r'/([^/]+)_\d+.*'), Categorize(vocab=vocab)
])

In [6]:
labels = list(map(pipe, train_fnames))

In [7]:
splits = []
skf = StratifiedKFold(n_splits=10, shuffle=True)
for _, valid_indexes in skf.split(
    np.zeros(len(labels)), labels
):
    split = IndexSplitter(valid_indexes)
    splits.append(split)

In [8]:
valid_pcts = []
test_preds = []

In [9]:
def train(splitter:IndexSplitter):
    "Trains a single model over a set of splits based on `splitter`"
    dset = Datasets(
        train_fnames,
        tfms = [
            [PILImage.create], 
            [RegexLabeller(pat=r'/([^/]+)_\d+.*'), Categorize]
        ],
        splits = splitter(train_fnames)
    )
    dls = dset.dataloaders(
        bs=batch_size,
        after_item=item_tfms,
        after_batch=batch_tfms
    )
    learn = vision_learner(dls, resnet34, metrics=accuracy)
    learn.fit_one_cycle(1)
    valid_pcts.append(learn.validate()[1])
    dl = learn.dls.test_dl(test_fnames)
    preds, _ = learn.get_preds(dl=dl)
    test_preds.append(preds)

In [10]:
for splitter in splits:
    train(splitter)

In [11]:
test_labels = torch.stack([pipe(fname) for fname in test_fnames])
accuracy(test_preds[0], test_labels)

TensorBase(0.8863)

In [12]:
for preds in test_preds:
    print(accuracy(preds, test_labels))

TensorBase(0.8863)
TensorBase(0.8890)
TensorBase(0.8945)
TensorBase(0.8850)
TensorBase(0.8850)
TensorBase(0.8958)
TensorBase(0.8931)
TensorBase(0.8985)
TensorBase(0.8985)
TensorBase(0.8985)


In [13]:
votes = torch.stack(test_preds, dim=-1).sum(-1) / 5

In [14]:
accuracy(votes, test_labels)

TensorBase(0.9215)