In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from fastai.callback.fp16 import to_fp16
from fastai.vision import *
from fastai.data.all import *
from fastai.vision.all import *

# function definitions

In [None]:
def dl_accuracy(learner, dl, gpu=False):
    act, true_labels = learner.get_preds(dl=dl)
    preds = torch.argmax(act, axis=1)
    acc = torch.where(true_labels == preds, 1., 0.).mean()
    return acc.item()

In [None]:
def nosplit(o): return L(int(i) for i in range(len(o))), L()

# load data

In [None]:
if torch.cuda.is_available():
    use_gpu = True
    gpu = torch.device('cuda')
    print(torch.cuda.get_device_name(0))
else:
    print('No GPU available')

In [None]:
train_dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
                                   get_items=get_image_files, 
                                   splitter=RandomSplitter(seed=42),
                                   get_y=parent_label,
                                   item_tfms=Resize(460),
                                   batch_tfms=[*aug_transforms(size=224, min_scale=0.75),
                                               Normalize.from_stats(*imagenet_stats)])

In [None]:
test_dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
                                  get_items=get_image_files, 
                                  splitter=nosplit,
                                  get_y=parent_label,
                                  item_tfms=Resize(460),
                                  batch_tfms=[*aug_transforms(size=224, min_scale=0.75),
                                              Normalize.from_stats(*imagenet_stats)])

In [None]:
train_dls = train_dblock.dataloaders('austrian_birds_dataset/images/train/', bs=128)
test_dls = test_dblock.dataloaders('austrian_birds_dataset/images/test/', bs=64)

# Resnet50

In [None]:
# create a resnet50 learner
resnet50_learner = vision_learner(train_dls, resnet50, metrics=accuracy)

In [None]:
# use lr_find to get a good estimate for the base learning rate
lr_min,lr_steep = learn.lr_find(suggest_funcs=(minimum, steep))
print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")

In [None]:
# training
resnet50_learner.fine_tune(40, freeze_epochs=3, base_lr=2e-2)

In [None]:
# check accuracy on the test data set
dl_accuracy(resnet50_learner, test_dls.train)

In [None]:
# export the pickled model
resnet50_learner.export('models/bird_classifier_resnet50.pkl', with_opt=False)