In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from PIL import Image
from torchvision.datasets import ImageFolder
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pickle



# Define the CNN model to extract visual features
class VisualENcoder_resnet(nn.Module):
    def __init__(self, encoding_dim):
        super(VisualENcoder_resnet, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, encoding_dim)
        self.relu1 = nn.ReLU()
        self.linear1 = nn.Linear(encoding_dim, 50)
        

    def forward(self, x):
        features = self.resnet(x)
        features = self.relu1(features)
        features = self.linear1(features)
        return features


encoding_dim = 4096
visual_model = VisualENcoder_resnet(encoding_dim)

# Load the dataset and pre-process the images
transform1 = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

image_dir = 'images_here/'
dataset = ImageFolder(image_dir, transform=transform1)
#dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False)

# Precompute the visual features of the items using the CNN model
item_features = {}
visual_model.eval()
#for ep in range(num_epochs):
with torch.no_grad():
     for batch_idx, (images, item_ids) in enumerate(dataloader):
         print(batch_idx)
         features = visual_model(images)
         for i, item_id in enumerate(item_ids):
             item_features[item_id.item()] = features[i].cpu().numpy()

# Convert the item features dictionary to a numpy array
num_items = len(item_features)
item_features_array_resnet = np.zeros((num_items, 50), dtype=np.float32)
for i, item_id in enumerate(sorted(item_features)):
    item_features_array_resnet[i] = item_features[item_id]
similarities_item_features_array_resnet = cosine_similarity(item_features_array_resnet)
np.save("resnet_sim.npy", similarities_item_features_array_resnet)
dict_mapping_path_class = {}
for i in range(len(dataset.imgs)):
    path, class_id = dataset.imgs[i]
    dict_mapping_path_class[path] = class_id
with open('resnet_mapping_path_classid.pickle', 'wb') as f:
    pickle.dump(dict_mapping_path_class, f)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 204MB/s]


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52


In [None]:
!unzip zipped_index_based_images.zip -d images_here 

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: images_here/index_based_images/13449.jpg  
  inflating: images_here/index_based_images/2407.jpg  
  inflating: images_here/index_based_images/38814.jpg  
  inflating: images_here/index_based_images/2361.jpg  
  inflating: images_here/index_based_images/12031.jpg  
  inflating: images_here/index_based_images/34784.jpg  
  inflating: images_here/index_based_images/25689.jpg  
  inflating: images_here/index_based_images/20929.jpg  
  inflating: images_here/index_based_images/2375.jpg  
  inflating: images_here/index_based_images/4062.jpg  
  inflating: images_here/index_based_images/6675.jpg  
  inflating: images_here/index_based_images/6885.jpg  
  inflating: images_here/index_based_images/1132.jpg  
  inflating: images_here/index_based_images/25879.jpg  
  inflating: images_here/index_based_images/1654.jpg  
  inflating: images_here/index_based_images/3043.jpg  
  inflating: images_here/index_based_images/2349

In [None]:
!pip3 install timm

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import timm

encoding_dim = 4096

# Define the CNN model to extract visual features
class VisualENcoder_Vit(nn.Module):
    def __init__(self, encoding_dim):
        super(VisualENcoder_Vit, self).__init__()
        self.model = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.fc1 = nn.Linear(1000, encoding_dim)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(encoding_dim, 50)
        

    def forward(self, x):
        features = self.model(x)
        features = self.fc1(features)
        features = self.relu1(features)
        features = self.fc2(features)
        return features



encoding_dim = 4096
visual_model = VisualENcoder_Vit(encoding_dim)

# Load the dataset and pre-process the images
transform1 = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

image_dir = 'images_here/'
dataset = ImageFolder(image_dir, transform=transform1)
#dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False)

# Precompute the visual features of the items using the CNN model
item_features = {}
visual_model.eval()
#for ep in range(num_epochs):
with torch.no_grad():
     for batch_idx, (images, item_ids) in enumerate(dataloader):
         print(batch_idx)
         features = visual_model(images)
         for i, item_id in enumerate(item_ids):
             item_features[item_id.item()] = features[i].cpu().numpy()
             #print(item_features[item_id.item()].shape)

# Convert the item features dictionary to a numpy array
num_items = len(item_features)
item_features_array_vit = np.zeros((num_items, 50), dtype=np.float32)
for i, item_id in enumerate(sorted(item_features)):
    item_features_array_vit[i] = item_features[item_id]
similarities_item_features_array_vit = cosine_similarity(item_features_array_vit)
np.save("vit_sim.npy", similarities_item_features_array_vit)
dict_mapping_path_class = {}
for i in range(len(dataset.imgs)):
    path, class_id = dataset.imgs[i]
    dict_mapping_path_class[path] = class_id
with open('vit_mapping_path_classid.pickle', 'wb') as f:
    pickle.dump(dict_mapping_path_class, f)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
