In [None]:
import header

from multiprocessing.pool import ThreadPool
from pathlib import Path

import pandas as pd
import numpy as np
import os
import requests
import tldextract
from v0.ai import embedding_model
from v0.models import Image


FILE = Path('data/gcc/validation.tsv')

if not FILE.is_file():
    raise FileNotFoundError(f'{FILE} not found')


In [None]:
images = pd.read_csv(FILE, sep='\t', header=None, names=['label', 'url'])
print(images.head())

# filter to ones we havent done yet
os.environ["DJANGO_ALLOW_ASYNC_UNSAFE"] = "true"
done_urls = list(Image.objects.all().values_list('url', flat=True))
print(len(done_urls))

images = images.drop(images[images.url.isin(done_urls)].index)
print(len(images))

In [None]:
# embeddings
embeds = np.array(embedding_model.model.encode(images['label'].to_list())).astype(np.float32)
images['embed'] = embeds.tolist()

In [None]:

def process_image(i_image):
    i, image = i_image
    try:
        image_model = Image()
        image_model.url = image['url']
        image_model.description = image['label']
        image_model.embedding_all_mpnet_base_v2 = image['embed']
        image_model.domain = tldextract.extract(image['url']).domain
        image_model.provider = Image.providers.GCC_DATASET

        # validate the image is alive
        r = requests.head(image['url'], timeout=1)
        if r.status_code != 200:
            image_model.url_alive = False
        image_model.save()
        return
    except Exception as e:
        print(e)

# for i, image in images.iterrows(): 
#     process_image((i, image))
#     print(i)
pool = ThreadPool(processes=16)
pool.map(process_image, images.iterrows())