In [2]:
# Download initial data

import os
import urllib.request
import gzip
import shutil

url = 'https://storage.googleapis.com/gresearch/wit/wit_v1.train.all-1percent_sample.tsv.gz'
filename = 'data/wit/data.tsv'

if os.path.exists(filename):
    print("The file exists")

else:
    # Download the data from the URL
    with urllib.request.urlopen(url) as response:
      with open(filename + '.gz', 'wb') as f:
        f.write(response.read())
    
    # Extract the data from the compressed file
    with gzip.open(filename + '.gz', 'rb') as f_in:
      with open(filename, 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)

    print("The file was downloaded")

The file exists


In [3]:
# Create Postgres table with initial data

import os
import psycopg2

db_connection_string = os.environ.get('DATABASE_URL')
conn = psycopg2.connect(db_connection_string)
cursor = conn.cursor()

with open('data/wit/create_table.sql', 'r') as sql_file:
    sql_script = sql_file.read()
cursor.execute(sql_script)
print("Create table")

count_query = "SELECT COUNT(*) FROM tsv_data"
cursor.execute(count_query)
row_count = cursor.fetchone()[0]

if row_count == 0:
    with open('data/wit/copy_data.sql', 'r') as sql_file:
        sql_script = sql_file.read()
    cursor.execute(sql_script)
    print("Copied data")
else:
    print("No need to copy data")

image_urls_query = "SELECT id, image_url FROM tsv_data WHERE image_url_ai IS NULL LIMIT 10"
cursor.execute(image_urls_query)
image_urls = cursor.fetchall()

conn.commit()

cursor.close()
conn.close()

print("Completed")

Create table
No need to copy data
Completed


In [None]:
import psycopg2
import torch
import clip
import requests
import PIL
import io
from tqdm import tqdm
import os

db_connection_string = os.environ.get('DATABASE_URL')
conn = psycopg2.connect(db_connection_string)
cursor = conn.cursor()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, preprocess = clip.load('ViT-B/32', device)
model.eval()
model.to(device)

cursor.execute('''
    SELECT
        id,
        context_page_description,
        image_url
    FROM
        tsv_data
    WHERE
        language = 'en'
        AND context_page_description IS NOT NULL
        AND image_url IS NOT NULL
    LIMIT 10000
''')
rows = cursor.fetchall()

# Set the batch size for database updates
batch_size = 100

# Initialize a buffer for batched updates
update_buffer = []

# Execute batched updates and clear the buffer
def update_from_update_buffer():
    cursor.executemany(
        "UPDATE tsv_data SET image_url_ai = %s, context_page_description_ai = %s WHERE id = %s",
        update_buffer
    )
    update_buffer.clear()

# Process each tuple
for item in tqdm(rows):
    # Unpack the tuple
    id, summary, image_url = item

    # Download the image from the URL
    req_headers = {'User-Agent': 'SelectImages/0.0 (narekg.me; ngalstjan4@gmail.com)'}
    response = requests.get(image_url, headers=req_headers)
    image = PIL.Image.open(io.BytesIO(response.content)).convert("RGB")

    # Preprocess the image and summary
    preprocessed_image = preprocess(image).unsqueeze(0).to(device)
    preprocessed_summary = clip.tokenize(summary, truncate=True).to(device)

    # Generate image and summary embeddings
    with torch.no_grad():
        image_embedding = model.encode_image(preprocessed_image).squeeze()
        summary_embedding = model.encode_text(preprocessed_summary).squeeze()

    # Add the updated row to the buffer
    update_buffer.append((image_embedding.tolist(), summary_embedding.tolist(), id))

    # Execute batched updates when the buffer reaches the specified batch size
    if len(update_buffer) >= batch_size:
        update_from_update_buffer()

# Execute the remaining batched updates in the buffer
if len(update_buffer) > 0:
    update_from_update_buffer()
    

# Commit the changes to the database
conn.commit()

# Close the cursor and database connection
cursor.close()
conn.close()

  0%|          | 5/10000 [00:02<1:32:36,  1.80it/s]