-
Notifications
You must be signed in to change notification settings - Fork 3
/
generate_embeddings.py
77 lines (58 loc) · 1.95 KB
/
generate_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import json
import numpy as np
import torch
import tqdm
from PIL import Image
from .models import create_and_load_from_hub
IMAGE_SIZE_OUTPUT = 'small'
base_folder = './datasets/unsplash'
def list_files(dir, image_type='jpg'):
files = []
for parent_path, _, filenames in os.walk(dir):
for f in filenames:
if f'.{image_type}' not in f:
continue
files.append(os.path.join(parent_path, f))
return files
def load_images(files):
images = {file: Image.open(file) for file in files}
return images
def calculate_image_embeddings(model, images, device='cpu'):
images_input = torch.stack([model.image_transform(img).to(device) for img in images])
with torch.cuda.amp.autocast():
with torch.no_grad():
images_embeddings = model.vision_encoder(images_input).float().to(device)
return images_embeddings
def generate_unsplash_embeddings(input_folder, output_folder, model=None, device='cpu'):
if not model:
model = create_and_load_from_hub()
all_embeddings = []
urls = []
# Input files
files = list_files(input_folder)
for file in tqdm.tqdm(files):
image = Image.open(file)
with open(file.replace('.jpg', '.json'), 'r') as f:
image_data = json.load(f)
urls.append(image_data['urls'][IMAGE_SIZE_OUTPUT])
embeddings = calculate_image_embeddings(model, [image], device=device)
all_embeddings.append(embeddings[0].cpu().numpy())
# Save embeddings
if not os.path.exists(output_folder):
os.makedirs(output_folder)
np.save(f'{output_folder}/embeddings.npy', all_embeddings)
with open(f'{output_folder}/urls.json', 'w') as f:
json.dump(urls, f, indent=4)
if __name__ == '__main__':
device = 'cpu' # hpu, cuda, or cpu
# Create model
model = create_and_load_from_hub()
model.to(device)
# Generate embeddings
generate_unsplash_embeddings(
input_folder=f'{base_folder}/images',
output_folder=f'{base_folder}/embeddings',
model = model,
device = device,
)