In [None]:
import os
GPU_id = 0
os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_id)

In [None]:
import warnings
warnings.filterwarnings("ignore")

import fastai
print(fastai.__version__)
from fastai.vision import *
from fastai.callbacks import SaveModelCallback
import time

### Create a Path instance

In [None]:
path = Path('../input/hymenoptera-data/hymenoptera_data')
print(type(path))
path.ls()

In [None]:
(path/'train').ls()

### Create an ImageList instance

In [None]:
il = ImageList.from_folder(path)
il.items[0]

In [None]:
il

In [None]:
il[0].show()

### Create item lists for train and valid

In [None]:
sd = il.split_by_folder(train='train', valid='val')
sd

### Create a label list

In [None]:
ll = sd.label_from_folder()
ll

### Show an image with label

In [None]:
%%time
x,y = ll.train[0]
x.show()
print(y,x.shape)

### Apply transformations

In [None]:
tfms = get_transforms(max_rotate=25); len(tfms)

In [None]:
ll = ll.transform(tfms,size=224)

In [None]:
%%time
x,y = ll.train[0]
x.show()
print(y,x.shape)

### Create a databunch instance

In [None]:
%%time
bs = 32
data = ll.databunch(bs=bs).normalize(imagenet_stats)

In [None]:
x,y = data.train_ds[0]
x.show()
print(y)

### Show random transformations of the same image

In [None]:
def _plot(i,j,ax): data.train_ds[0][0].show(ax)
plot_multi(_plot, 3, 3, figsize=(8,8))

### show a batch of images with labels

In [None]:
xb,yb = data.one_batch()
print(xb.shape,yb.shape)
data.show_batch(rows=3, figsize=(10,8))

### Create a CNN learner

In [None]:
%%time
learn = cnn_learner(data, models.resnet18, metrics=accuracy)
learn.model_dir = '/kaggle/working/models'

### find a proper learning rate

In [None]:
!pwd

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

### training

In [None]:
learn.fit_one_cycle(10,max_lr=slice(0.007),callbacks=[
            SaveModelCallback(learn, every='improvement', monitor='accuracy'),
            ])

In [None]:
pred, truth = learn.get_preds()

In [None]:
pred = pred.numpy()
truth = truth.numpy()
acc = np.mean(np.argmax(pred,axis=1) == truth)
print('Validation Accuracy %.4f'%acc)