In [None]:
from fastai.vision.all import *

from sklearn.model_selection import StratifiedKFold

In [None]:
path = untar_data(URLs.PETS)
fnames = get_image_files(path/'images')
pat = r'(.+)_\d+.jpg$'
item_tfms = [RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.)), ToTensor()]
batch_tfms = [IntToFloatTensor(), *aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]
batch_size = 64

In [None]:
random.shuffle(fnames)

train_fnames = [filename for filename in fnames[:int(len(fnames) * .9)]]
test_fnames = [filename for filename in fnames[int(len(fnames) * .9):]]

In [None]:
vocab = list(map(RegexLabeller(pat=r'/([^/]+)_\d+.*'), train_fnames))

In [None]:
pipe = Pipeline([
    RegexLabeller(pat=r'/([^/]+)_\d+.*'), Categorize(vocab=vocab)
])

In [None]:
labels = list(map(pipe, train_fnames))

In [None]:
splits = []
skf = StratifiedKFold(n_splits=10, shuffle=True)
for _, valid_indexes in skf.split(
    np.zeros(len(labels)), labels
):
    split = IndexSplitter(valid_indexes)
    splits.append(split)

In [None]:
valid_pcts = []
test_preds = []

In [None]:
def train(splitter:IndexSplitter):
    "Trains a single model over a set of splits based on `splitter`"
    dset = Datasets(
        train_fnames,
        tfms = [
            [PILImage.create], 
            [RegexLabeller(pat=r'/([^/]+)_\d+.*'), Categorize]
        ],
        splits = splitter(train_fnames)
    )
    dls = dset.dataloaders(
        bs=batch_size,
        after_item=item_tfms,
        after_batch=batch_tfms
    )
    learn = vision_learner(dls, resnet34, metrics=accuracy)
    learn.fit_one_cycle(1)
    valid_pcts.append(learn.validate()[1])
    dl = learn.dls.test_dl(test_fnames)
    preds, _ = learn.get_preds(dl=dl)
    test_preds.append(preds)

In [None]:
for splitter in splits:
    train(splitter)

epoch,train_loss,valid_loss,accuracy,time
0,1.201077,0.350736,0.884384,00:32


epoch,train_loss,valid_loss,accuracy,time
0,1.131611,0.360337,0.885714,00:31


epoch,train_loss,valid_loss,accuracy,time
0,1.140909,0.401147,0.881203,00:31


epoch,train_loss,valid_loss,accuracy,time
0,1.164429,0.397282,0.87218,00:32


epoch,train_loss,valid_loss,accuracy,time
0,1.178342,0.423457,0.879699,00:31


epoch,train_loss,valid_loss,accuracy,time
0,1.143618,0.354578,0.890226,00:31


epoch,train_loss,valid_loss,accuracy,time
0,1.187801,0.380514,0.87218,00:32


epoch,train_loss,valid_loss,accuracy,time
0,1.17831,0.334488,0.890226,00:32


epoch,train_loss,valid_loss,accuracy,time
0,1.172838,0.354474,0.885714,00:32


epoch,train_loss,valid_loss,accuracy,time
0,1.221658,0.352071,0.897744,00:32


In [None]:
test_labels = torch.stack([pipe(fname) for fname in test_fnames])
accuracy(test_preds[0], test_labels)

TensorBase(0.8877)

In [None]:
for preds in test_preds:
    print(accuracy(preds, test_labels))

TensorBase(0.8877)
TensorBase(0.8904)
TensorBase(0.9039)
TensorBase(0.8904)
TensorBase(0.8796)
TensorBase(0.9039)
TensorBase(0.8945)
TensorBase(0.8863)
TensorBase(0.9012)
TensorBase(0.8904)


In [None]:
votes = torch.stack(test_preds, dim=-1).sum(-1) / 5

In [None]:
accuracy(votes, test_labels)

TensorBase(0.9215)