In [3]:
import torchvision.models as models

model = models.efficientnet_b4(models.EfficientNet_B4_Weights.DEFAULT)

Downloading: "https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth" to C:\Users\mauri/.cache\torch\hub\checkpoints\efficientnet_b4_rwightman-23ab8bcd.pth
100%|██████████| 74.5M/74.5M [00:09<00:00, 8.11MB/s]


In [14]:
import torch

def get_emb(model, img):
    x = model.features(img)
    x = model.avgpool(x)
    return torch.flatten(x, 1).squeeze(0).tolist()

In [6]:
models.EfficientNet_B4_Weights.DEFAULT.transforms

functools.partial(<class 'torchvision.transforms._presets.ImageClassification'>, crop_size=380, resize_size=384, interpolation=<InterpolationMode.BICUBIC: 'bicubic'>)

In [19]:
import json
from PIL import Image
import torch
from torchvision import transforms
import os 
import numpy as np
from tqdm import tqdm

dir_str = "images"
directory = os.fsencode(dir_str)
embeddings = list()

for file in tqdm(os.listdir(directory)):
    filename = os.fsdecode(file)
    img = transforms.Compose([
        transforms.Resize((384, 384), Image.BICUBIC),
        transforms.CenterCrop(380),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])(Image.open('/'.join((dir_str,filename)))).unsqueeze(0)
    # print(img.shape) # torch.Size([1, 3, 384, 384])

    with torch.no_grad():
        outputs = get_emb(model, img)

    embeddings.append(outputs)

100%|██████████| 41/41 [00:20<00:00,  2.01it/s]


In [18]:
len(embeddings[0])

1792

In [20]:
embeddings = np.array(embeddings)
embeddings.shape

(41, 1792)

In [26]:
base = embeddings[35]

cos_sims = list()

for idx, img in enumerate(embeddings):
    cos_sim = np.dot(base, img)/(np.linalg.norm(base)*np.linalg.norm(img))
    cos_sims.append((cos_sim, idx))

cos_sims.sort(reverse=True)

In [27]:
cos_sims

[(1.0, 35),
 (0.8868286420362609, 6),
 (0.8264174662778809, 11),
 (0.7878606883208698, 33),
 (0.7700424473057442, 19),
 (0.769071105898622, 12),
 (0.7633499864789764, 1),
 (0.7607359268288509, 4),
 (0.743370488381418, 10),
 (0.7430450052114606, 32),
 (0.7284250071864187, 13),
 (0.7159573363484643, 31),
 (0.71497390019577, 21),
 (0.6817514279012727, 38),
 (0.6420223216477583, 34),
 (0.628414056483935, 3),
 (0.6103807483675272, 23),
 (0.6095492616837259, 17),
 (0.6011928739028061, 39),
 (0.5869402824824432, 22),
 (0.5779961887313033, 16),
 (0.5556637374567358, 8),
 (0.5444095932209682, 7),
 (0.528101382790358, 15),
 (0.5094797516735364, 25),
 (0.5043835469240293, 0),
 (0.5033128592476008, 30),
 (0.5015538902980392, 37),
 (0.49744471357147474, 20),
 (0.4695109853765545, 26),
 (0.43682679176403605, 5),
 (0.4311139129250266, 2),
 (0.42934942979144747, 29),
 (0.4168917021425288, 27),
 (0.4135452006091975, 14),
 (0.398632075550831, 24),
 (0.39284851171931223, 9),
 (0.38620548978237407, 40),
 

In [25]:
for idx, file in enumerate(os.listdir(directory)):
    filename = os.fsdecode(file)
    print(idx, filename)

0 img_0_0.jpg
1 img_0_1.jpg
2 img_0_2.jpg
3 img_10_0.jpg
4 img_10_1.jpg
5 img_10_2.jpg
6 img_11_0.jpg
7 img_11_1.jpg
8 img_11_2.jpg
9 img_12_0.jpg
10 img_12_1.jpg
11 img_12_2.jpg
12 img_13_0.jpg
13 img_13_1.jpg
14 img_13_2.jpg
15 img_14_0.jpg
16 img_1_0.jpg
17 img_1_1.jpg
18 img_1_2.jpg
19 img_2_0.jpg
20 img_2_1.jpg
21 img_2_2.jpg
22 img_3_0.jpg
23 img_3_1.jpg
24 img_3_2.jpg
25 img_4_0.jpg
26 img_4_1.jpg
27 img_4_2.jpg
28 img_5_0.jpg
29 img_5_1.jpg
30 img_5_2.jpg
31 img_6_0.jpg
32 img_6_1.jpg
33 img_6_2.jpg
34 img_7_0.jpg
35 img_7_1.jpg
36 img_8_0.jpg
37 img_8_1.jpg
38 img_8_2.jpg
39 img_9_0.jpg
40 img_9_1.jpg
