In [None]:
import glob
import pathlib

import torch
import clip
import tqdm
from PIL import Image
import cv2
import numpy as np
import os
import pandas as pd
import duckdb

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def batch_images_generator(path_to_folder, bs, preprocessor, resume_from_batch=None, list_of_images=None):
    if list_of_images:
        images = list_of_images
    else:
        images = sorted(glob.glob('*.jpg', root_dir=path_to_folder))
        images = [image for image in images if os.path.getsize(path_to_folder + image) > 0]
    print(len(images))
    start = 0 * bs if not resume_from_batch else resume_from_batch * bs
    print(f'start from {start}')
    for i in tqdm.tqdm(range(start, len(images), bs)):
        batch_images = images[i:i + bs]
        torch_batch = []
        valid_images = []
        for image in batch_images:
            try:
                torch_batch.append(
                    preprocessor(images=Image.open(path_to_folder + image), return_tensors="pt")['pixel_values']
                )
                valid_images.append(image)
            except Exception as e:
                print(e)
                print(image)

        torch_batch = torch.cat(torch_batch).to(device)

        yield i, torch_batch, valid_images


def clip_embeddings(model, batch):

    with torch.no_grad():
        embeddings = model.encode_image(batch)

    return embeddings


def hf_embeddings(model, batch):
    with torch.no_grad():
        image_features = model.get_image_features(batch, normalize=True)
    return image_features


def get_missing_images(path_to_parquet, path_to_valid_images):
    with open(path_to_valid_images, 'r') as f:
        valid_images = list(map(lambda x: x.strip(), f.readlines()))

    embedded_images = pd.read_parquet(path_to_parquet, columns=['filename']).filename.values

    return list(set(valid_images) - set(embedded_images))

In [None]:
from transformers import AutoModel, AutoProcessor

model_name = 'Marqo/marqo-fashionSigLIP' # Marqo/marqo-ecommerce-embeddings-L
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

batch_size = 512
for t in ['train', 'test']:
    output_dir = 'parquets'
    #os.mkdir(output_dir, exist)
    #os.mkdir(output_dir, exist_ok=True)
    images_dir = f'avito/images/{t}/images/'
    images_list = get_missing_images('/avito/images/embeddings/final_embeddings_fashion_clip_train.parquet',
                                     './valid_train_images.txt')

    for i, batch, fname in batch_images_generator(images_dir,
                                                  bs=batch_size,
                                                  preprocessor=processor,
                                                  resume_from_batch=0,
                                                  list_of_images=images_list):

        embedded = hf_embeddings(model, batch).tolist()

        df = pd.DataFrame({
            "filename": fname,
            "embedding": embedded
        })

        df.to_parquet(
            os.path.join(output_dir, f"batch_{i // batch_size}.parquet"),
            index=False
        )

    duckdb.sql(rf"""
        COPY (
            SELECT * FROM '{output_dir}/*.parquet'
        )
        TO '{output_dir}/final_embeddings.parquet' (FORMAT PARQUET)
    """)