# 本ノートブックの説明
今回のコンペのデータに関するEDAとConvNeXt Baseモデルを使用して画像特徴量を抽出し、faissを使用してコサイン類似度の高いもの画像を検索し、提出ファイルを作成する流れを紹介します。
過去の類似コンペとして[AI×商標：イメージサーチコンペティション（類似商標画像の検出）](https://competition.nishika.com/competitions/patent/summary)がございます。必要に応じてそちらもご確認ください。



## 前提
- 本ノートブックでは以下のディレクトリ構成を想定して実装されています。こちらを活用する場合はその点にご留意ください。またGPU環境(Tesla T4）で動作確認を行なっております。


```
/content: Google Colaboratory実行時のカレントディレクトリ
┣ cite_images: 事前作成不要。Google Drive上のcite_images.zip展開後に作成されます。
┣ query_images: 事前作成不要。Google Drive上のquery_images.zip展開後に作成されます
┗ /drive/MyDrive/cpt-sake/: ベースとなる作業ディレクトリ
　　　　　┣ data: 本コンペで提供されているデータを格納するディレクトリ
   ┃ ┣ train.csv: 訓練データ
   ┃ ┣ cite.csv: 引用データ
   ┃ ┣ test.csv: 評価データ
   ┃ ┣ cite_images.zip: 引用画像
   ┃ ┣ query_images.zip: クエリ画像（訓練データ、評価データの画像）
   ┃ ┣ test.csv: 評価データ
   ┃ ┗ sample_submission.csv: 提出用データのサンプル
   ┣ features: 画像特徴量を保存するディレクトリ
   ┣ index: 索引を保存するディレクトリ
   ┣ model: モデルを保存するディレクトリ
   ┗ output: 本コンペで提出するファイルを格納するディレクトリ
```

## 注意事項
- インデックス作成処理については１時間ほど要します。

## 設定
- Google Driveのマウント
- 画像ファイルの展開
- 追加で必要なライブラリ(faiss-gpu, japanize_matplotlib, timm)のインストール

In [None]:
!nvidia-smi


In [None]:
import collections
import math
import os
import random
import typing
from pathlib import Path

import faiss
import japanize_matplotlib
import matplotlib.pyplot as plt
import missingno as msno
import numpy as np
import pandas as pd
import seaborn as sns
import timm
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch.transforms import ToTensorV2
from PIL import Image
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from torch import nn
from torch.nn import CrossEntropyLoss, Parameter
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm


In [None]:
BASE_DIR = Path("/media/data/sake_brand_image_search")
DATA_DIR = BASE_DIR.joinpath("data")
FEATURES_DIR = BASE_DIR.joinpath("features")
OUT_DIR = BASE_DIR.joinpath("output")
INDEX_DIR = BASE_DIR.joinpath("index")
MODEL_DIR = BASE_DIR.joinpath("model")
CITE_IMG_DIR = BASE_DIR.joinpath("cite_images")
QUERY_IMG_DIR = BASE_DIR.joinpath("query_images")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EXP_NAME = "cpt-sake-tutorial"


class CFG:
    img_size = 224
    model_name = "convnext_base"
    in_channels = 3
    embedding_dim = 128
    pretrained = True
    batch_size = 128
    n_workers = 0
    seed = 0


In [None]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_torch(CFG.seed)


# データ

In [None]:
cite_filepath = DATA_DIR.joinpath("cite.csv")
train_filepath = DATA_DIR.joinpath("train.csv")
test_filepath = DATA_DIR.joinpath("test.csv")
sub_filepath = DATA_DIR.joinpath("sample_submission.csv")
df_cite = pd.read_csv(cite_filepath)
df_train = pd.read_csv(train_filepath)
df_test = pd.read_csv(test_filepath)
df_sub = pd.read_csv(sub_filepath)


## データの形式、概要確認

In [None]:
df_cite.shape


In [None]:
df_cite.head()


In [None]:
df_cite.tail()


In [None]:
df_cite.isna().sum()


In [None]:
df_train.shape


In [None]:
df_train.head()


In [None]:
df_train.tail()


In [None]:
df_train.isna().sum()


In [None]:
df_train[df_train["meigara"].isna()]


In [None]:
df_test.shape


In [None]:
df_test.head()


In [None]:
df_test.tail()

In [None]:
df_test.isna().sum()


## 提出ファイルの形式確認

In [None]:
df_sub.head()


In [None]:
df_sub.dtypes


引用画像ID(gid)をスペースでつないだ文字列を予測結果とします。

In [None]:
pred_sample = df_sub["cite_gid"].values[0]
print(pred_sample)


## 画像ファイルのパス情報追加

In [None]:
cite_filenames = df_cite["cite_filename"].to_list()
df_cite["path"] = [str(CITE_IMG_DIR.joinpath(filename))
                   for filename in cite_filenames]
train_filenames = df_train["filename"].to_list()
df_train["path"] = [str(QUERY_IMG_DIR.joinpath(filename))
                    for filename in train_filenames]
test_filenames = df_test["filename"].to_list()
df_test["path"] = [str(QUERY_IMG_DIR.joinpath(filename))
                   for filename in test_filenames]


In [None]:
from sklearn import preprocessing
le = preprocessing.LabelEncoder()
le.fit(df_train["meigara"])
le.transform(df_train["meigara"])
le.classes_.shape

## 訓練データ brand_id, meigara

酒のブランドを識別するbrand_id

In [None]:
topn = 20
print("訓練データ内のbrand_idの種類数：{:4d}".format(df_train["brand_id"].nunique()))
ax = sns.countplot(x=df_train["brand_id"], order=pd.value_counts(
    df_train['brand_id']).iloc[:topn].index)
ax.set_xticklabels(ax.get_xticklabels(), rotation=30)
plt.show()


In [None]:
topn = 20
print("訓練データ内のmeigaraの種類数：{:4d}".format(df_train["meigara"].nunique()))
ax = sns.countplot(x=df_train["meigara"], order=pd.value_counts(
    df_train['meigara']).iloc[:topn].index)
ax.set_xticklabels(ax.get_xticklabels(), rotation=30)
plt.show()


## meigaraが同じ画像サンプルを確認
- meigara：brand_idは多くの場合1:1ですが、１：Nの場合が一部存在し、別の酒蔵が同じ銘柄の商品を出す場合に別のブランドとして扱われております。
- 今回の画像検索の難しい点として、以下の点が確認できます。
 - お酒にフォーカスして撮影された写真もあれば、料理と一緒に撮影された写真もあります。
 - 同じブランドであってもお酒の容器のラベルレイアウトに複数の種類がある場合があります。


In [None]:
df_brand = df_train.groupby("meigara")["brand_id"].nunique().to_frame()
df_brand[df_brand["brand_id"] > 1]


BLACK JACK

In [None]:
df_train[df_train["meigara"] == "BLACK JACK"].head()


In [None]:
# @title
def show_same_meigara(meigara: str) -> None:
    image_paths = df_train.loc[df_train["meigara"]
                               == meigara, "path"].to_numpy()
    brand_ids = df_train.loc[df_train["meigara"]
                             == meigara, "brand_id"].to_numpy()
    col = 2
    # row = int(len(image_paths)/col)
    row = int(len(image_paths)/col)
    # サンプル数が奇数の場合、２列で表示できる分だけ表示する
    n_samples = row * col
    fig, axs = plt.subplots(row, col, figsize=(col * 5, row * 5))
    for i, path in enumerate(image_paths[:n_samples]):
        j = int(i / 2)
        k = i % 2
        image = Image.open(path)
        axs[j, k].imshow(image)
        title = f"brand_id: {str(brand_ids[i])}"
        axs[j, k].set_title(title)
        axs[j, k].axis('off')  # 軸を非表示にする
    plt.show()


In [None]:
show_same_meigara(meigara="BLACK JACK")


亀の尾

In [None]:
df_train[df_train["meigara"] == "亀の尾"].head()


In [None]:
show_same_meigara(meigara="亀の尾")


　初桜

In [None]:
df_train[df_train["meigara"] == "初桜"].head()


In [None]:
show_same_meigara(meigara="初桜")


## データセット

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]  # RGB
IMAGENET_STD = [0.229, 0.224, 0.225]  # RGB


class SakeDataset(Dataset):
    def __init__(self, image_filepaths: list, labels: list = None,
                 transform: typing.Dict[str, typing.Any] = None) -> None:
        self.image_filepaths = image_filepaths
        self.labels = labels
        self.transform = transform

    def __len__(self) -> int:
        return len(self.image_filepaths)

    def __getitem__(self, idx: int) -> typing.Tuple[torch.tensor, torch.tensor]:
        item = dict()
        image_filepath = self.image_filepaths[idx]
        image = self.__read_image(image_filepath)
        if self.transform is not None:
            image = self.transform(image=image)["image"]
            item["image"] = image

        if self.labels is not None:
            label = self.labels[idx]
            label = torch.tensor(label, dtype=torch.long)
            item["label"] = label
        return item

    def __read_image(self, path: str) -> None:
        with open(path, 'rb') as f:
            image = Image.open(f)
            image_rgb = image.convert('RGB')
        image = np.array(image_rgb)
        return image


def get_transforms(img_size: int = 224) -> torch.tensor:
    return Compose(
        [
            Resize(img_size, img_size),
            Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2()
        ]
    )


In [None]:
df_train["brand_id"]


In [None]:
train_dataset = SakeDataset(
    image_filepaths=df_train["path"].to_list(),
    labels=df_train["brand_id"].to_list(),
    transform=get_transforms()
)


In [None]:
train_dataset.__getitem__(0)["label"]


# ConvNeXt Baseモデルの画像特徴量で検索

In [None]:
class SakeNet(nn.Module):
    def __init__(
        self,
        cfg
    ):
        super().__init__()
        self.cfg = cfg
        if hasattr(timm.models, cfg.model_name):
            base_model = timm.create_model(
                cfg.model_name, num_classes=0, pretrained=cfg.pretrained, in_chans=cfg.in_channels)
            in_features = base_model.num_features
            self.backbone = base_model
            print("load imagenet model_name:", cfg.model_name)
            print("load imagenet pretrained:", cfg.pretrained)
        else:
            raise NotImplementedError
        self.in_features = in_features
        self.fc = nn.Linear(self.in_features, cfg.embedding_dim)

    def get_embedding(self, image: torch.tensor) -> torch.tensor:
        output = self.backbone(image)
        output = self.fc(output)
        return output


In [None]:
model = SakeNet(cfg=CFG)
model = model.to(DEVICE)
model_path = MODEL_DIR.joinpath(f"{EXP_NAME}.pth")
torch.save(model.state_dict(), model_path)


## インデックス作成

In [None]:
def infer(data_loader: DataLoader, model: nn.Module) -> np.array:
    stream = tqdm(data_loader)
    model.eval()
    embedding = []
    for batch in stream:
        images = batch["image"].to(DEVICE, non_blocking=True).float()
        # targets = batch["target"].to(DEVICE, non_blocking = True).float().view(-1, 1)
        with torch.set_grad_enabled(mode=False):
            output = model.get_embedding(images)
            embedding.append(output.detach().cpu().numpy())
    embedding = np.concatenate(embedding)
    return embedding


In [None]:
cite_dataset = SakeDataset(
    image_filepaths=df_cite["path"].to_list(),
    transform=get_transforms()
)
cite_loader = DataLoader(
    cite_dataset,
    batch_size=CFG.batch_size,
    shuffle=False,
    num_workers=CFG.n_workers,
    pin_memory=True
)
model = SakeNet(cfg=CFG)
model = model.to(DEVICE)
model.load_state_dict(torch.load(model_path))
cite_embedding = infer(cite_loader, model)


In [None]:
np.save(FEATURES_DIR.joinpath(
    f"cite_embedding_{EXP_NAME}.npy"), cite_embedding)


In [None]:
cite_embedding = np.load(FEATURES_DIR.joinpath(
    f"cite_embedding_{EXP_NAME}.npy"))
cite_embedding.shape


In [None]:
class FaissKNeighbors:
    def __init__(self, model_name: str, index_name: str, k: int = 20) -> None:
        self.index = None
        self.d = None
        self.k = k
        self.model_name = model_name
        self.index_name = str(INDEX_DIR.joinpath(f"{index_name}.index"))

    def fit(self, X: np.array) -> None:
        X = X.copy(order="C")
        self.d = X.shape[1]
        # distance: cosine similarity
        self.index = faiss.IndexFlatIP(self.d)
        self.index.add(X.astype(np.float32))

    def save_index(self) -> None:
        faiss.write_index(self.index, self.index_name)
        print(f"{self.index_name} saved.")

    def read_index(self) -> None:
        self.index = faiss.read_index(self.index_name)
        self.d = self.index.d
        print(f"{self.index_name} read.")

    def predict(self, X: np.array) -> typing.Tuple:
        X = X.copy(order="C")
        X = np.reshape(X, (-1, self.d))
        distances, indices = self.index.search(X.astype(np.float32), k=self.k)
        if X.shape[0] == 1:
            return distances[0], indices[0]
        else:
            return distances, indices


In [None]:
knn = FaissKNeighbors(model_name=CFG.model_name, index_name=EXP_NAME, k=20)
knn.fit(cite_embedding)
knn.save_index()


In [None]:
df_cite.head()


In [None]:
idx2cite_gid = dict(zip(df_cite.index, df_cite["cite_gid"]))


## クエリ画像の検索

In [None]:
test_dataset = SakeDataset(
    image_filepaths=df_test["path"].to_list(),
    transform=get_transforms()
)
test_loader = DataLoader(
    test_dataset,
    batch_size=CFG.batch_size,
    shuffle=False,
    num_workers=CFG.n_workers,
    pin_memory=True
)

query_embedding = infer(test_loader, model)
np.save(FEATURES_DIR.joinpath(
    f"query_embedding_{EXP_NAME}.npy"), query_embedding)


In [None]:
query_embedding = np.load(FEATURES_DIR.joinpath(
    f"query_embedding_{EXP_NAME}.npy"))
query_embedding.shape


In [None]:
cite_gids = []
for _query_embeding in tqdm(query_embedding):
    distance, pred = knn.predict(_query_embeding)
    _cite_gids = [str(idx2cite_gid[p]) for p in pred]
    cite_gids.append(" ".join(_cite_gids))
df_test["cite_gid"] = cite_gids
df_test[["gid", "cite_gid"]].to_csv(OUT_DIR.joinpath(
    f"submission_{EXP_NAME}.csv"), index=False)


In [None]:
# @title
def view_result_bygid(df_test: pd.DataFrame, gid: int) -> None:
    pred_gids = df_test.loc[df_test["gid"] == gid, "cite_gid"].values[0]
    pred_gids = pred_gids.split()

    query_path = df_test.loc[df_test["gid"] == gid, "path"].values[0]
    paths = []
    paths.append(query_path)
    cite_paths = [str(CITE_IMG_DIR.joinpath(path + ".jpg"))
                  for path in pred_gids]
    paths.extend(cite_paths)

    figs, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 20))
    for i, path in enumerate(paths):
        img = Image.open(path)
        i_row = int(i / 3)
        i_col = i % 3
        axs[i_row, i_col].imshow(img)
        gid = path.split("/")[-1].replace(".jpg", "")
        if i == 0:
            title = f"query data gid:{gid}"
            color = "red"
        else:
            title = f"rank: {i}, cite_gid:{gid}"
            color = "black"

        axs[i_row, i_col].set_title(title, color=color)
        axs[i_row, i_col].grid(False)
        axs[i_row, i_col].axis("off")
    plt.show()


## 検索結果の確認

In [None]:
df_test.head()


In [None]:
view_result_bygid(df_test, gid=200108162)


# EOF