In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.basic_data import *
from skimage.util import montage
import pandas as pd
from torch import optim
import re

from utils import *

I take a curriculum approach to training here. I first expose the model to as many different images of whales as quickly as possible (no oversampling) and train on images resized to 224x224.

I would like the conv layers to start picking up on features useful for identifying whales. For that, I want to show the model as rich of a dataset as possible.

I then train on images resized to 448x448.

Finally, I train on oversampled data. Here, the model will see some images more often than others but I am hoping that this will help alleviate the class imbalance in the training data.

In [2]:
import fastai
from fastprogress import force_console_behavior
import fastprogress
fastprogress.fastprogress.NO_BAR = True
master_bar, progress_bar = force_console_behavior()
fastai.basic_train.master_bar, fastai.basic_train.progress_bar = master_bar, progress_bar

In [4]:
df = pd.read_csv('../input/train.csv')
val_fns = {'69823499d.jpg'}

In [5]:
fn2label = {row[1].Image: row[1].Id for row in df.iterrows()}
path2fn = lambda path: re.search('\w*\.jpg$', path).group(0)

In [6]:
name = f'res50-full-train'

In [7]:
SZ = 224
BS = 32
NUM_WORKERS = 6
SEED=0

In [10]:
data = (
    ImageList
        .from_df(df[df.Id != 'new_whale'], '../input/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageList.from_folder('../input/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../')
        .normalize(imagenet_stats)
)

In [12]:
%%time

learn = create_cnn(data, models.resnet50, lin_ftrs=[2048])
learn.clip_grad();

#learn.fit_one_cycle(14, 1e-2)
#learn.save(f'{name}-stage-1')

#learn.unfreeze()

max_lr = 1e-3
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(24, lrs)
learn.save(f'{name}-stage-2')

  warn("`create_cnn` is deprecated and is now named `cnn_learner`.")


epoch     train_loss  valid_loss  time    


KeyboardInterrupt: 

In [6]:
SZ = 224 * 2
BS = 64 // 4
NUM_WORKERS = 12
SEED=0

In [7]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], 'data/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('data/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='data')
        .normalize(imagenet_stats)
)

In [8]:
%%time
learn = create_cnn(data, models.resnet50, lin_ftrs=[2048])
learn.clip_grad();
learn.load(f'{name}-stage-2')
learn.freeze_to(-1)

learn.fit_one_cycle(12, 1e-2 / 4)
learn.save(f'{name}-stage-3')

learn.unfreeze()

max_lr = 1e-3 / 4
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(22, lrs)
learn.save(f'{name}-stage-4')

epoch     train_loss  valid_loss
1         1.100031    0.000000    
3         1.335055    0.000000    
4         1.674122    0.000000    
5         1.785136    0.000000    
6         1.717228    0.000000    
7         1.412960    0.000000    
8         1.303269    0.000000    
9         1.008257    0.000000    
10        0.796222    0.000000    
11        0.634087    0.000000    
12        0.487326    0.000000    
epoch     train_loss  valid_loss
1         0.482283    0.000000    
2         0.492100    0.000000    
3         0.563699    0.000000    
4         0.571843    0.000000    
5         0.650438    0.000000    
6         0.695321    0.000000    
7         0.700596    0.000000    
8         0.615317    0.000000    
9         0.678798    0.000000    
10        0.616675    0.000000    
11        0.715437    0.000000    
12        0.628833    0.000000    
13        0.616170    0.000000    
14        0.530670    0.000000    
15        0.458034    0.000000    
16        0.467264    0.

In [9]:
# with oversampling
df = pd.read_csv('data/oversampled_train_and_val.csv')

In [10]:
data = (
    ImageItemList
        .from_df(df, 'data/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('data/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='data')
        .normalize(imagenet_stats)
)

In [11]:
%%time
learn = create_cnn(data, models.resnet50, lin_ftrs=[2048])
learn.clip_grad();
learn.load(f'{name}-stage-4')
learn.freeze_to(-1)

learn.fit_one_cycle(2, 1e-2 / 4)
learn.save(f'{name}-stage-5')

learn.unfreeze()

max_lr = 1e-3 / 4
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(3, lrs)
learn.save(f'{name}-stage-6')

epoch     train_loss  valid_loss
1         1.626801    0.000010    
2         0.566748    0.000010    
epoch     train_loss  valid_loss
1         0.604931    0.000121    
2         0.531284    0.000026    
3         0.442735    0.000039    
CPU times: user 1h 25min 46s, sys: 38min 1s, total: 2h 3min 48s
Wall time: 2h 3min 59s


## Predict

In [12]:
preds, _ = learn.get_preds(DatasetType.Test)

In [13]:
preds = torch.cat((preds, torch.ones_like(preds[:, :1])), 1)

In [14]:
preds[:, 5004] = 0.06

In [15]:
classes = learn.data.classes + ['new_whale']

In [16]:
create_submission(preds, learn.data, name, classes)

In [17]:
pd.read_csv(f'subs/{name}.csv.gz').head()

Unnamed: 0,Image,Id
0,47380533f.jpg,w_6c995fd new_whale w_7206ab2 w_54ea24d w_620dffe
1,1d9de38ba.jpg,w_641df87 new_whale w_e99ed06 w_3e6cee1 w_0b7ce1e
2,b3d4ee916.jpg,new_whale w_23ce00e w_bc7de9f w_71a1a08 w_708c3d2
3,460fd63ae.jpg,new_whale w_0bb71d3 w_9eab46a w_60cf87c w_42388df
4,79738ffc1.jpg,new_whale w_1419d90 w_01976db w_dbf651b w_415dea0


In [18]:
pd.read_csv(f'subs/{name}.csv.gz').Id.str.split().apply(lambda x: x[0] == 'new_whale').mean()

0.48693467336683416

In [19]:
!kaggle competitions submit -c humpback-whale-identification -f subs/{name}.csv.gz -m "{name}"

100%|████████████████████████████████████████| 183k/183k [00:04<00:00, 37.6kB/s]
Successfully submitted to Humpback Whale Identification