In [1]:
import os
import csv

In [2]:
from app.slow_classificator import Classificator
from transformers import ViTForImageClassification, ViTImageProcessor
classificator = Classificator(
    ViTForImageClassification.from_pretrained("vit_overfit_last_results/checkpoint-684"),
    ViTImageProcessor.from_pretrained("google/vit-base-patch16-224"),
    device="cuda")

In [3]:
path2sub = "/home/user1/test_data_rkn/sample_submission.csv"
path2ds = "/home/user1/test_data_rkn/dataset"

In [4]:
from torch.utils.data import Dataset, DataLoader
from transformers import ViTImageProcessor
from torchvision.transforms import Compose, Normalize, ToTensor, Resize
from PIL import Image

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        
        self.id2label = {k: v for k, v in enumerate(sorted(os.listdir(root_dir)))}
        self.label2id = {v: k for k, v in self.id2label.items()}
        
        self.image_paths = []
        self.labels = []

        self.improcessor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
        
        self.size = self.improcessor.size["height"]
        self.normalize = Normalize(
            mean=self.improcessor.image_mean,
            std=self.improcessor.image_std
        )

        self._transforms = Compose([
            Resize((self.size, self.size)),
            ToTensor(),
            self.normalize
        ])

        for cls in self.id2label.values():
            cls_folder = os.path.join(root_dir, cls)
            if os.path.isdir(cls_folder):
                for img_name in os.listdir(cls_folder):
                    img_path = os.path.join(cls_folder, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(cls)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        return {
            "pixel_values": self.improcessor(
                images=Image.open(self.image_paths[idx]).convert("RGB")).pixel_values[0].squeeze(), # .squeeze()
            "labels": self.label2id[self.labels[idx]]
        }

In [5]:
ds = CustomImageDataset("/home/user1/hack/train_data_rkn/dataset")

In [6]:
embs = {}
for i in range(len(ds)):
    embs.update({i : classificator.predict_result(Image.open(ds.image_paths[i]).convert("RGB"))})


IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [7]:
print(embs[1])

{'class': 1, 'probs_class': array(0.8927962, dtype=float32), 'embedding': tensor([-4.7827e-01,  1.0817e+00, -1.5218e-01,  1.6220e+00, -1.7477e+00,
        -3.5527e-01,  3.9697e-01,  2.2603e-01, -1.1455e+00,  7.8467e-01,
        -1.3803e+00, -1.7426e+00,  7.3839e-01, -5.8287e-01, -7.5116e-01,
        -1.1761e+00, -4.3580e-01,  1.1276e+00,  2.5845e-01,  1.4696e+00,
         6.3200e-01,  4.8337e-01,  6.2374e-01, -1.4967e+00, -7.0406e-01,
        -9.2718e-01,  7.9730e-01,  3.3994e-01, -1.3709e+00, -8.1227e-01,
         9.5888e-01,  1.2425e+00,  4.8060e-01, -1.0815e+00, -7.6121e-01,
        -5.7987e-02, -1.9587e-01,  1.0074e+00,  1.7819e+00, -1.7208e+00,
        -1.0130e+00,  1.5809e+00,  1.7783e+00,  3.5052e-01, -1.6655e+00,
        -1.8997e+00, -4.3875e-01, -9.6757e-02,  1.7067e-01,  7.8400e-01,
        -1.1159e-01,  3.2458e-01, -8.6776e-01, -1.0329e+00, -3.3496e-02,
        -2.5031e-01, -3.4577e-02,  4.5826e-01,  1.1225e-01,  1.0864e+00,
        -6.3294e-01,  1.8279e+00, -9.7146e-02, -1.

In [55]:
from scipy.spatial.distance import cdist
import numpy as np
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

embeddings_array = np.array([embs[i]["embedding"].cpu().numpy() for i in range(len(embs))])
answer = {}

for img_path in ds.image_paths:
    try:
        output = classificator.predict_result(Image.open(img_path))
    except: continue
    
    dist = cdist([output["embedding"].cpu().numpy()], embeddings_array, metric="cosine")[0]
    closest_images = [ds.image_paths[i] for i in np.argsort(dist)[:10]]
    answer.update({r[0] : closest_images})

In [53]:
answer

{'e624b22d-6895-4a8d-91a3-19fd7e0f4c93.jpg': ['/home/user1/hack/train_data_rkn/dataset/Accordion/a76d509b54dc9b80.jpg',
  '/home/user1/hack/train_data_rkn/dataset/Accordion/2a8c3379fe174472.jpg',
  '/home/user1/hack/train_data_rkn/dataset/Accordion/80d19062a39df158.jpg',
  '/home/user1/hack/train_data_rkn/dataset/Accordion/0f08b4f4d5a27625.jpg',
  '/home/user1/hack/train_data_rkn/dataset/Accordion/e77aa4970b360a96.jpg',
  '/home/user1/hack/train_data_rkn/dataset/Accordion/1234140cb6bd44b6.jpg',
  '/home/user1/hack/train_data_rkn/dataset/Accordion/00eaaf3ecceb80a8.jpg',
  '/home/user1/hack/train_data_rkn/dataset/Accordion/8c9c0d298f6de40a.jpg',
  '/home/user1/hack/train_data_rkn/dataset/Accordion/9ed04f0c54c05152.jpg',
  '/home/user1/hack/train_data_rkn/dataset/Accordion/e2eceae8f2e8f55e.jpg']}