In [2]:
# Download initial data

import os
import urllib.request
import gzip
import shutil

url = 'https://storage.googleapis.com/gresearch/wit/wit_v1.train.all-1percent_sample.tsv.gz'
filename = 'data/wit/data.tsv'

if os.path.exists(filename):
    print("The file exists")

else:
    # Download the data from the URL
    with urllib.request.urlopen(url) as response:
      with open(filename + '.gz', 'wb') as f:
        f.write(response.read())
    
    # Extract the data from the compressed file
    with gzip.open(filename + '.gz', 'rb') as f_in:
      with open(filename, 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)

    print("The file was downloaded")

The file exists


In [14]:
# Create Postgres table with initial data

import os
import psycopg2

db_connection_string = os.environ.get('DATABASE_URL')
with psycopg2.connect(db_connection_string) as conn:
    with conn.cursor() as cursor:
        with open('data/wit/create_table.sql', 'r') as sql_file:
            sql_script = sql_file.read()
        cursor.execute(sql_script)
        print("Create table")
        
        count_query = "SELECT COUNT(*) FROM tsv_data"
        cursor.execute(count_query)
        row_count = cursor.fetchone()[0]
        
        if row_count == 0:
            with open('data/wit/copy_data.sql', 'r') as sql_file:
                sql_script = sql_file.read()
            cursor.execute(sql_script)
            print("Copied data")
        else:
            print("No need to copy data")
        
        image_urls_query = "SELECT id, image_url FROM tsv_data WHERE image_url_ai IS NULL LIMIT 10"
        cursor.execute(image_urls_query)
        image_urls = cursor.fetchall()
        
        conn.commit()
        print("Completed")

Create table
Copied data
Completed


In [44]:
# Process image data

import asyncio
import aiohttp
import asyncpg
import torch
import clip
import PIL
import io
import os
import json
from tqdm.asyncio import tqdm as async_tqdm
import nest_asyncio

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, preprocess = clip.load('ViT-B/32', device)
model.eval()
model.to(device)

BATCH_SIZE = 8

async def process_batch(session, items):
    images = []
    ids = []
    for item in items:
        id, image_url = item
        try:
            req_headers = {'User-Agent': 'SelectImages/0.0 (narekg.me; ngalstjan4@gmail.com)'}
            async with session.get(image_url, headers=req_headers) as response:
                image_bytes = await response.read()
                image = PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
                images.append(image)
                ids.append(id)
        except Exception as e:
            pass

    if len(images) > 0:
        preprocessed_images = torch.stack([preprocess(image) for image in images]).to(device)
        with torch.no_grad():
            image_embeddings = model.encode_image(preprocessed_images).tolist()
        image_embeddings = [json.dumps(embedding) for embedding in image_embeddings]
        return list(zip(image_embeddings, image_embeddings, ids))
    else:
        return None

async def main():
    db_connection_string = os.environ.get('DATABASE_URL')
    conn = await asyncpg.connect(db_connection_string)
    async with conn.transaction():
        rows = await conn.fetch('''
            SELECT
                id,
                image_url
            FROM
                tsv_data
            WHERE
                language = 'en'
                AND image_url IS NOT NULL
            ORDER BY
                RANDOM()
            LIMIT 1000
        ''')

        async with aiohttp.ClientSession() as session:
            results = []
            batch = []
            for item in async_tqdm(rows, total=len(rows)):
                batch.append(item)
                if len(batch) == BATCH_SIZE:
                    batch_result = await process_batch(session, batch)
                    if batch_result is not None:
                        results.extend(batch_result)
                    batch = []

            # Process the last batch if it's not empty
            if len(batch) > 0:
                batch_result = await process_batch(session, batch)
                if batch_result is not None:
                    results.extend(batch_result)

            # Execute the SQL update queries for the entire batch at once
            update_query = "UPDATE tsv_data SET image_url_ai1 = $1, image_url_ai2 = $2 WHERE id = $3"
            await conn.executemany(update_query, results)

    await conn.close()

nest_asyncio.apply()
asyncio.run(main())

100%|██████████| 1000/1000 [08:17<00:00,  2.01it/s]


In [None]:
# Process text data
import asyncio
import os
import torch
import asyncpg
from tqdm import tqdm
from transformers import DistilBertModel, DistilBertTokenizer
import json
import nest_asyncio
nest_asyncio.apply()
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'distilbert-base-uncased'
model = DistilBertModel.from_pretrained(model_name)
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model.eval()
model.to(device)

def process_batch(batch):
    # Unpack IDs and texts from the batch
    ids, texts = zip(*batch)

    inputs = tokenizer.batch_encode_plus(
        texts,
        add_special_tokens=True,
        padding='longest',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        text_embeddings = torch.mean(outputs.last_hidden_state, dim=1).tolist()

    text_embeddings = [json.dumps(embedding) for embedding in text_embeddings]

    return list(zip(text_embeddings, text_embeddings, ids))

async def main():
    db_connection_string = os.environ.get('DATABASE_URL')
    conn = await asyncpg.connect(db_connection_string)
    async with conn.transaction():
        rows = await conn.fetch('''
            SELECT
                id,
                context_page_description
            FROM
                tsv_data
            WHERE
                language = 'en'
                AND context_page_description IS NOT NULL
            ORDER BY
                RANDOM()
            LIMIT 1000
        ''')
        
        batch_size = 32
        results = []
        for i in tqdm(range(0, len(rows), batch_size)):
            batch = rows[i:i+batch_size]
            batch_result = process_batch(batch)
            results.extend(batch_result)

        # Execute the SQL update queries for the entire batch at once
        update_query = "UPDATE tsv_data SET context_page_description_ai1 = $1, context_page_description_ai2 = $2 WHERE id = $3"
        await conn.executemany(update_query, results)

    await conn.close()

asyncio.run(main())

 12%|█▎        | 4/32 [00:26<03:02,  6.51s/it]

In [11]:
# Generate experiment data

import os
import time
import numpy as np
from pgvector.psycopg2 import register_vector
import psycopg2

db_connection_string = os.environ.get('DATABASE_URL')

n_interval = 1000
num_iterations = 100  # Increase the sample size
warm_up_iterations = 10  # Number of iterations for warm-up

def measure_throughput(label, query, generate_args):
    with psycopg2.connect(db_connection_string) as conn:
        register_vector(conn)
        with conn.cursor() as cursor:
            # Create table

            # Helper for executing query
            def execute_query():
                if generate_args is not None:
                    vector = generate_args()
                    cursor.execute(query, (vector,))
                else:
                    cursor.execute(query)
            
            # Warm-up
            for _ in range(warm_up_iterations):
                execute_query()

            # Measured
            start_time = time.time()
            for _ in range(num_iterations):
                execute_query()
            end_time = time.time()

    elapsed_time = end_time - start_time
    throughput = num_iterations / elapsed_time
    return throughput

def create_test_table(n, column):
    with psycopg2.connect(db_connection_string) as conn:
        with conn.cursor() as cursor:
            cursor.execute("DROP TABLE IF EXISTS test_table")
            cursor.execute(f'''
                CREATE TABLE test_table AS
                SELECT * FROM tsv_data
                WHERE {column}1 IS NOT NULL
                ORDER BY RANDOM()
                LIMIT {n}
            ''')
            cursor.execute('CREATE INDEX ON test_table USING ivfflat (image_url_ai2 vector_l2_ops) WITH (lists = 100)')
            cursor.execute('CREATE INDEX ON test_table USING ivfflat (context_page_description_ai2 vector_l2_ops) WITH (lists = 100)')
            cursor.execute('CREATE INDEX ON test_table (original_width)')

def drop_test_table():
    with psycopg2.connect(db_connection_string) as conn:
        with conn.cursor() as cursor:
            cursor.execute("DROP TABLE IF EXISTS test_table")

def get_table_count(column):
    with psycopg2.connect(db_connection_string) as conn:
        with conn.cursor() as cursor:
            cursor.execute(f"SELECT COUNT(*) FROM tsv_data WHERE {column}1 IS NOT NULL")
            count = cursor.fetchone()[0]
            return count

def column_experiments(column, dimensions):
    results = {}
    
    max_count = get_table_count(column)
    start = n_interval
    end = max_count - (max_count % n_interval) + 1
    for count in range(start, end, n_interval):
        create_test_table(count, column)

        experiment_inputs = [
            {
                'label': 'select 1',
                'query': 'SELECT 1'
            },
            {
                'label': 'select id',
                'query': f"SELECT 1 FROM test_table WHERE id < 100"
            },
            {
                'label': 'select int',
                'query': f"SELECT 1 FROM test_table WHERE original_height < 100"
            },
            {
                'label': 'select int (indexed)',
                'query': "SELECT 1 FROM test_table WHERE original_width < 100",
            },
            {
                'label': 'select vector',
                'query': f"SELECT 1 FROM test_table ORDER BY {column}1 <-> %s LIMIT 10",
                'generate_args': lambda: np.random.rand(dimensions)
            },
            {
                'label': 'select vector (indexed)',
                'query': f"SELECT * FROM test_table ORDER BY {column}2 <-> %s LIMIT 10",
                'generate_args': lambda: np.random.rand(dimensions)
            },
        ]

        for input in experiment_inputs:
            label = input['label']
            query = input['query']
            generate_args = input['generate_args'] if 'generate_args' in input else None
            throughput = measure_throughput(label, query, generate_args=generate_args)
            if label not in results:
                results[label] = []
            results[label].append((count, throughput))

        drop_test_table()

    return results
    
results1 = column_experiments('image_url_ai', 512)
results2 = column_experiments('context_page_description_ai', 768)

In [13]:
# Plot experiments

import plotly.graph_objects as go

def generate_plots(results):
    # Create traces for each query type
    data = []
    for query_type, query_results in results.items():
        sizes, throughputs = zip(*query_results)  # Unpack the table sizes and throughput values
        trace = go.Scatter(
            x=sizes,
            y=throughputs,
            mode='lines+markers',
            name=query_type
        )
        data.append(trace)
    
    # Create layout
    layout = go.Layout(
        title='Throughput vs. Table Size for Each Query Type',
        xaxis=dict(title='Table Size'),
        yaxis=dict(title='Throughput (ops/s)'),
        showlegend=True
    )
    
    # Create figure
    fig = go.Figure(data=data, layout=layout)
    
    # Display the figure
    fig.show()

generate_plots(results1)

In [14]:
# Print experiment tables
print_double_line = lambda: print("=" * 40)
print_line = lambda: print("-" * 40)

def print_throughput_results(results):
    for label, data in results.items():

        print("")
        print_double_line()
        print(f"{label}")
        print_line()

        print(f"{'Table Size':<15}{'Throughput (ops/s)':<20}")
        print_line()

        for table_size, throughput in data:
            print(f"{table_size:<15}{throughput:<20.3f}")

        print_double_line()
        print("")

# Example usage:
print_throughput_results(results1)


select 1
----------------------------------------
Table Size     Throughput (ops/s)  
----------------------------------------
1000           21693.928           
2000           16148.087           
3000           14719.955           
4000           16810.838           
5000           17669.899           
6000           18458.408           
7000           13082.262           
8000           13546.618           
9000           11520.281           
10000          12040.488           
11000          12400.012           
12000          11958.442           
13000          16557.335           


select id
----------------------------------------
Table Size     Throughput (ops/s)  
----------------------------------------
1000           7654.258            
2000           4433.303            
3000           3227.108            
4000           2559.188            
5000           2103.578            
6000           1994.922            
7000           1707.994            
8000           1346.19