In [3]:
DATASET_URL      = "https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz"
DATASET_PATH     = "../dataset"

DATABASE_CONF    = {
    "host": "192.168.56.101",
    "user": "postgres",
    "password": "a"
}
DATABASE_NAME    = "db_flower"
DATABASE_SCHEMA  = "flower_f1"

In [4]:
# Loading pre-trained model
# Details: https://github.com/openai/CLIP
import torch
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
# Download & extract dataset. From: https://www.tensorflow.org/datasets/catalog/oxford_flowers102
import urllib.request
import tarfile
import os

__save_path = os.path.join(DATASET_PATH, "102flowers.tgz")
urllib.request.urlretrieve(DATASET_URL, __save_path)

file = tarfile.open(__save_path, mode="r|gz")
file.extractall(DATASET_PATH)
file.close()

In [44]:
# Setup Database -- Using postgresql
import psycopg2
db_conn = psycopg2.connect(**DATABASE_CONF)

def __create_database(db_conn):
   db_cursor = db_conn.cursor()
   db_cursor.execute(
   f"""
   CREATE EXTENSION dblink;
   DO
   $do$
   BEGIN
      IF NOT EXISTS (SELECT FROM pg_database WHERE datname = '{DATABASE_NAME}') THEN
         PERFORM dblink_exec('', 'CREATE DATABASE {DATABASE_NAME}');
      END IF;
   END
   $do$;
   """)

__create_database(db_conn)
db_conn.close()

# Reconnect with new DB
db_conn = psycopg2.connect(**DATABASE_CONF, database = DATABASE_NAME)

In [84]:
# Setup tables

def __write(db_conn, query: str, *args):
    db_cursor = db_conn.cursor()
    db_cursor.execute(query, [*args])
    db_conn.commit()
    return db_cursor

__queries = [
    f"CREATE SCHEMA IF NOT EXISTS {DATABASE_SCHEMA}",
    f"""CREATE TABLE IF NOT EXISTS {DATABASE_SCHEMA}.flowers_img(
        img_idx SERIAL PRIMARY KEY,
        filename TEXT
    )
    """,
    f"""CREATE TABLE IF NOT EXISTS {DATABASE_SCHEMA}.flowers_vector(
        img_idx INTEGER,
        image_features FLOAT[]
    )
    """,
    f"""CREATE TABLE IF NOT EXISTS {DATABASE_SCHEMA}.config (
        key  TEXT NOT NULL PRIMARY KEY,
        data TEXT DEFAULT '{{}}'
    )
    """,
    f""" INSERT INTO {DATABASE_SCHEMA}.config
        (
            SELECT  *
            FROM    (VALUES('base','{{}}')) as tmp (key, data)
            WHERE   NOT EXISTS ( SELECT 1 FROM {DATABASE_SCHEMA}.config m where m.key = tmp.key )
        )
    """,
    "CREATE EXTENSION IF NOT EXISTS pg_trgm",
    """
    CREATE OR REPLACE FUNCTION cosine_similarity(a double precision[], b double precision[])
    RETURNS double precision AS $body$
    DECLARE
        dot_product double precision;
        norm_a double precision;
        norm_b double precision;
    BEGIN
        dot_product := 0;
        norm_a := 0;
        norm_b := 0;

        FOR i IN 1..array_length(a, 1) LOOP
            dot_product := dot_product + a[i] * b[i];
            norm_a := norm_a + a[i] * a[i];
            norm_b := norm_b + b[i] * b[i];
        END LOOP;

        norm_a := sqrt(norm_a);
        norm_b := sqrt(norm_b);

        IF norm_a = 0 OR norm_b = 0 THEN
            RETURN 0;
        ELSE
            RETURN dot_product / (norm_a * norm_b);
        END IF;
    END;
    $body$ LANGUAGE plpgsql;
    """
]

for query in __queries:
    __write(db_conn, query)

In [5]:
# Predict/ Extract feature from the whole datasets
import os
from PIL import Image
from pathlib import Path

def __write(db_conn, query: str, *args):
    db_cursor = db_conn.cursor()
    db_cursor.execute(query, [*args])
    db_conn.commit()
    return db_cursor

def __save_img(db_conn, filename, features):
    cursor = __write(db_conn, f"INSERT INTO {DATABASE_SCHEMA}.flowers_img(filename) VALUES (%s) RETURNING img_idx;", filename)
    lastrowid = cursor.fetchone()[0]
    __write(db_conn,
            f"INSERT INTO {DATABASE_SCHEMA}.flowers_vector(img_idx, image_features) VALUES (%s, %s)",
            lastrowid,
            psycopg2.extensions.adapt(features))
    return lastrowid

__pathlist = Path(DATASET_PATH).glob('**/*.jpg')

for path in __pathlist:
    path_str        = str(path)
    filename        = os.path.split(path_str)[1]
    image           = preprocess(Image.open(path_str)).unsqueeze(0).to(device)
    image_features  = model.encode_image(image).cpu().detach().numpy().tolist()

    __save_img(db_conn, filename, image_features[0])


In [82]:
# Test & Close DB
db_conn.rollback()
db_cursor = db_conn.cursor()

# Assume you have text features for query
input_text = "Passiflora incarnata flower with the green leaf"
text_input = clip.tokenize([input_text]).to(device)
with torch.no_grad():
    text_features_input = model.encode_text(text_input).squeeze().cpu().detach().numpy().tolist()

# Query similar images based on text features using pg_similarity
query = f"""
SELECT img_idx, cosine_similarity(image_features, %s) as similarity
FROM {DATABASE_SCHEMA}.flowers_vector
ORDER BY similarity DESC;
"""
db_cursor.execute(query, (text_features_input,))
results = db_cursor.fetchall()

# Print results
for idx, similarity in results:
    print(f"Index: {idx}, Similarity Score: {similarity}")

# Close connection
db_conn.close()

[-0.37590086460113525, -0.04755682870745659, 0.1821782886981964, -0.10466793924570084, -0.6141473054885864, 0.08754009008407593, -0.40212705731391907, -0.32866978645324707, -0.41855156421661377, -0.2849782705307007, 0.17707942426204681, 0.24608761072158813, -0.28224003314971924, 0.05725375935435295, 0.03153083473443985, -0.42762166261672974, -0.08872921764850616, -0.15684349834918976, -0.31721365451812744, -0.09965786337852478, 0.2473062127828598, -0.021698379889130592, 0.18524593114852905, 0.3159913718700409, -0.07523111253976822, 0.6008061766624451, -0.15752945840358734, 0.010737289674580097, 0.13770167529582977, 0.43688443303108215, -0.016057290136814117, 0.09804243594408035, -0.21789129078388214, -0.40454578399658203, -0.12875154614448547, -0.46163320541381836, 0.5724146962165833, 0.09110977500677109, -0.1538224220275879, -0.16101771593093872, 0.3638114631175995, 0.29600146412849426, 0.3804880678653717, 0.029980197548866272, 0.21667860448360443, -0.07441286742687225, -0.45396113395