In [None]:
import os

import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm

from datasets.utils import PreprocessingDataset
from models.utils import get_model_by_name
from utils.environment import modified_environ

# Create image embeddings

In [None]:
# Dataset
DATASET = "Wikimedia"
assert DATASET in ["UGallery", "Wikimedia"]

In [None]:
# Parameters
BATCH_SIZE, NUM_WORKERS = 8, 4
IMAGES_EXT = ["*.gif", "*.jpg", "*.jpeg", "*.png", "*.webp"]
USE_GPU = True

# Model
MODEL = "resnet50"
LAYER = "" # if not defined the last layer, before the classification, output will be extracted
assert MODEL in ["alexnet", "vgg16", "resnet50"]

# Images path
IMAGES_DIR = None
if DATASET == "Wikimedia":
    IMAGES_DIR = os.path.join("/", "mnt", "data2", "wikimedia", "mini-images-224-224-v2")
elif DATASET == "UGallery":
    IMAGES_DIR = os.path.join("/", "mnt", "workspace", "Ugallery", "mini-images-224-224-v2")


In [None]:
# Paths (output)
LAYERED_OUTPUT = f"-{LAYER}" if LAYER else ""
OUTPUT_EMBEDDING_PATH = os.path.join("data", DATASET, f"embedding-{MODEL}{LAYERED_OUTPUT}.npy")


In [None]:
import PIL
from PIL import ImageFile


# Needed for some images in the Wikimedia dataset
PIL.Image.MAX_IMAGE_PIXELS = 3_000_000_000
# Some images are "broken" in Wikimedia dataset
ImageFile.LOAD_TRUNCATED_IMAGES = True


In [None]:
%%time
# Setting up torch device (useful if GPU available)
print("\nCreating device...")
device = torch.device("cuda:0" if torch.cuda.is_available() and USE_GPU else "cpu")
if torch.cuda.is_available() != USE_GPU:
    print((f"\nNotice: Not using GPU - "
           f"Cuda available ({torch.cuda.is_available()}) "
           f"does not match USE_GPU ({USE_GPU})"
    ))

# Downloading models for feature extraction
print("\nDownloading model...")
with modified_environ(TORCH_HOME="."):
    print(f"Model: {MODEL} (pretrained on imagenet)")
    model = get_model_by_name(MODEL, output_layer=LAYER).eval().to(device)

# Setting up transforms and dataset
print("\nSetting up transforms and dataset...")
images_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_dataset = PreprocessingDataset(
    IMAGES_DIR,
    extensions=IMAGES_EXT,
    transform=images_transforms,
)
image_dataloader = DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
print(f">> Images dataset: {len(image_dataset)}")

# Calculate embedding dimension size
dummy_input = torch.ones(1, *image_dataset[0]["image"].size()).to(device)
dummy_output = model(dummy_input)
emb_dim = dummy_output.shape[1:] if LAYER else dummy_output.size(1)
print(f">> Embedding dimension size: {emb_dim}")

# Feature extraction phase
print(f"\nFeature extraction...")
output_ids = np.empty(len(image_dataset), dtype=object)
if LAYER:
    output_embedding = torch.zeros((len(image_dataset), *emb_dim), dtype=torch.float32, device=device)
else:
    output_embedding = torch.zeros((len(image_dataset), emb_dim), dtype=torch.float32, device=device)

with torch.no_grad():
    for batch_i, sample in enumerate(tqdm(image_dataloader, desc="Feature extraction")):
        item_image = sample["image"].to(device)
        item_idx = sample["idx"]
        output_ids[[*item_idx]] = sample["id"]
        output_embedding[item_idx] = model(item_image).squeeze(-1).squeeze(-1)

output_embedding = output_embedding.cpu().numpy()

# Fill output embedding
embedding = np.ndarray(
    shape=(len(image_dataset), 2),
    dtype=object,
)
for i in range(len(image_dataset)):
    embedding[i] = np.asarray([output_ids[i], output_embedding[i]])
print(f">> Embedding shape: {embedding.shape}")

# Save embedding to file
print(f"\nSaving embedding to file... ({OUTPUT_EMBEDDING_PATH})")
np.save(OUTPUT_EMBEDDING_PATH, embedding, allow_pickle=True)

# Free some memory
if USE_GPU:
    print(f"\nCleaning GPU cache...")
    model = model.to(torch.device("cpu"))
    torch.cuda.empty_cache()

# Finished
print("\nDone")
