## CIFAR 10

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai.conv_learner import *
PATH = "data/cifar10/"
os.makedirs(PATH,exist_ok=True)

### Downloading CIFAR 10

In [None]:
from fastai import io

In [None]:
import tarfile

In [None]:
def untar_file(file_path, save_path):
    if file_path.endswith('.tar.gz') or file_path.endswith('.tgz'):
        obj = tarfile.open(file_path)
        obj.extractall(save_path)
        obj.close()
        os.remove(file_path)

In [None]:
cifar_url = 'http://files.fast.ai/data/cifar10.tgz' # faster download
# cifar_url = 'http://pjreddie.com/media/files/cifar.tgz'

io.get_data(cifar_url, 'data/cifar10.tgz')
untar_file('data/cifar10.tgz', 'data/')

### Load classes

In [3]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

In [4]:
def get_data(sz,bs):
    tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)
    return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)

In [5]:
bs=128

### Look at data

In [None]:
data = get_data(32,4)

In [None]:
x,y=next(iter(data.trn_dl))

In [None]:
plt.imshow(data.trn_ds.denorm(x)[0]);

In [None]:
plt.imshow(data.trn_ds.denorm(x)[1]);

## Initial model

In [13]:
from fastai.models.cifar10.resnext import resnext29_8_64

m = resnext29_8_64()
# m = resnet50(False)
bm = BasicModel(m.half().cuda(), name='cifar10_resnet50')

In [14]:
data = get_data(8,bs*4*4)

In [15]:
data.trn_dl.fp16

False

In [16]:
data.half()

In [None]:
it = iter(data.trn_dl)

In [None]:
x,y = next(it)

In [None]:
type(x)

In [17]:
learn = ConvLearner(data, bm)
learn.unfreeze()

In [18]:
lr=5e-2; wd=5e-4

In [19]:
learn.lr_find()

epoch      trn_loss   val_loss   accuracy                 
    0      5.416119   nan        0.090972  



In [None]:
learn.sched.plot()

In [None]:
%pdb off

In [20]:
%time learn.fit(lr, 1, cycle_len=1, use_clr=(20,8))

epoch      trn_loss   val_loss   accuracy                 
    0      3.641495   2.3375     0.192857  

CPU times: user 1min 6s, sys: 22.1 s, total: 1min 28s
Wall time: 1min 14s


[2.3375, 0.1928572654724121]

In [None]:
%time learn.fit(lr, 2, cycle_len=1)

In [None]:
%time learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)

In [None]:
learn.sched.plot_lr()

In [None]:
learn.save('8x8_8')

## 16x16

In [None]:
learn.load('8x8_8')

In [None]:
learn.set_data(get_data(16,bs*2))

In [None]:
%time learn.fit(1e-3, 1, wds=wd)

In [None]:
learn.unfreeze()

In [None]:
learn.lr_find()

In [None]:
learn.sched.plot()

In [None]:
lr=1e-2

In [None]:
%time learn.fit(lr, 2, cycle_len=1, wds=wd)

In [None]:
%time learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)

In [None]:
learn.save('16x16_8')

## 24x24

In [None]:
learn.load('16x16_8')

In [None]:
learn.set_data(get_data(24,bs))

In [None]:
%time learn.fit(1e-2, 1, wds=wd)

In [None]:
learn.unfreeze()

In [None]:
%time learn.fit(lr, 1, cycle_len=1, wds=wd)

In [None]:
%time learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)

In [None]:
learn.save('24x24_8')

In [None]:
log_preds,y = learn.TTA()
preds = np.mean(np.exp(log_preds),0)metrics.log_loss(y,preds), accuracy(preds,y)

## 32x32

In [None]:
learn.load('24x24_8')

In [None]:
learn.set_data(get_data(32,bs))

In [None]:
%time learn.fit(1e-2, 1, wds=wd)

In [None]:
learn.unfreeze()

In [None]:
%time learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd)

In [None]:
%time learn.fit(lr, 3, cycle_len=4, wds=wd)

In [None]:
log_preds,y = learn.TTA()
metrics.log_loss(y,np.exp(log_preds)), accuracy(log_preds,y)

In [None]:
learn.save('32x32_8')