In [20]:
import json
from PIL import Image
import torch
from torchvision import transforms
from pytorch_pretrained_vit import ViT

model = ViT('B_16_imagenet1k', image_size=768, pretrained=True)
model.eval()
delattr(model, 'fc')

Resized positional embeddings from torch.Size([1, 577, 768]) to torch.Size([1, 2305, 768])
Loaded pretrained weights.


In [22]:
import os 
import numpy as np
from tqdm import tqdm

dir_str = "images"
directory = os.fsencode(dir_str)
embeddings = list()

for file in tqdm(os.listdir(directory)):
    filename = os.fsdecode(file)
    img = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5),
    ])(Image.open('/'.join((dir_str,filename)))).unsqueeze(0)
    # print(img.shape) # torch.Size([1, 3, 384, 384])

    with torch.no_grad():
        outputs = model(img)[:,0].squeeze(0).tolist()

    embeddings.append(outputs)

100%|██████████| 41/41 [07:33<00:00, 11.05s/it]


In [23]:
embeddings

[[tensor(7.0143),
  tensor(-2.7055),
  tensor(-2.1418),
  tensor(13.1341),
  tensor(1.2271),
  tensor(-6.3945),
  tensor(10.9924),
  tensor(-0.1560),
  tensor(5.0019),
  tensor(-1.0378),
  tensor(8.6506),
  tensor(-10.2951),
  tensor(8.7375),
  tensor(-3.6895),
  tensor(0.7186),
  tensor(-6.4601),
  tensor(-4.6152),
  tensor(4.3083),
  tensor(2.2352),
  tensor(-9.9109),
  tensor(1.3594),
  tensor(-1.9578),
  tensor(0.4157),
  tensor(-3.3598),
  tensor(-2.4260),
  tensor(2.6514),
  tensor(0.8050),
  tensor(5.9400),
  tensor(8.0926),
  tensor(-5.9861),
  tensor(0.3052),
  tensor(7.6694),
  tensor(-5.3975),
  tensor(-2.3021),
  tensor(3.8280),
  tensor(-6.5270),
  tensor(-6.3145),
  tensor(4.1091),
  tensor(0.9636),
  tensor(-1.1994),
  tensor(3.7440),
  tensor(0.7334),
  tensor(3.4859),
  tensor(6.3741),
  tensor(-1.3310),
  tensor(11.7292),
  tensor(-10.9846),
  tensor(-4.5327),
  tensor(-0.4406),
  tensor(0.2407),
  tensor(-0.0981),
  tensor(-1.8241),
  tensor(-9.9842),
  tensor(4.1608

In [25]:
print(len(embeddings))
print(len(embeddings[0]))

41
768


In [26]:
embs = np.zeros((41,768))

for i,l in enumerate(embeddings):
    for j in range(768):
        embs[i,j] = l[j].numpy()[()]

In [27]:
embs

array([[  7.01434326,  -2.70548153,  -2.1418426 , ...,   2.42300963,
         -7.88299799,   5.05245399],
       [  7.24234295,  -0.77021581,  -3.60442257, ...,   2.58113313,
         -7.81171513,   5.00962639],
       [  4.29610109,  -2.53905582,  -0.52308047, ...,   0.55602205,
         -6.08861113,   6.94561481],
       ...,
       [-13.12815285,  -8.38469124,   5.31107521, ...,  -3.98978424,
         -5.6503315 ,  -4.78499603],
       [  0.72535276,  -2.60728073,  17.71445084, ...,  -0.95993841,
        -12.25497246,   3.08994246],
       [  1.88694263,  -4.88599205,   5.29348946, ...,  -8.24641705,
         -9.26574516,   6.27256203]])

In [33]:
base = embs[35]

cos_sims = list()

for idx, img in enumerate(embs):
    cos_sim = np.dot(base, img)/(np.linalg.norm(base)*np.linalg.norm(img))
    cos_sims.append((cos_sim, idx+1))

cos_sims.sort(reverse=True)

In [34]:
cos_sims

[(0.9537418195637598, 1),
 (0.8518883541618214, 2),
 (0.5439786257601165, 4),
 (0.4930890323658644, 3),
 (0.4881598416141021, 5),
 (0.3759238627318705, 32),
 (0.3640617735033462, 31),
 (0.3315114984982245, 33),
 (0.31628115360415593, 19),
 (0.2970782841194456, 40),
 (0.2839620340764554, 16),
 (0.2810167614585272, 17),
 (0.2567501764901122, 39),
 (0.2521454373046276, 9),
 (0.24425394288134003, 10),
 (0.24233853280616563, 11),
 (0.23317455959463296, 18),
 (0.22836924810883286, 25),
 (0.22834409844460032, 21),
 (0.22401766702117637, 20),
 (0.21735706907820554, 37),
 (0.21146940704683276, 29),
 (0.20983203690780555, 28),
 (0.20836833389361312, 30),
 (0.1942673461754875, 23),
 (0.1914241129633079, 35),
 (0.18792679396223805, 26),
 (0.18661329543260183, 27),
 (0.17489777195911163, 24),
 (0.17281627491008716, 34),
 (0.1717747895859106, 15),
 (0.16462163878409478, 7),
 (0.16225991696525285, 22),
 (0.15441629457065967, 36),
 (0.14557977088964785, 6),
 (0.13195632227756496, 13),
 (0.114151828991

In [35]:
for idx, file in enumerate(os.listdir(directory)):
    filename = os.fsdecode(file)
    print(idx, filename)

0 img_0_0.jpg
1 img_0_1.jpg
2 img_0_2.jpg
3 img_10_0.jpg
4 img_10_1.jpg
5 img_10_2.jpg
6 img_11_0.jpg
7 img_11_1.jpg
8 img_11_2.jpg
9 img_12_0.jpg
10 img_12_1.jpg
11 img_12_2.jpg
12 img_13_0.jpg
13 img_13_1.jpg
14 img_13_2.jpg
15 img_14_0.jpg
16 img_1_0.jpg
17 img_1_1.jpg
18 img_1_2.jpg
19 img_2_0.jpg
20 img_2_1.jpg
21 img_2_2.jpg
22 img_3_0.jpg
23 img_3_1.jpg
24 img_3_2.jpg
25 img_4_0.jpg
26 img_4_1.jpg
27 img_4_2.jpg
28 img_5_0.jpg
29 img_5_1.jpg
30 img_5_2.jpg
31 img_6_0.jpg
32 img_6_1.jpg
33 img_6_2.jpg
34 img_7_0.jpg
35 img_7_1.jpg
36 img_8_0.jpg
37 img_8_1.jpg
38 img_8_2.jpg
39 img_9_0.jpg
40 img_9_1.jpg
