In [None]:
# Download initial data

import os
import urllib.request
import gzip
import shutil

url = 'https://storage.googleapis.com/gresearch/wit/wit_v1.train.all-1percent_sample.tsv.gz'
filename = 'data/wit/data.tsv'

if os.path.exists(filename):
    print("The file exists")

else:
    # Download the data from the URL
    with urllib.request.urlopen(url) as response:
      with open(filename + '.gz', 'wb') as f:
        f.write(response.read())
    
    # Extract the data from the compressed file
    with gzip.open(filename + '.gz', 'rb') as f_in:
      with open(filename, 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)

    print("The file was downloaded")

In [None]:
# Create Postgres table with initial data

import os
import psycopg2

db_connection_string = os.environ.get('DATABASE_URL')
with psycopg2.connect(db_connection_string) as conn:
    with conn.cursor() as cursor:
        with open('data/wit/create_table.sql', 'r') as sql_file:
            sql_script = sql_file.read()
        cursor.execute(sql_script)
        print("Create table")
        
        count_query = "SELECT COUNT(*) FROM tsv_data"
        cursor.execute(count_query)
        row_count = cursor.fetchone()[0]
        
        if row_count == 0:
            with open('data/wit/copy_data.sql', 'r') as sql_file:
                sql_script = sql_file.read()
            cursor.execute(sql_script)
            print("Copied data")
        else:
            print("No need to copy data")
        
        image_urls_query = "SELECT id, image_url FROM tsv_data WHERE image_url_ai IS NULL LIMIT 10"
        cursor.execute(image_urls_query)
        image_urls = cursor.fetchall()
        
        conn.commit()
        print("Completed")

In [None]:
# Process image data

import asyncio
import aiohttp
import asyncpg
import torch
import clip
import PIL
import io
import os
import json
from tqdm.asyncio import tqdm as async_tqdm
import nest_asyncio

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, preprocess = clip.load('ViT-B/32', device)
model.eval()
model.to(device)

BATCH_SIZE = 8

async def process_batch(session, items):
    images = []
    ids = []
    for item in items:
        id, image_url = item
        try:
            req_headers = {'User-Agent': 'SelectImages/0.0 (narekg.me; ngalstjan4@gmail.com)'}
            async with session.get(image_url, headers=req_headers) as response:
                image_bytes = await response.read()
                image = PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
                images.append(image)
                ids.append(id)
        except Exception as e:
            pass

    if len(images) > 0:
        preprocessed_images = torch.stack([preprocess(image) for image in images]).to(device)
        with torch.no_grad():
            image_embeddings = model.encode_image(preprocessed_images).tolist()
        image_embeddings = [json.dumps(embedding) for embedding in image_embeddings]
        return list(zip(image_embeddings, image_embeddings, ids))
    else:
        return None

async def main():
    db_connection_string = os.environ.get('DATABASE_URL')
    conn = await asyncpg.connect(db_connection_string)
    async with conn.transaction():
        rows = await conn.fetch('''
            SELECT
                id,
                image_url
            FROM
                tsv_data
            WHERE
                language = 'en'
                AND image_url IS NOT NULL
            ORDER BY
                RANDOM()
            LIMIT 1000
        ''')

        async with aiohttp.ClientSession() as session:
            results = []
            batch = []
            for item in async_tqdm(rows, total=len(rows)):
                batch.append(item)
                if len(batch) == BATCH_SIZE:
                    batch_result = await process_batch(session, batch)
                    if batch_result is not None:
                        results.extend(batch_result)
                    batch = []

            # Process the last batch if it's not empty
            if len(batch) > 0:
                batch_result = await process_batch(session, batch)
                if batch_result is not None:
                    results.extend(batch_result)

            # Execute the SQL update queries for the entire batch at once
            update_query = "UPDATE tsv_data SET image_url_ai1 = $1, image_url_ai2 = $2 WHERE id = $3"
            await conn.executemany(update_query, results)

    await conn.close()

nest_asyncio.apply()
asyncio.run(main())

In [None]:
# Process text data
import asyncio
import os
import torch
import asyncpg
from tqdm import tqdm
from transformers import DistilBertModel, DistilBertTokenizer
import json
import nest_asyncio
nest_asyncio.apply()
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'distilbert-base-uncased'
model = DistilBertModel.from_pretrained(model_name)
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model.eval()
model.to(device)

def process_batch(batch):
    # Unpack IDs and texts from the batch
    ids, texts = zip(*batch)

    inputs = tokenizer.batch_encode_plus(
        texts,
        add_special_tokens=True,
        padding='longest',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        text_embeddings = torch.mean(outputs.last_hidden_state, dim=1).tolist()

    text_embeddings = [json.dumps(embedding) for embedding in text_embeddings]

    return list(zip(text_embeddings, text_embeddings, ids))

async def main():
    db_connection_string = os.environ.get('DATABASE_URL')
    conn = await asyncpg.connect(db_connection_string)
    async with conn.transaction():
        rows = await conn.fetch('''
            SELECT
                id,
                context_page_description
            FROM
                tsv_data
            WHERE
                language = 'en'
                AND context_page_description IS NOT NULL
            ORDER BY
                RANDOM()
            LIMIT 10000
        ''')
        
        batch_size = 32
        results = []
        checkpoint_interval = 5  # Save data after processing every 5 batches
        checkpoint_count = 0

        for i in tqdm(range(0, len(rows), batch_size)):
            batch = rows[i:i+batch_size]
            batch_result = process_batch(batch)
            results.extend(batch_result)

            checkpoint_count += 1
            if checkpoint_count >= checkpoint_interval:
                # Execute the SQL update queries for the current batch
                update_query = "UPDATE tsv_data SET context_page_description_ai1 = $1, context_page_description_ai2 = $2 WHERE id = $3"
                await conn.executemany(update_query, results)

                # Reset the results and checkpoint count
                results = []
                checkpoint_count = 0

        # Save any remaining data after the last checkpoint
        if results:
            await conn.executemany(update_query, results)

    await conn.close()

asyncio.run(main())

 89%|████████▉ | 278/313 [43:53<03:42,  6.36s/it] 