In [1]:
# http://www.cs.utoronto.ca/~gkoch/files/msc-thesis.pdf
# https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
from fastai.metrics import accuracy_thresh
from fastai.basic_data import *
from torch.utils.data import DataLoader, Dataset
from torch import nn
from fastai.callbacks.hooks import num_features_model, model_sizes
from fastai.layers import BCEWithLogitsFlat
from fastai.basic_train import Learner
from skimage.util import montage
import pandas as pd
from torch import optim
import re

from utils import *

In [3]:
# import fastai
# from fastprogress import force_console_behavior
# import fastprogress
# fastprogress.fastprogress.NO_BAR = True
# master_bar, progress_bar = force_console_behavior()
# fastai.basic_train.master_bar, fastai.basic_train.progress_bar = master_bar, progress_bar

Posing the problem as a classification task is probably not ideal. We are asking our NN to learn to recognize a whale out of 5004 possible candidates based on what it has learned about the whales. That is a tall order.

Instead, here we will try to pose the problem as a verification task. When presented with two images of whale flukes, we will ask the network - are the images of the same whale or of different whales? In particular, we will try to teach our network to learn features that can be useful in determining the similarity between whale images (hence the name of this approach - feature learning).

This seems like a much easier task, at least in theory. Either way, no need to start with a relatively big CNN like resnet50. Let's see what mileage we can get out of resnet18.

In [4]:
root_path = Path('../input')
train_path = root_path/'train'
test_path = root_path/'test'

In [5]:
# new architecture calls for a new validation set, this time our validation set will consist of all whales that have exactly two images
df = pd.read_csv(root_path/'train.csv')
im_count = df[df.Id != 'new_whale'].Id.value_counts()
im_count.name = 'sighting_count'
df = df.join(im_count, on='Id')
val_fns = set(df[df.sighting_count == 2].Image)

In [6]:
len(val_fns)

2570

In [7]:
fn2label = {row[1].Image: row[1].Id for row in df.iterrows()}
path2fn = lambda path: re.search('\w*\.jpg$', path).group(0)

name = f'res18-siamese'

In [8]:
SZ = 224
BS = 64
NUM_WORKERS = 6
SEED=0
train_path = root_path/f'train_{SZ}'
test_path = root_path/f'test_{SZ}'

In [9]:
# data_block api creates categories based on classes it sees in the train set and
# our val set contains whales whose ids do not appear in the train set
classes = df.Id.unique()

In [10]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], train_path, cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)], classes=classes)
        .add_test(ImageItemList.from_folder(test_path))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
#         .databunch(bs=BS, num_workers=NUM_WORKERS, path='data')
#         .normalize(imagenet_stats)
)

I am still using the ImageItemList even though I will create my own datasets. Why? Because I want to reuse the functionality that is already there (creating datasets from files, augmentations, resizing, etc).

I realize the code is neither clean nor elegant but for the time being I am happy with this approach.

In [11]:
def is_even(num): return num % 2 == 0

class TwoImDataset(Dataset):
    def __init__(self, ds):
        self.ds = ds
        self.whale_ids = ds.y.items
    def __len__(self):
        return 2 * len(self.ds)
    def __getitem__(self, idx):
        if is_even(idx):
            return self.sample_same(idx // 2)
        else: return self.sample_different((idx-1) // 2)
    def sample_same(self, idx):
        whale_id = self.whale_ids[idx]        
        candidates = list(np.where(self.whale_ids == whale_id)[0])
        candidates.remove(idx) # dropping our current whale - we don't want to compare against an identical image!
        
        if len(candidates) == 0: # oops, there is only a single whale with this id in the dataset
            return self.sample_different(idx)
        
        np.random.shuffle(candidates)
        return self.construct_example(self.ds[idx][0], self.ds[candidates[0]][0], 1)
    def sample_different(self, idx):
        whale_id = self.whale_ids[idx]
        candidates = list(np.where(self.whale_ids != whale_id)[0])
        np.random.shuffle(candidates)
        return self.construct_example(self.ds[idx][0], self.ds[candidates[0]][0], 0)
    
    def construct_example(self, im_A, im_B, class_idx):
        return [im_A, im_B], class_idx

In [12]:
train_dl = DataLoader(
    TwoImDataset(data.train),
    batch_size=BS,
    shuffle=True,
    num_workers=NUM_WORKERS
)
valid_dl = DataLoader(
    TwoImDataset(data.valid),
    batch_size=BS,
    shuffle=False,
    num_workers=NUM_WORKERS
)

data_bunch = ImageDataBunch(train_dl, valid_dl)

In [13]:
def normalize_batch(batch):
    stat_tensors = [torch.tensor(l).cuda() for l in imagenet_stats]
    return [normalize(batch[0][0], *stat_tensors), normalize(batch[0][1], *stat_tensors)], batch[1]

In [14]:
data_bunch.add_tfm(normalize_batch)

In [15]:
from functional import seq

class SiameseNetwork1(nn.Module):
    def __init__(self, arch=models.resnet18):
        super().__init__() 
        self.cnn = create_body(arch)
        self.head = nn.Linear(num_features_model(self.cnn), 1)
        
    def forward(self, im_A, im_B):
        # dl - distance layer
        x1, x2 = seq(im_A, im_B).map(self.cnn).map(self.process_features)
        dl = self.calculate_distance(x1, x2)
        out = self.head(dl)
        return out
    
    def process_features(self, x): return x.reshape(*x.shape[:2], -1).max(-1)[0]
    def calculate_distance(self, x1, x2): return (x1 - x2).abs()
    

In [16]:
class SiameseNetwork(nn.Module):
    def __init__(self, arch=models.resnet18):
        super().__init__()
        self.cnn = create_body(arch)
        self.fc = nn.Linear(num_features_model(self.cnn), 1)

    def im2emb(self, batch):
        x = self.cnn(batch)
        x = self.process_features(x)
        return x

    def forward1(self, im1, im2):
        x1 = self.cnn(im1)
        x1 = self.process_features(x1)
        x2 = self.cnn(im2)
        x2 = self.process_features(x2)
        dl = self.distance(x1, x2)
        out = self.fc(dl)
        return out

    def forward(self, im1, im2):
        x1 = self.cnn(im1)
        x1 = self.process_features(x1)
        x2 = self.cnn(im2)
        x2 = self.process_features(x2)
        dl = self.distance(x1, x2).mean(dim=1)
        return dl

    def process_features(self, x):
        return x.reshape(*x.shape[:2], -1).max(-1)[0]

    def distance(self, x1, x2):
        return (x1 - x2).abs()

    def similarity(self, x1, x2):
        dl = self.distance(x1, x2)
        logit = self.fc(dl)
        return torch.sigmoid(logit)



In [17]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss
    Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
    """
    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, distances, target, size_average=True):
        losses = target.float() * distances + (1 - target).float() * torch.relu(self.margin - distances)
        return losses.mean() if size_average else losses.sum()
    

Below I include two slightly different siamese networks. I leave the code commented out and choose to use the one above.

In [18]:
# from functional import seq

# def cnn_activations_count(model):
#     _, ch, h, w = model_sizes(create_body(models.resnet18), (SZ, SZ))[-1]
#     return ch * h * w

# class SiameseNetwork(nn.Module):
#     def __init__(self, lin_ftrs=2048, arch=models.resnet18):
#         super().__init__() 
#         self.cnn = create_body(arch)
#         self.fc1 = nn.Linear(cnn_activations_count(self.cnn), lin_ftrs)
#         self.fc2 = nn.Linear(lin_ftrs, 1)
        
#     def forward(self, im_A, im_B):
#         x1, x2 = seq(im_A, im_B).map(self.cnn).map(self.process_features).map(self.fc1)
#         dl = self.calculate_distance(x1.sigmoid(), x2.sigmoid())
#         out = self.fc2(dl)
#         return out
    
#     def calculate_distance(self, x1, x2): return (x1 - x2).abs_()
#     def process_features(self, x): return x.reshape(x.shape[0], -1)

In [19]:
# from functional import seq

# def cnn_activations_count(model):
#     _, ch, h, w = model_sizes(create_body(models.resnet18), (SZ, SZ))[-1]
#     return ch * h * w

# class SiameseNetwork(nn.Module):
#     def __init__(self, lin_ftrs=2048, pool_to=3, arch=models.resnet18, pooling_layer=nn.AdaptiveMaxPool2d):
#         super().__init__() 
#         self.cnn = create_body(arch)
#         self.pool = pooling_layer(pool_to)
#         self.fc1 = nn.Linear(num_features_model(self.cnn) * pool_to**2, lin_ftrs)
#         self.fc2 = nn.Linear(lin_ftrs, 1)
        
#     def forward(self, im_A, im_B):
#         x1, x2 = seq(im_A, im_B).map(self.cnn).map(self.pool).map(self.process_features).map(self.fc1)
#         dl = self.calculate_distance(x1.sigmoid(), x2.sigmoid())
#         out = self.fc2(dl)
#         return out
    
#     def calculate_distance(self, x1, x2): return (x1 - x2).abs_()
#     def process_features(self, x): return x.reshape(x.shape[0], -1)

In [20]:
learn = Learner(data_bunch, SiameseNetwork(), loss_func=BCEWithLogitsFlat(), metrics=[lambda preds, targs: accuracy_thresh(preds.squeeze(), targs, sigmoid=False)])
#learn = Learner(data_bunch, SiameseNetwork(), loss_func=ContrastiveLoss(margin=1.0), metrics=[lambda preds, targs: accuracy_thresh(preds.squeeze(), targs, sigmoid=False)])
#learn = Learner(data_bunch, SiameseNet(), loss_func=BCEWithLogitsFlat(), metrics=[lambda preds, targs: accuracy_thresh(preds.squeeze(), targs, sigmoid=False)])

In [21]:
learn.split([learn.model.cnn[:6], learn.model.cnn[6:], learn.model.fc])

In [None]:
learn.freeze_to(-1)

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(4, 1e-2)

In [None]:
learn.save(f'{name}-stage-1')

In [22]:
learn.unfreeze()

In [None]:
max_lr = 5e-4
lrs = [max_lr/100, max_lr/10, max_lr]

In [None]:
learn.fit_one_cycle(10, lrs)

In [None]:
learn.save(f'{name}-stage-2')

In [None]:
learn.recorder.plot_losses()

The model is not doing that well - out of presented pairs it gets roughly 10% of examples wrong. I also did a cursory error analysis (not shown here for the sake of brevity) and the model is not doing that great at all.

How can this be? Maybe the nearly absolute positional invariance through the use of global max pooling is not working that well. Maybe there is a bug somewhere? Maybe the model has not been trained for long enough or lacks capacity?

If I do continue to work on this I will definitely take a closer look at each of the angles I list above. For the time being, let's try to predict on the validation set and finish off with making a submission.

The predicting part is where the code gets really messy. That is good enough for now though.

In [23]:
learn.load(f'{name}-stage-2');

In [24]:
new_whale_fns = set(df[df.Id == 'new_whale'].sample(frac=1).Image.iloc[:1000])

In [25]:
data = (
    ImageItemList
        .from_df(df, train_path, cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns.union(new_whale_fns))
        .label_from_func(lambda path: fn2label[path2fn(path)], classes=classes)
        .add_test(ImageItemList.from_folder(test_path))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path=root_path)
        .normalize(imagenet_stats)
)

In [26]:
len(data.valid_ds)

3570

In [27]:
len(data.train_ds)

21791

In [34]:
data.train_ds.y

CategoryList (21791 items)
[Category w_f48451c, Category w_c3d896a, Category w_20df2c5, Category new_whale, Category new_whale]...
Path: ../input/train_224

In [30]:
from utils import siamese_validate

In [37]:
map5, pos_dist_max, neg_dist_min = siamese_validate(data.valid_dl, learn.model, data.train_dl,
                                                            pos_mask=[0], ref_idx2class=data.train_ds.y,
                                                            target_idx2class=data.valid_ds.y)

dist_pos_max = 0.7686152458190918, dist_neg_min = 0.21202769875526428


IndexError: index 3572 is out of bounds for axis 0 with size 3570

In [28]:
%%time
targs = []
feats = []
learn.model.eval()
for ims, ts in data.valid_dl:
    feats.append(learn.model.im2emb(ims).detach().cpu())
    targs.append(ts)

CPU times: user 2.58 s, sys: 992 ms, total: 3.57 s
Wall time: 4.35 s


In [29]:
feats = torch.cat(feats)

In [None]:
feats.shape

In [None]:
%%time
sims = []
for feat in feats:
    dists = learn.model.distance(feats, feat.unsqueeze(0).repeat(3570, 1))
    dists = dists.detach().cpu()
    #predicted_similarity = learn.model.head(dists.cuda()).sigmoid_()
    sims.append(dists)

In [None]:
len(sims[0])

In [None]:
new_whale_idx = np.where(classes == 'new_whale')[0][0]

In [None]:
sims[0].argsort(descending=True).shape

In [None]:
%%time
top_5s = []
for i, sim in enumerate(sims):
    idxs = sim.argsort(descending=True)
    probs = sim[idxs]
    top_5 = []
    for j, p in zip(idxs, probs):
        if len(top_5) == 5: break
        if j == i: continue
        predicted_class = data.valid_ds.y.items[j]
        if j == predicted_class: continue
        if predicted_class not in top_5: top_5.append(predicted_class)
    top_5s.append(top_5)

In [None]:
# without predicting new_whale
mapk(data.valid_ds.y.items.reshape(-1,1), np.stack(top_5s), 5)

In [None]:
%%time

for thresh in np.linspace(0.98, 1, 10):
    top_5s = []
    for i, sim in enumerate(sims):
        idxs = sim.argsort(descending=True)
        probs = sim[idxs]
        top_5 = []
        for j, p in zip(idxs, probs):
            if new_whale_idx not in top_5 and p < thresh and len(top_5) < 5: top_5.append(new_whale_idx)
            if len(top_5) == 5: break
            if j == new_whale_idx or j == i: continue
            predicted_class = data.valid_ds.y.items[j]
            if predicted_class not in top_5: top_5.append(predicted_class)
        top_5s.append(top_5)
    print(thresh, mapk(data.valid_ds.y.items.reshape(-1,1), np.stack(top_5s), 5))

There are many reasons why the best threshold here might not carry over to what would make sense on the test set. It is some indication though of how our model is doing and a useful data point.

## Predict

In [26]:
len(data.test_ds)

7960

In [27]:
data = (
    ImageItemList
        .from_df(df, train_path, cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in {'69823499d.jpg'}) # in newer version of the fastai library there is .no_split that could be used here
        .label_from_func(lambda path: fn2label[path2fn(path)], classes=classes)
        .add_test(ImageItemList.from_folder(test_path))
        .transform(None, size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path=root_path)
        .normalize(imagenet_stats)
)

In [28]:
%%time
test_feats = []
learn.model.eval()
for ims, _ in data.test_dl:
    test_feats.append(learn.model.process_features(learn.model.cnn(ims)).detach().cpu())

CPU times: user 5.68 s, sys: 1.76 s, total: 7.44 s
Wall time: 8.03 s


In [29]:
%%time
train_feats = []
train_class_idxs = []
learn.model.eval()
for ims, t in data.train_dl:
    train_feats.append(learn.model.process_features(learn.model.cnn(ims)).detach().cpu())
    train_class_idxs.append(t)

CPU times: user 18 s, sys: 4.76 s, total: 22.7 s
Wall time: 23 s


In [30]:
train_class_idxs = torch.cat(train_class_idxs)
train_feats = torch.cat(train_feats)

In [31]:
test_feats = torch.cat(test_feats)

In [32]:
train_feats.shape

torch.Size([25344, 512])

In [33]:
%%time
test_feats = []
learn.model.eval()
for ims, _ in data.test_dl:
    test_emb = (learn.model.im2emb(ims))
    break
    

CPU times: user 36 ms, sys: 192 ms, total: 228 ms
Wall time: 457 ms


In [35]:
test_emb.shape

torch.Size([64, 512])

In [None]:
%%time
sims = []
for feat in test_feats:
    dists = learn.model.distance(train_feats, feat.unsqueeze(0).repeat(25344, 1))
    #predicted_similarity = learn.model.head(dists.cuda()).sigmoid_()
    sims.append(predicted_similarity.squeeze().detach().cpu())

In [None]:
%%time
thresh = 1

top_5s = []
for sim in sims:
    idxs = sim.argsort(descending=True)
    probs = sim[idxs]
    top_5 = []
    for i, p in zip(idxs, probs):
        if new_whale_idx not in top_5 and p < thresh and len(top_5) < 5: top_5.append(new_whale_idx)
        if len(top_5) == 5: break
        if i == new_whale_idx: continue
        predicted_class = train_class_idxs[i]
        if predicted_class not in top_5: top_5.append(predicted_class)
    top_5s.append(top_5)

In [None]:
top_5_classes = []
for top_5 in top_5s:
    top_5_classes.append(' '.join([classes[t] for t in top_5]))

In [None]:
top_5_classes[:5]

In [None]:
sub = pd.DataFrame({'Image': [path.name for path in data.test_ds.x.items]})
sub['Id'] = top_5_classes
sub.to_csv(f'../submission/{name}.csv.gz', index=False, compression='gzip')

In [None]:
pd.read_csv(f'../submission/{name}.csv.gz').head()

In [None]:
pd.read_csv(f'../submission/{name}.csv.gz').Id.str.split().apply(lambda x: x[0] == 'new_whale').mean()

In [None]:
!kaggle competitions submit -c humpback-whale-identification -f subs/{name}.csv.gz -m "{name}"