In [None]:
!pip install fastai==1.0.41 -q

In [None]:
import fastai
from fastai import *
from fastai.vision import *
print(f'fastai version: {fastai.__version__}')
print(f'torch version: {torch.__version__}')

verbose = False  # print out extra details?

# import matplotlib.patches.Path
from matplotlib.patches import Rectangle
%matplotlib inline

import json
import warnings
warnings.filterwarnings('ignore')# This Python 3 environment comes with many helpful analytics libraries installed

import os
print(os.listdir("../input"))

In [None]:
# glogal settings
data_fp = Path('../input')
data_train = data_fp/'whale-categorization-playground'/'train'/'train'
crop_fp = data_fp/'cropping_whale2'/'cropping.txt'  ## From this kernel, https://www.kaggle.com/martinpiotte/bounding-box-model/output
crop_coco = data_fp/'cropping-whale-coco'/'coco_whale.json'  # cropping.txt was convert to coco format so that fastai get_annotation can be used

bs = 64
num_workers = 0  # set to zero because get error : 'DataLoader worker (pid 56) is killed by signal: Bus error. '
sz = 224 ## resize images

## Create DataBunch with Coco Format

In [None]:
images, lbl_bbox = get_annotations(crop_coco)
img2bbox = dict(zip(images, lbl_bbox))
get_y_func = lambda o: img2bbox[Path(o).name]

In [None]:
tfm = get_transforms(flip_vert=False, 
                     # doesn't make sense to have upside down tails
                     max_rotate=0.3)
                     # rotating too much will cause the bbox to be super large and not accurate
if verbose: tfm  # show the list of transformation

In [None]:
data = (ObjectItemList.from_df(pd.DataFrame(data=images), path=data_train)
        .random_split_by_pct(seed=52)                          
        #How to split in train/valid? -> randomly with the default 20% in valid
        .label_from_func(get_y_func)
        #How to find the labels? -> use get_y_func
        .transform(get_transforms(), 
                   tfm_y= True, 
                   size=sz, 
                   resize_method=ResizeMethod.SQUISH,
                   padding_mode='border')
        #Data augmentation? -> Standard transforms with tfm_y=True
        .databunch(bs=bs, collate_fn=bb_pad_collate, num_workers=num_workers)   
        #Finally we convert to a DataBunch and we use bb_pad_collate
        .normalize(imagenet_stats))

### Show how data augmentation on image and the bounding box

In [None]:
idx = 65
fig, axes = plt.subplots(3,3, figsize=(9,9))
for i, ax in enumerate(axes.flat):
    img = data.train_ds[idx]
    # image is augmented each time it is retrived
    img[0].show(y=img[1], ax=ax)

In [None]:
data.show_batch(rows=2)

## Training 

In [None]:
# L1Loss is used instead of MSE is because MSE penalize mistake more than it should 
def loss_func(preds, targs, class_idx, **kwargs):
    return nn.L1Loss()(preds, targs.squeeze())

In [None]:
head_reg4 = nn.Sequential(Flatten(), nn.Linear(25088,4))
learn = create_cnn(data=data, arch=models.resnet18, pretrained=True, custom_head=head_reg4,
                  model_dir = '/tmp/models')
learn.loss_func = loss_func

In [None]:
if verbose: print(learn.summary())

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
learn.fit(25)

In [None]:
learn.recorder.plot_losses()

In [None]:
learn.unfreeze()

In [None]:
learn.fit(10)

In [None]:
learn.recorder.plot_losses()

## Check Result

In [None]:
preds, targs = learn.get_preds(ds_type=DatasetType.Valid)
targs = targs.squeeze()  # fastai expect multiple objects but we only have 1
# preds = torch.clamp(preds, -1,1) # making sure the preds values are within the picture

In [None]:
np.random.seed(24)
n = 10  # look at n samples
idxs = np.random.randint(0,len(data.valid_ds), size=n)
print(idxs)
_, axes = plt.subplots(nrows=n, ncols=2, figsize = (15,20))
for i, row in zip(idxs, axes):
    ## get the img from valid_ds
    ## get the targs and preds
    ## also need to get axxes
    img = data.valid_ds[i][0].data  # image resize after data is called else original image size
    img_name = Path(data.valid_ds.items[i]).name
    img_size = img.shape[1:]
    targ, pred = targs[i], preds[i]
    print(targ, pred)
#     pred = torch.tensor([-0.3,-.6,.1,.6])  # For testing
    for l, v, ax in zip(['Target', 'Prediction'], [targ, pred], row):
        Image(img).show(ax=ax,
                         y=ImageBBox.create(*img_size, 
                                            bboxes=v.unsqueeze(0),
                                            scale=False),
                        title=l + ":" + img_name)    
        
## It is noticably that the x axis of the bbox is really bad. It is consistently out of range of -1 to 1

In [None]:
np.random.seed(45)
n = 10  # look at n samples
idxs = np.random.randint(0,len(data.valid_ds), size=n)
print(idxs)
_, axes = plt.subplots(nrows=n, ncols=2, figsize = (15,20))
for i, row in zip(idxs, axes):
    ## get the img from valid_ds
    ## get the targs and preds
    ## also need to get axxes
    img = data.valid_ds[i][0].data  # image resize after data is called else original image size
    img_name = Path(data.valid_ds.items[i]).name
    img_size = img.shape[1:]
    targ, pred = targs[i], preds[i]
    print(targ, pred)
#     pred = torch.tensor([-0.3,-.6,.1,.6])  # For testing
    for l, v, ax in zip(['Target', 'Prediction'], [targ, pred], row):
        Image(img).show(ax=ax,
                         y=ImageBBox.create(*img_size, 
                                            bboxes=v.unsqueeze(0),
                                            scale=False),
                        title=l + ":" + img_name)    

In [None]:
pd.DataFrame(data = preds.numpy()).to_csv('testing.csv')