In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.basic_data import *
from skimage.util import montage
import pandas as pd
from torch import optim
import re

from utils import *

In [2]:
import fastai
from fastprogress import force_console_behavior
import fastprogress
fastprogress.fastprogress.NO_BAR = True
master_bar, progress_bar = force_console_behavior()
fastai.basic_train.master_bar, fastai.basic_train.progress_bar = master_bar, progress_bar

df = pd.read_csv('../input/train.csv')
val_fns = {'69823499d.jpg'}

In [3]:
fn2label = {row[1].Image: row[1].Id for row in df.iterrows()}
path2fn = lambda path: re.search('\w*\.jpg$', path).group(0)

In [11]:
name = f'res50-full-train'

SZ = 224
BS = 32
NUM_WORKERS = 0
SEED=0

In [12]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], '../input/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../input/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../input')
        .normalize(imagenet_stats)
)

In [13]:
%%time
MODEL_PATH = "/tmp/model/"
learn = create_cnn(data, models.resnet50, lin_ftrs=[2048], model_dir=MODEL_PATH)
learn.clip_grad();

CPU times: user 1.04 s, sys: 320 ms, total: 1.36 s
Wall time: 1.29 s


In [16]:
%%time
learn.fit_one_cycle(14, 1e-2)

epoch     train_loss  valid_loss


KeyboardInterrupt: 

In [17]:
max_lr = 1e-3
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(24, lrs)#24
learn.save(f'{name}-stage-2')

epoch     train_loss  valid_loss
1         5.282892    0.029420    


In [18]:
SZ = 224 * 2
BS = 64 // 4
NUM_WORKERS = 0
SEED=0

In [19]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], '../input/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../input/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../input')
        .normalize(imagenet_stats)
)

In [20]:
%%time
learn = create_cnn(data, models.resnet50, lin_ftrs=[2048], model_dir=MODEL_PATH)
learn.clip_grad();
learn.load(f'{name}-stage-2')
learn.freeze_to(-1)

learn.fit_one_cycle(12, 1e-2 / 4)#12
learn.save(f'{name}-stage-3')

learn.unfreeze()

max_lr = 1e-3 / 4
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(24, lrs)#22
learn.save(f'{name}-stage-4')

epoch     train_loss  valid_loss
1         6.013852    0.005453    
epoch     train_loss  valid_loss
1         5.170578    0.001835    
CPU times: user 44min 11s, sys: 14min 56s, total: 59min 8s
Wall time: 58min 56s


In [21]:
preds, _ = learn.get_preds(DatasetType.Test)

In [24]:
!git clone https://github.com/radekosmulski/whale

Cloning into 'whale'...
remote: Enumerating objects: 22, done.[K
remote: Counting objects: 100% (22/22), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 22 (delta 7), reused 19 (delta 6), pack-reused 0[K
Unpacking objects: 100% (22/22), done.


In [25]:
from whale.utils import *

def create_submission(preds, data, name, classes=None):
    if not classes: classes = data.classes
    sub = pd.DataFrame({'Image': [path.name for path in data.test_ds.x.items]})
    sub['Id'] = top_5_pred_labels(preds, classes)
    sub.to_csv(f'{name}.csv', index=False) # compression='gzip'
    
create_submission(preds, learn.data, name)

pd.read_csv(f'{name}.csv').head()


Unnamed: 0,Image,Id
0,f28e2a7e7.jpg,w_a6703dd w_d0bfef3 w_23a388d w_0369a5c w_d0475b2
1,f1a620ed9.jpg,w_9438119 w_242fb46 w_01f14e1 w_025911c w_b950c88
2,1613db994.jpg,w_1f0cf0a w_8a1786f w_6c3ec2d w_5e8e218 w_eaed433
3,dac7f10b4.jpg,w_67a9841 w_5d6ba39 w_04003e9 w_f602022 w_fccccec
4,777e2025a.jpg,w_89f6097 w_51fc1fc w_9b5109b w_31b5dd8 w_8e88f4f


credit: [https://github.com/radekosmulski/whale](https://github.com/radekosmulski/whale)