# Building unsplash-25k dataset for image search with Vector SQL

## create and insert data to table

In [None]:
from toml import load as pload
from clickhouse_connect import get_client
from clickhouse_connect.driver.tools import insert_file

setting = pload('../.streamlit/secrets.toml')

MYSCALE_USER = setting['MYSCALE_USER']
MYSCALE_PASSWORD = setting['MYSCALE_PASSWORD']
MYSCALE_HOST = setting['MYSCALE_HOST']
MYSCALE_PORT = setting['MYSCALE_PORT']
OPENAI_API_BASE = setting['OPENAI_API_BASE']
OPENAI_API_KEY = setting['OPENAI_API_KEY']

client = get_client(host=MYSCALE_HOST, port=MYSCALE_PORT, user=MYSCALE_USER, password=MYSCALE_PASSWORD)

In [None]:
# download our database dump
!mkdir -p data
!wget https://myscale-demo.s3.ap-southeast-1.amazonaws.com/visual-dataset-explorer/unsplash_25k_clip_indexer.pq -O data/photos.parquet

## Insert photos with vectors and build vector index for photos

In [2]:
import pandas as pd
client.command('''CREATE DATABASE IF NOT EXISTS unsplash''')
client.command('''DROP TABLE IF EXISTS unsplash.photos''')
client.command('''
CREATE TABLE IF NOT EXISTS unsplash.photos (
  `photo_id` String,
  `photo_url` String,
  `photo_vector` Array(Float32),
  CONSTRAINT constraint_vec_length CHECK length(photo_vector) = 512
) ENGINE = MergeTree
ORDER BY
  photo_id SETTINGS index_granularity = 8192
''')
df = pd.read_parquet('data/photos.parquet')
df['photo_id'] = df['id']
df['photo_url'] = df['url']
df['photo_vector'] = df['vector']
df[['photo_id', 'photo_url', 'photo_vector']].to_parquet('data/exported_photos.parquet')
insert_file(client, 'photos', 'data/exported_photos.parquet', fmt='Parquet', database='unsplash')
client.command('ALTER TABLE unsplash.photos ADD INDEX vindx photo_vector TYPE annoy() GRANULARITY 8192')

['0|chi-msc-1decbcc9-msc-1decbcc9-0-0', '0', '', '0', '0']

## create attribute table for photos

In [None]:
!https_proxy=http://localhost:7890 wget -c https://unsplash-datasets.s3.amazonaws.com/lite/latest/unsplash-research-dataset-lite-latest.zip - data/attributes.zip
!unzip -o data/attributes.zip -d data


In [None]:
# attribute table

import numpy as np
import pandas as pd

key_list = [
    'photo_id',
    'photo_featured', 
    'photo_width', 
    'photo_height', 
    'photo_aspect_ratio', 
    'photographer_username', 
    'exif_camera_make', 
    'exif_camera_model', 
    'photo_location_country', 
    'photo_location_city',
    ]

img_attr = pd.read_csv('data/photos.tsv000', delimiter='\t')
img_attr['photo_featured'] = np.where(img_attr['photo_featured']=='t', True, False)
img_attr[key_list].to_parquet('data/photos_attr.parquet')

In [None]:
client.command('''DROP TABLE IF EXISTS unsplash.photos_attributes''')
client.command('''
CREATE TABLE IF NOT EXISTS unsplash.photos_attributes (
  `photo_id` String,
  `photo_featured` Bool,
  `photo_width` Int64,
  `photo_height` Int64,
  `photo_aspect_ratio` Double,
  `photographer_username` String,
  `exif_camera_make` Nullable(String),
  `exif_camera_model` Nullable(String),
  `photo_location_country` Nullable(String),
  `photo_location_city` Nullable(String)
) ENGINE = MergeTree
ORDER BY
  photo_id SETTINGS index_granularity = 8192
''')
df = pd.read_parquet('data/photos_attr.parquet')
_ = insert_file(client, 'photos_attributes', 'data/photos_attr.parquet', fmt='Parquet', database='unsplash')

In [None]:
# conversion table

from datetime import datetime as dt
from tqdm import tqdm
import pandas as pd

key_list = [
    'converted_at',
    'conversion_type',
    'photo_id',
    'anonymous_user_id',
    'conversion_country',
]
img_c = pd.read_csv('data/conversions.tsv000', delimiter='\t')
# 2020-07-29 00:08:04.221
img_c['converted_at'] = [dt.strptime(r.split('.')[0], '%Y-%m-%d %H:%M:%S') for r in tqdm(img_c['converted_at'])]
img_c[key_list].to_parquet('data/photos_conversions.parquet')

In [None]:
from tqdm import tqdm
from multiprocessing.pool import ThreadPool
img_c = pd.read_parquet('data/photos_conversions.parquet')
batch_size = 1024

client.command('''DROP TABLE IF EXISTS unsplash.photo_conversions''')
client.command('''
CREATE TABLE IF NOT EXISTS unsplash.photo_conversions (
  `converted_at` DateTime,
  `conversion_type` String,	
  `photo_id` String,
  `anonymous_user_id` String,
  `conversion_country` Nullable(String)
) ENGINE = MergeTree
ORDER BY
  photo_id SETTINGS index_granularity = 8192
''')

# batch-wise insertion if you have large dataframe
def single(n):
    t_client = get_client(host=MYSCALE_HOST, port=MYSCALE_PORT, user=MYSCALE_USER, password=MYSCALE_PASSWORD)
    t_client.insert_df('unsplash.photo_conversions', img_c[n:min(n+batch_size, len(img_c))])
    t_client.close()

# Using thread to obtain maximized performance
with ThreadPool(64) as p:
    batches = list(range(0, len(img_c), batch_size))
    for _ in tqdm(p.imap_unordered(single, batches), total=len(batches)):
        pass

In [None]:
[_r for _r in client.query('SELECT COUNT(*) FROM unsplash.photo_conversions').named_results()]

In [3]:
from os import environ
from typing import Dict, Any
from langchain import OpenAI
from langchain import PromptTemplate
from sqlalchemy import create_engine, Column, MetaData
from clickhouse_sqlalchemy import (
    Table, make_session, get_declarative_base, types, engines
)
import sys
sys.path.append('..')
from chains.unsplash_chains import UnsplashSQLChain
from prompts.unsplash_prompt import _DEFAULT_TEMPLATE

from toml import load as pload

setting = pload('../.streamlit/secrets.toml')

MYSCALE_USER = setting['MYSCALE_USER']
MYSCALE_PASSWORD = setting['MYSCALE_PASSWORD']
MYSCALE_HOST = setting['MYSCALE_HOST']
MYSCALE_PORT = setting['MYSCALE_PORT']
OPENAI_API_BASE = setting['OPENAI_API_BASE']
OPENAI_API_KEY = setting['OPENAI_API_KEY']

engine = create_engine(
    f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/unsplash?protocol=https')
metadata = MetaData(bind=engine)

PROMPT = PromptTemplate(input_variables=['top_k', 'table_info', 'input'],
                        template=_DEFAULT_TEMPLATE)

def get_key():
    with open('key.txt') as f:
        keys = [l.split('\n')[0] for l in f.readlines() if l[:3] == 'sk-']
    return keys[0]

environ['OPENAI_API_KEY'] = OPENAI_API_KEY
environ['OPENAI_API_BASE'] = OPENAI_API_BASE

  metadata = MetaData(bind=engine)


In [None]:
from os import environ
from typing import List
from langchain import SQLDatabase, OpenAI
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.chains.sql_database.parser import VectorSQLOutputParser
from langchain.embeddings.base import Embeddings

from transformers import CLIPTokenizerFast, CLIPModel
class EmbModel(Embeddings):
    def __init__(self, model_name = "openai/clip-vit-base-patch32") -> None:
        model_name = "openai/clip-vit-base-patch32"
        self.tokenizer = CLIPTokenizerFast.from_pretrained(model_name)
        self.clip = CLIPModel.from_pretrained(model_name)
    
    def embed_query(self, prompt: str, tokenizer, clip):
        inputs = tokenizer(prompt, return_tensors='pt')
        out = clip.get_text_features(**inputs)
        xq = out.squeeze(0).cpu().detach().numpy().tolist()
        return xq
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return super().embed_query(texts)

chain = SQLDatabaseChain.from_llm(llm=OpenAI(temperature=0), prompt=PROMPT, verbose=True, 
                                  db=SQLDatabase(engine, None, metadata), 
                                  sql_cmd_parser=VectorSQLOutputParser(
                                      model=EmbModel(model_name="openai/clip-vit-base-patch32")),
                                  return_direct=True)
chain.verbose = True


In [None]:
def natural_sql(question):
    _sql_str = chain.run(question)
    for n in ['This is not a valid question.']:
        if n in _sql_str:
            return 'cannot-parse'
    start = _sql_str.find('NeuralArray(')
    if start > 0:
        _matched = _sql_str[_sql_str.find('NeuralArray(')+len('NeuralArray('):]
        entity = _matched[:_matched.find(')')]
        end = _matched.find(')') + start + len('NeuralArray(') + 1
        vecs = prompt2vec(entity, tokenizer, clip)
        vecs_str = '[' + ','.join(map(str, vecs)) + ']'
        _sql_str_compl = _sql_str.replace('DISTANCE', 'cosineDistance').replace(_sql_str[start:end], vecs_str)
        if _sql_str_compl[-1] == ';':
            _sql_str_compl = _sql_str_compl[:-1]
    else:
        _sql_str_compl = _sql_str
    try:
        r = client.query(_sql_str_compl).named_results()
    except Exception as e:
        return 'cannot-execute'
    return [_r for _r in r]


## General SQL

In [None]:
chain("what is the photo that has the most downloads which was taken by davidclode?")

## Conditioned Vector Search SQL

In [None]:
natural_sql("what is the most-5 similar photos's url to dog which was shot in Australia?")

## Complicated entity SQL

In [None]:
natural_sql("what is the most-10 similar photo to an entity called 'a lake by a house'? And what are their numbers of download?")

## Complicated entity SQL with implicit condition

In [None]:
natural_sql("what is the most-10 similar photo to an entity called 'a lake by a house' which is a square photo?")

## Group-By clause

In [None]:
natural_sql("what are the most popular authors?")

## Test ALL

In [None]:
from glob import glob
import json
q = [json.load(open(f))['question'] for f in glob('log/*.json')]
result = {}
for _q in q:
    result[_q] = natural_sql(_q)
    print(result[_q])

In [None]:
len([k for k, v in result.items() if v == 'cannot-execute'])

## Some insights

In [None]:
from langchain import LLMChain
# what does the prompt look like?
llm_chain = LLMChain(llm=chain.llm, prompt=chain.prompt)
llm_inputs = {
    "input": "what is the closet id to dog whose bounding box width and height is smaller than 0.5?",
    "top_k": chain.top_k,
    "dialect": chain.database.dialect,
    "table_info": chain.database.get_table_info(table_names=None),
    "stop": ["\nSQLQuery:"],
}
print(llm_chain.prep_prompts([llm_inputs])[0][0].text)

In [None]:
# replace neural array with embeddings
import re
from lexer import Lexer, Rule, Token

text_sql = " SELECT obj_id FROM vector_database WHERE box_w < 0.5 AND box_h < 0.5 ORDER BY DISTANCE(prelogit, NeuralArray(dog)) LIMIT 5"

lexer = Lexer(
    rules=[
        Rule('SELECT', re.compile('\s*(SELECT|select)\s+\w+\s+')),
        Rule('FROM', re.compile('(FROM|from)\s+\w+\s+')),
        Rule('WHERE', re.compile('(WHERE|where)\s+\w+\s*((\!\=)|[\>\<\=])\s*[\w\d\.\']+(\s+(AND|and)\s+\w+\s*((\!\=)|[\>\<\=])\s*[\w\d\.\']+\s*)*\s+')),
        Rule('ORDERBY', re.compile('(ORDER\s+BY|order\s+by)\s+[\w]*\(*[\w\,\s\(\)]+\)\s+'),
             next=[
                 Rule('clause', re.compile('(ORDER\s+BY|order\s+by)\s+')),
                 Rule('expr', re.compile('\w+\([\w\s\,\(\)]+\)'), next=[
                     Rule('op', re.compile('\w+\(')),
                     Rule('col', re.compile('\w+\,')),
                     Rule('Narr', re.compile('\s*NeuralArray\(')),
                     Rule('entity', re.compile('\s*\w+')),
                     Rule(')', re.compile('\)*')),
                     ])
                 ]),
        Rule('LIMIT', re.compile('(LIMIT|limit)\s+\d'))
    ]
)

def token2lrtree(tokens):
    return {t.identifier: (token2lrtree(t.content)) if type(t.content) is list else t.content for i, t in enumerate(tokens)}

t = lexer.lex(text_sql)
d = token2lrtree(t[0])
print(d)

# for n in t[0]:
#     if n.identifier == 'ORDERBY':
        
#         print(n.content)