# Image search

指定された画像と似ている画像を検索する。

Step 1. Modify Resnet model
* 画像のencoder として Resnet を修正する。

Step 2. Lightning Module
* encoder を pytorchlightning 用に定義する。

Step 3. Training Encoder
* encoder 訓練する。
* 訓練済みモデルをファイルに保存する。

Step 4. Embedding
* 検索対象画像と検索画像を特徴ベクトルへ変換する。
* 特徴ベクトルをファイルに保存する。

Step 5. Faiss indexes
* Step 4 で作成した検索対象画像の特徴ベクトルを検索するためのインデックスを生成する。
* インデックスをファイルに保存する。

Step 6. Image similarity search
* Step 4 で作成した検索画像の特徴ベクトルを使用して、Step 5 のインデックスを検索する。

-----

pytorch:
* https://pytorch.org/

pytorchligtning:
* https://www.pytorchlightning.ai/

faiss:
* https://github.com/facebookresearch/faiss
* https://faiss.ai/

Cifar10:
* https://pytorch.org/vision/stable/datasets.html

PyTorch Lightning CIFAR10 ~94% Baseline Tutorial

* https://github.com/PyTorchLightning/lightning-tutorials/blob/3321b468e78167aaf056894e92ed6d649c76e89e/.notebooks/lightning_examples/cifar10-baseline.ipynb

## Setup for Notebook

In [None]:
# notebook runtime
import sys

runtime = 'local'
if 'google.colab' in sys.modules:
    runtime = 'colab'
elif _dh == ['/kaggle/working']:
    runtime = 'kaggle'
runtime

In [None]:
if runtime == 'colab':
    from google.colab import drive
    drive.mount('/content/drive')

In [None]:
if runtime == 'colab':
    home_path = '/content/drive/MyDrive/image_similarity_search'
else:
    home_path = '/home/jovyan/image_similarity_search'

nbs_path = f'{home_path}/nbs'
datasets_path = f'{home_path}/datasets'
models_path = f'{home_path}/models'
figs_path = f'{home_path}/figs'
logs_path = f'{home_path}/logs'

In [None]:
%cd {nbs_path}

In [None]:
!pip install -q pytorch_lightning

In [None]:
!pip install -q faiss-gpu

In [None]:
!pip install -q nb-clean

## Setup

In [None]:
import os
import time
import numpy as np
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
import pickle
import shutil
from PIL import Image
import glob

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchsummary import summary

import torchvision
from torchvision import transforms

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

In [None]:
from pytorch_lightning.callbacks import LearningRateMonitor
from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy

In [None]:
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
pl.seed_everything(42)

## Dataset

In [None]:
from torchvision.datasets import CIFAR10

In [None]:
dataset_name = 'cifar10'

In [None]:
def clear_datasets():
    path_cifar_10_batches_py = os.path.join(datasets_path, 'cifar-10-batches-py')
    if os.path.exists(path_cifar_10_batches_py):
        shutil.rmtree(path_cifar_10_batches_py)

    path_cifar_10_python = os.path.join(datasets_path, 'cifar-10-python.tar.gz')
    if os.path.exists(path_cifar_10_python):
        os.remove(path_cifar_10_python)

In [None]:
#clear_datasets()

In [None]:
def cifar10_normalization():
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
    )
    return normalize

In [None]:
def train_transforms():
    return torchvision.transforms.Compose([
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        #cifar10_normalization(),
    ])

In [None]:
def test_transforms():
    return torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        #cifar10_normalization(),
    ])

In [None]:
class DataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str, batch_size: int=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = int(os.cpu_count() / 2)
        self.train_transform = train_transforms()
        self.test_transform = test_transforms()
        self.classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        self.num_classes = len(self.classes)
        self.dims = (3, 32, 32) # channels, width, height

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None): #train, val, testデータ分割
        # Assign train/val datasets for use in dataloaders
        datasets = CIFAR10(self.data_dir, train=True, transform=self.train_transform)
        n_train = int(len(datasets) * 0.8)
        n_val = len(datasets) - n_train
        self.ds_train, self.ds_val = torch.utils.data.random_split(datasets, [n_train, n_val])
        self.ds_test = CIFAR10(self.data_dir, train=False, transform=self.test_transform)

    def train_dataloader(self):
        return DataLoader(self.ds_train, shuffle=True, drop_last=True, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.ds_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers)
 
    def test_dataloader(self):
        return DataLoader(self.ds_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
datamodule = DataModule(datasets_path, BATCH_SIZE)
datamodule.prepare_data()
datamodule.setup()

In [None]:
# check data size
train_dataloader = iter(datamodule.train_dataloader())
images, labels = next(train_dataloader)
images.shape, labels.shape

## Step 1. Modify Resnet model
Modify the pre-existing Resnet architecture from TorchVision. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32).

In [None]:
torchvision.models.resnet18(pretrained=False, num_classes=10)

In [None]:
def create_model():
    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
    # (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    # (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    model.maxpool = nn.Identity()
    return model

In [None]:
# 特徴ベクトル
m = create_model()
m.fc = nn.Identity()
v = m(torch.randn(32, 3, 32, 32))
v.shape

## Step 2. Lightning Module

* endoder を pytorchlightning 用に定義する。

In [None]:
encoder_name = 'resnet'

In [None]:
class LitResnet(pl.LightningModule):

    def __init__(self, lr=0.05):
        super().__init__()

        self.save_hyperparameters()
        self.model = create_model()

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, 'val')

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, 'test')

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
        steps_per_epoch = 45000 // BATCH_SIZE
        scheduler_dict = {
            'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),
            'interval': 'step',
        }
        return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}

In [None]:
# model
encoder = LitResnet(lr=0.05)
encoder = encoder.to(device)  #for gpu
summary(encoder, (3, 32, 32))
print(encoder)

## Step 3. Training Encoder

* encoder 訓練する。
* 訓練済みモデルをファイルに保存する。

In [None]:
path_encoder = f'{models_path}/{dataset_name}_{encoder_name}.ckpt'

In [None]:
def clear_encoder():
    if os.path.exists(path_encoder):
        os.remove(path_encoder)

In [None]:
clear_encoder()

In [None]:
callbacks = [EarlyStopping(monitor="val_loss")]
#callbacks = [LearningRateMonitor(logging_interval='step')],

trainer = pl.Trainer(gpus=1, callbacks=callbacks,
                logger=TensorBoardLogger(f'{logs_path}/lightning_logs/', name='f{dataset_name}_{encoder_name}'))

if not os.path.exists(path_encoder):
    #trainer = pl.Trainer(max_epochs=10, gpus=1)
    trainer.fit(encoder, datamodule)
    trainer.save_checkpoint(path_encoder)
else:
    encoder = encoder.load_from_checkpoint(path_encoder)

In [None]:
encoder = encoder.to(device)  #for gpu
encoder.freeze()
encoder.eval()

In [None]:
# Test
results = trainer.test(encoder, datamodule)

In [None]:
# functions to show an image
def imshow(img, file=None, title=None):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy()
    plt.figure(figsize=(20, 10))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    if title:
        plt.title(title)
    if file:
        plt.savefig(file + '.png')
    plt.show()

In [None]:
# Original train images
dataiter = iter(datamodule.train_dataloader())
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images, nrow=32), f'{figs_path}/{dataset_name}_{encoder_name}_train_original', title='train images')
print(' '.join('%5s' % datamodule.classes[labels[j]] for j in range(8)))

In [None]:
# Start tensorboard.
#%reload_ext tensorboard
#%tensorboard --logdir lightning_logs/

## Step 4. Embedding

* 検索対象画像と検索画像を特徴ベクトルへ変換する。
* 特徴ベクトルをファイルに保存する。

In [None]:
path_embeded_train = f'{models_path}/{dataset_name}_{encoder_name}_embeded_train.pickle'
path_embeded_test = f'{models_path}/{dataset_name}_{encoder_name}_embeded_test.pickle'

In [None]:
def clear_embedding():
    if os.path.exists(path_embeded_train):
        os.remove(path_embeded_train)
    if os.path.exists(path_embeded_test):
        os.remove(path_embeded_test)

In [None]:
clear_embedding()

In [None]:
train_dataset = CIFAR10(datasets_path, train=True, download=True)
test_dataset = CIFAR10(datasets_path, train=False, download=True)

In [None]:
len(train_dataset), len(test_dataset)

In [None]:
# with preprocess(without data augumentation)
train_dataloader = DataLoader(CIFAR10(datasets_path, train=True, download=True, transform=test_transforms()),
                               shuffle=False, batch_size=32, num_workers=0)
test_dataloader = DataLoader(CIFAR10(datasets_path, train=False, download=True, transform=test_transforms()),
                               shuffle=False, batch_size=32, num_workers=0)

In [None]:
len(train_dataloader), len(test_dataloader)

In [None]:
def create_embedder_model(model):
    model = model.model
    model.fc = nn.Identity()
    return model

In [None]:
embedder = create_embedder_model(encoder)
embedder

In [None]:
img_random = torch.randn(32, 3, 32, 32)
img_emb = embedder(img_random)
img_emb.shape

In [None]:
def get_embeded_vector(embedder, dataloader):
    vector = []
    for i, (images, labels) in enumerate(tqdm(dataloader)):
        with torch.no_grad():
            v = embedder(images).squeeze().cpu()
        vector.extend(v.detach().numpy())

    return vector

In [None]:
if not os.path.exists(path_embeded_train):
    train_vectors = get_embeded_vector(embedder, train_dataloader)
    print(len(train_vectors), train_vectors[0].shape)
    with open(path_embeded_train, mode='wb') as f:
        pickle.dump(train_vectors, f)

In [None]:
if not os.path.exists(path_embeded_test):
    test_vectors = get_embeded_vector(embedder, test_dataloader)
    print(len(test_vectors))
    with open(path_embeded_test, mode='wb') as f:
        pickle.dump(test_vectors, f)

## Step 5. Faiss indexes

* Step 4 で作成した検索対象画像の特徴ベクトルを検索するためのインデックスを生成する。
* インデックスをファイルに保存する。

In [None]:
import faiss

In [None]:
path_indexer = f'{models_path}/{dataset_name}_{encoder_name}_indexer.faiss'

In [None]:
def clear_indexer():
    if os.path.exists(path_indexer):
        os.remove(path_indexer)

In [None]:
clear_indexer()

In [None]:
with open(path_embeded_train, mode='rb') as f:
    train_vectors = np.array(pickle.load(f))

train_vectors.shape

In [None]:
class FlatIndexer(object):

    def __init__(self, vector_sz: int, nlist=10, path=None):
        if path and os.path.exists(path):
            index_cpu = faiss.read_index(path)
            self.indexer = faiss.index_cpu_to_all_gpus(index_cpu)
        else:
            #index_cpu = faiss.IndexFlatIP(vector_sz) # Not Work
            quantizer = faiss.IndexFlatL2(vector_sz)
            index_cpu = faiss.IndexIVFFlat(quantizer, vector_sz, nlist, faiss.METRIC_L2)
            res = faiss.StandardGpuResources()
            self.indexer = faiss.index_cpu_to_gpu(res, 0, index_cpu)

    def index_data(self, vectors):
        self.indexer.train(vectors)
        self.indexer.add(vectors)

    def search_knn(self, query_vectors: np.array, top_docs: int):
        scores, indexes = self.indexer.search(query_vectors, top_docs)
        return scores, indexes

    def save_index(self, path):
        index_cpu = faiss.index_gpu_to_cpu(self.indexer)
        faiss.write_index(index_cpu, path)

In [None]:
if not os.path.exists(path_indexer):
    indexer = FlatIndexer(512, nlist=20)
    indexer.index_data(train_vectors)
    indexer.save_index(path_indexer)
else:
    indexer = FlatIndexer(512, path=path_indexer)

## Step 6. Image similarity search

In [None]:
with open(path_embeded_test, mode='rb') as f:
    test_vectors = np.array(pickle.load(f))

test_vectors.shape

In [None]:
scores, indexes = indexer.search_knn(test_vectors, 10)

In [None]:
scores.shape, indexes.shape

In [None]:
for i in range(0, 20):
    print(i, indexes[i])

In [None]:
scores[0]

In [None]:
def show_search_images(train_dataset, test_dataset, indexes, i, dataset_name, encoder_name):
    " train_dataset の中から test_dataset で指定された画像に似ている画像を検索する "
    fig, axes = plt.subplots(1, 11, figsize=(15,5))
    test_image, test_label = test_dataset[i]
    axes[0].imshow(test_image)
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[0].set_title(f'Q[{i}]')

    for col, idx in enumerate(indexes[i]):
        img, label = train_dataset[idx]
        axes[col+1].set_title(f'A[{idx}]')
        axes[col+1].imshow(img)
        axes[col+1].set_xticks([])
        axes[col+1].set_yticks([])
    plt.show()
    fig.savefig(f'{figs_path}/{dataset_name}_{encoder_name}_search_images_{i}')

In [None]:
 for i in range(10): 
    show_search_images(train_dataset, test_dataset, indexes, i, dataset_name, encoder_name)

In [None]:
def merge_search_images(figs_path, dataset_name, encoder_name):
    files = glob.glob(f"{figs_path}/{dataset_name}_{encoder_name}_search_images_*.png")
    images = None
    for file in sorted(files):
        im = np.array(Image.open(file).convert('RGB'))
        h, w, c = im.shape
        im = im[120:h-130, 110:w-80, :] # trim
        if images is None:
            images = im
        else:
            images = np.append(images, im, axis=0)
    img = Image.fromarray(images)
    img.save(f"{figs_path}/{dataset_name}_{encoder_name}_search_images.png")
    return img

In [None]:
img = merge_search_images(figs_path, dataset_name, encoder_name)
img

[]("../figs/cifar10_autoencoder_search_images.png")