In [2]:
import argparse
import builtins
import math
import os
import random
import shutil
import time
import warnings
from pathlib import Path

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from PIL import Image, ImageFilter
from tqdm import tqdm

import timm

In [3]:
timm.list_models('*nfnet*')

['dm_nfnet_f0',
 'dm_nfnet_f1',
 'dm_nfnet_f2',
 'dm_nfnet_f3',
 'dm_nfnet_f4',
 'dm_nfnet_f5',
 'dm_nfnet_f6',
 'eca_nfnet_l0',
 'eca_nfnet_l1',
 'nfnet_f0',
 'nfnet_f0s',
 'nfnet_f1',
 'nfnet_f1s',
 'nfnet_f2',
 'nfnet_f2s',
 'nfnet_f3',
 'nfnet_f3s',
 'nfnet_f4',
 'nfnet_f4s',
 'nfnet_f5',
 'nfnet_f5s',
 'nfnet_f6',
 'nfnet_f6s',
 'nfnet_f7',
 'nfnet_f7s',
 'nfnet_l0']

In [13]:
model = timm.create_model('ig_resnext101_32x8d', features_only=True, pretrained=True)

import torch.nn.functional as F
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

class ISCDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        paths,
        transforms,
    ):
        self.paths = paths
        self.transforms = transforms

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        image = Image.open(self.paths[i])
        image = self.transforms(image)
        return image

In [18]:
from types import SimpleNamespace
args = SimpleNamespace()

args.data = '../input/'
args.batch_size = 256
args.workers = os.cpu_count()

In [None]:
query_paths = sorted(Path(args.data).glob('query_images/**/*.jpg'))
query_ids = np.array([p.stem for p in query_paths], dtype='S6')

reference_paths = sorted(Path(args.data).glob('reference_images/**/*.jpg'))
reference_ids = np.array([p.stem for p in reference_paths], dtype='S7')

model.eval().cuda()

cudnn.benchmark = True

In [53]:
preprocesses = [
    transforms.Resize(model.default_cfg['input_size'][1:]),
    # transforms.Resize(model.default_cfg['input_size'] + 32),
    # transforms.CenterCrop(model.default_cfg['input_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=model.default_cfg['mean'], std=model.default_cfg['std']),
]

datasets = {
    'query': ISCDataset(query_paths, transforms.Compose(preprocesses)),
    'reference': ISCDataset(reference_paths, transforms.Compose(preprocesses)),
}
loader_kwargs = dict(batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False)
data_loaders = {
    'query': torch.utils.data.DataLoader(datasets['query'], **loader_kwargs),
    'reference': torch.utils.data.DataLoader(datasets['reference'], **loader_kwargs),
}

def calc_feats(loader):
    feats = []
    for image in tqdm(loader, total=len(loader)):
        x = image.cuda()
        with torch.no_grad():
            y = model(x)[-1]
            y = gem(y).squeeze(-1).squeeze(-1)
        feats.append(y.cpu().numpy())
    feats = np.concatenate(feats, axis=0)
    feats /= np.linalg.norm(feats, 2, axis=1, keepdims=True)
    return feats.astype(np.float32)

query_feats = calc_feats(data_loaders['query'])
reference_feats = calc_feats(data_loaders['reference'])

out = f'fb-isc-submission.h5'
with h5py.File(out, 'w') as f:
    f.create_dataset('query', data=query_feats)
    f.create_dataset('reference', data=reference_feats)
    f.create_dataset('query_ids', data=query_ids)
    f.create_dataset('reference_ids', data=reference_ids)


  0%|          | 0/391 [00:00<?, ?it/s]Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f2c64e8db00>Traceback (most recent call last):
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():Exception ignored in: 
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f2c64e8db00>    assert self._parent_pid == os.getpid(), 'can only test a child process'

Traceback (most recent call last):
AssertionError:   File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
can only test a child process
    self._shutdown_work

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-13-4c648425b482>", line 15, in __getitem__
    image = self.transforms(image)
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 60, in __call__
    img = t(img)
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 297, in forward
    return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 401, in resize
    return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/torchvision/transforms/functional_pil.py", line 241, in resize
    return img.resize(size[::-1], interpolation)
  File "/home/shuhei.yokoo/anaconda3/envs/fbisc/lib/python3.7/site-packages/PIL/Image.py", line 1943, in resize
    return self._new(self.im.resize(size, resample, box))
TypeError: an integer is required (got type tuple)
