In [131]:
!pip install faiss-gpu



In [132]:
!pip install faiss-cpu



In [133]:
import os
import timm
import json
import faiss
import torch
import numpy as np
import pandas as pd
from torch import nn
from tqdm import tqdm
from PIL import Image
from pathlib import Path
import torch.optim as optim
from torchvision import transforms
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import average_precision_score
from sklearn.model_selection import train_test_split

In [134]:
GLOBAL_SEED=123
np.random.seed(GLOBAL_SEED)

In [135]:
kaggle_root = Path('/kaggle/input')
dataset_root = kaggle_root / 'imaterialist-fashion-2020-fgvc7'
models_root = kaggle_root / 'imaterialist-comparison-models'
train_df_path = dataset_root / 'train.csv'
# images_folder = dataset_root / 'train'
images_folder = kaggle_root / 'imaterialist-train-224x224' / 'train_224'
desc_path = dataset_root / 'label_descriptions.json'

In [136]:
full_dataset_df = pd.read_csv(train_df_path)

In [137]:
full_dataset_groupped = full_dataset_df[
    ['ImageId', 'ClassId']
].groupby('ImageId')[
    'ClassId'
].agg([('ClassIds', lambda xs: set(xs))])

In [138]:
class_descriptions = json.load(open(desc_path))
class_descriptions = list(map(lambda x: x[1], sorted(list(map(
    lambda x: (x['id'], x['name']),
    class_descriptions['categories']
)), key=lambda x: x[0])))
class_descriptions = np.array(class_descriptions)

In [139]:
train_df, test_df = train_test_split(
    full_dataset_groupped,
    test_size=0.2,
    random_state=GLOBAL_SEED,
)

In [140]:
N_CLASSES=len(set(full_dataset_df['ClassId']))
print("Number of classes: ", N_CLASSES)

Number of classes:  46


In [141]:
# model_path = models_root / 'small_swin_metric_ep14_best9737.pth'
# model_path = models_root / 'small_swin_from_scratch_1280.pth'
# model_path = models_root / 'small_swin_transfer_5536.pth'

In [142]:
model = timm.create_model('swin_small_patch4_window7_224.ms_in22k_ft_in1k', pretrained=True)

In [143]:
# model.head.fc = nn.Linear(model.head.fc.in_features, N_CLASSES)
model.head.fc = nn.Identity()

In [144]:
import gc
torch.cuda.empty_cache()
gc.collect()

50

In [145]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [146]:
# model.load_state_dict(torch.load(model_path, map_location=device))

In [147]:
# model.head.fc = nn.Identity()

In [148]:
class FashionDataset(Dataset):
    def __init__(
            self,
            dataframe,
            image_folder,
            n_classes=N_CLASSES,
            transform=None,
        ):
        self.dataframe: pd.DataFrame = dataframe
        self.image_folder: Path = image_folder
        self.transform: nn.Module = transform
        self.n_classes: int = n_classes

    def __len__(self):
        return len(self.dataframe)
    
    def load_transform(self, image_id):
        image = Image.open(self.image_folder / f"{image_id}.jpg")
        if image.mode == 'L':
            image = image.convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

    def __getitem__(self, idx):
        return self.load_transform(self.dataframe.iloc[idx].name)

In [149]:
input_size = 224
test_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [150]:
mem_dataset = FashionDataset(train_df, images_folder, transform=test_transform)
query_dataset = FashionDataset(test_df, images_folder, transform=test_transform)

In [151]:
BATCH_SIZE = 32
mem_loader = DataLoader(mem_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(), pin_memory=True)
query_loader = DataLoader(query_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(), pin_memory=True)

In [152]:
def is_match(label_match: set, label_query: set, positive_match_percent=0.8):
    """ Detects if positive_match_percent of the query labels are in the match """
    # % present
    p = 1 - len(label_query - label_match) / len(label_query)
    return p >= positive_match_percent

In [153]:
mem_embs = []
with torch.no_grad():
    for img in tqdm(mem_loader):
        mem_embs.append(model(img.to(device)))

100%|██████████| 1141/1141 [02:37<00:00,  7.24it/s]


In [154]:
mem_embs = torch.cat(mem_embs)

In [155]:
query_embs = []
with torch.no_grad():
    for img in tqdm(query_loader):
        query_embs.append(model(img.to(device)))

100%|██████████| 286/286 [00:39<00:00,  7.18it/s]


In [156]:
query_embs = torch.cat(query_embs)

In [157]:
index = faiss.IndexFlatL2(768)
index.add(mem_embs.cpu())

_, indices = index.search(query_embs.cpu(), 20)

In [158]:
top_1, top_5, top_20 = 0, 0, 0

for query_i, match_is in enumerate(indices):
    query_label = test_df.iloc[query_i]["ClassIds"]
    match_mask = [
        is_match(train_df.iloc[match_i]["ClassIds"], query_label)
        for match_i in match_is
    ]

    if any(match_mask[:1]):
        top_1 += 1
    if any(match_mask[:5]):
        top_5 += 1
    if any(match_mask[:20]):
        top_20 += 1

In [159]:
print(f"Top-1 Accuracy: {top_1 / len(query_embs) * 100:.2f}%")
print(f"Top-5 Accuracy: {top_5 / len(query_embs) * 100:.2f}%")
print(f"Top-20 Accuracy: {top_20 / len(query_embs) * 100:.2f}%")

Top-1 Accuracy: 35.98%
Top-5 Accuracy: 58.78%
Top-20 Accuracy: 75.66%
