# 1 - Imports

In [None]:
import warnings
warnings.filterwarnings("ignore")

# Load variables
import os
from dotenv import load_dotenv
load_dotenv()

# Snowpark Imports
from snowflake.snowpark.session import Session
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.snowpark.window import Window

# Snowflake Python API
from snowflake.core import Root
from snowflake.core.database import Database
from snowflake.core.warehouse import Warehouse
from snowflake.core.service import Service, ServiceSpecStageFile
from snowflake.core.compute_pool import ComputePool
from snowflake.core.image_repository import ImageRepository

# Other
from PIL import Image
from pprint import pprint
from datasets import load_dataset
import concurrent.futures
import io
import itertools
import threading
from plotting.image_plotting import plot_similar_images, plot_image_grid
from plotting.image_cluster_plotting import visualize_image_clusters
from api_calls.embedding_service import get_embedding

# 2 - Connect to Snowflake

In [None]:
snowflake_connection_cfg = {
    "ACCOUNT": os.getenv('SF_ACCOUNT'),
    "USER": os.getenv('SF_USER'),
    "ROLE": os.getenv('SF_ROLE'),
    "PASSWORD": os.getenv('SF_PASSWORD'),
}

# Creating Snowpark Session
session = Session.builder.configs(snowflake_connection_cfg).create()

# 3 - Set Up Environment

In [None]:
# Create Database & Schema
root = Root(session)
demo_db = Database(name="REVERSE_IMAGE_SEARCH")
demo_db = root.databases.create(demo_db, mode='if_not_exists')

# Create warehouse
wh = Warehouse(name="COMPUTE_WH", warehouse_size="XSMALL", auto_suspend=600, auto_resume='true')
warehouses = root.warehouses
wh = warehouses.create(wh, mode='if_not_exists')

# Set context
session.use_schema('REVERSE_IMAGE_SEARCH.PUBLIC')
session.use_warehouse('COMPUTE_WH')

# Create a Snowflake Stage for Images
session.sql("""CREATE STAGE IF NOT EXISTS IMAGES
                DIRECTORY = (ENABLE = TRUE AUTO_REFRESH = FALSE) 
                ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE') 
                COMMENT='Stage to store Image Files'""").collect()

# Create a Snowflake Stage for Container Files (Spec-Files & Models)
session.sql("""CREATE STAGE IF NOT EXISTS CONTAINER_FILES
                DIRECTORY = (ENABLE = TRUE AUTO_REFRESH = FALSE) 
                ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE') 
                COMMENT='Stage to store Container Files'""").collect()

# Create a an External Access Integration (to download models from HuggingFace)
session.sql("""CREATE OR REPLACE NETWORK RULE hf_rule
                MODE= 'EGRESS'
                TYPE = 'HOST_PORT'
                VALUE_LIST = (
                    'huggingface.co',
                    'cdn-lfs-us-1.huggingface.co',
                    'cdn-lfs.huggingface.co')""").collect()

session.sql("""CREATE OR REPLACE EXTERNAL ACCESS INTEGRATION hf_integration
                ALLOWED_NETWORK_RULES = (hf_rule)
                ENABLED = true""").collect()

# Create an Image Repository
new_image_repository = ImageRepository(name="image_repository")
image_repositories = root.databases["REVERSE_IMAGE_SEARCH"].schemas["PUBLIC"].image_repositories
my_image_repo = image_repositories.create(new_image_repository, mode='if_not_exists')

# Create a compute pool for the Image Embedding Model
compute_pool_def = ComputePool(
    name="GPU_POOL",
    instance_family="GPU_NV_S",
    min_nodes=1,
    max_nodes=1
)
my_compute_pool = root.compute_pools.create(compute_pool_def, mode='if_not_exists')

# Create a compute pool for the Streamlit App
compute_pool_def2 = ComputePool(
    name="CPU_POOL",
    instance_family="CPU_X64_XS",
    min_nodes=1,
    max_nodes=1
)
my_compute_pool2 = root.compute_pools.create(compute_pool_def2, mode='if_not_exists')

# 4 - Create the Embedding Service

In [None]:
session.file.put('spcs/embedding_service/container_spec.yml', stage_location='@CONTAINER_FILES', overwrite=True, auto_compress=False)

service_def = Service(
    name="EMBEDDING_SERVICE",
    compute_pool="GPU_POOL",
    spec=ServiceSpecStageFile(spec_file='container_spec.yml', stage='CONTAINER_FILES'),
    min_instances=1,
    max_instances=1,
    external_access_integrations=['HF_INTEGRATION']
)

embedding_service = demo_db.schemas['PUBLIC'].services.create(service_def, mode='if_not_exists')

# Create a Function to call the embedding function
session.sql("""CREATE OR REPLACE FUNCTION GENERATE_IMAGE_EMBEDDING(IMAGE_URL TEXT)
                RETURNS ARRAY
                SERVICE = EMBEDDING_SERVICE
                ENDPOINT=API
                AS '/generate-embeddings'""").collect()

# Get the Service Status
print('SERVICE STATUS:')
pprint(embedding_service.get_service_status())

# Get the Service Logs
print('\nSERVICE LOGS:')
print(embedding_service.get_service_logs(container_name='dinov2-base-service-container', instance_id='0'))

# 5 - Upload Images to Snowflake

We will use the `ceyda/fashion-products-small` dataset from Hugging Face with over 42.000 images.  
For demo purposes we sample (10%) and filter the dataset to only include images with the following categories: ['Footwear']  
The following cell will download the dataset, sample it, and then upload the files to Snowflake in parallel.

In [None]:
dataset = load_dataset("ceyda/fashion-products-small", split='train')
# Sample 10% of the dataset
sampled_dataset = dataset.train_test_split(test_size=0.1)['test']
# Filter to only include Footwear Images
sampled_dataset = sampled_dataset.filter(lambda example: example["masterCategory"] == 'Footwear')
print(f'Sampled {sampled_dataset.num_rows} Images.')

# Create a thread-safe counter
upload_counter = itertools.count(1)  # Starts counting from 1
print_lock = threading.Lock()

thread_snowflake_connection_cfg = snowflake_connection_cfg
thread_snowflake_connection_cfg['DATABASE'] = 'REVERSE_IMAGE_SEARCH'
thread_snowflake_connection_cfg['SCHEMA'] = 'PUBLIC'
thread_snowflake_connection_cfg['WAREHOUSE'] = 'COMPUTE_WH'

# Function to upload data using a provided session
def upload_data(row, session):
    """Saves each image to a Snowflake Stage"""
    file_name = row['filename']
    img_byte_arr = io.BytesIO()
    row['image'].save(img_byte_arr, format='JPEG')
    session.file.put_stream(input_stream=img_byte_arr, stage_location=f'@IMAGES/{file_name}', auto_compress=False, overwrite=False)

    # Thread-safe increment of the counter
    with print_lock:
        current_count = next(upload_counter)
        if current_count % 100 == 0:
            print(f'{current_count} images from {sampled_dataset.num_rows} uploaded ...')
    return file_name

# Wrapper function to handle the creation of session and the actual uploading
# Snowpark Session are not thread-safe, so creating one session per worker
def upload_with_session(row):
    with Session.builder.configs(thread_snowflake_connection_cfg).create() as session:
        return upload_data(row, session)

# Function to parallelize uploads with session management
def parallel_upload(dataset):
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        # Map upload_with_session to each row of the dataset
        results = list(executor.map(upload_with_session, dataset))
    return results

# Execute the parallel upload
upload_results = parallel_upload(sampled_dataset)
# Refresh the stage to register new files
session.sql('ALTER STAGE IMAGES REFRESH').show()

print(f'Uploaded {len(upload_results)} files.')

# 6 - Generate Image Embeddings

In [None]:
# Create a DataFrame with Images
images_df = session.sql("SELECT * FROM DIRECTORY('@IMAGES')")

# Generate an accessible URL and apply the embedding function on images
images_df = images_df.with_column('PRESIGNED_URL', F.call_builtin('GET_PRESIGNED_URL', '@IMAGES', F.col('RELATIVE_PATH')))
images_df = images_df.with_column('IMAGE_EMBEDDING', F.call_builtin('GENERATE_IMAGE_EMBEDDING', F.col('PRESIGNED_URL')).cast(T.VectorType(float,768)))
images_df.write.save_as_table('FASHION_IMAGES_EMBEDDINGS', mode='overwrite')

images_df = session.table('FASHION_IMAGES_EMBEDDINGS')
images_df.show()

In [None]:
# Using a helper function, we can easily visualize image clusters
visualize_image_clusters(images_df[['PRESIGNED_URL','RELATIVE_PATH','IMAGE_EMBEDDING']].sample(n=50).to_pandas())

# 7 - Calculate the Similarity between all Images

In [None]:
RETURN_TOP_N = 5 # number of similar images to return per image
window = Window.partition_by(['RELATIVE_PATH_LEFT']).order_by(F.col('COSINE_SIMILARITY').desc())
crossjoin_images_df = images_df.cross_join(images_df, rsuffix='_RIGHT', lsuffix='_LEFT')
crossjoin_images_df = crossjoin_images_df.with_column('COSINE_SIMILARITY', F.call_builtin('VECTOR_COSINE_SIMILARITY', F.col('IMAGE_EMBEDDING_LEFT'), F.col('IMAGE_EMBEDDING_RIGHT')))
crossjoin_images_df = crossjoin_images_df.select('RELATIVE_PATH_LEFT','RELATIVE_PATH_RIGHT','COSINE_SIMILARITY','PRESIGNED_URL_RIGHT','PRESIGNED_URL_LEFT')
crossjoin_images_df = crossjoin_images_df.with_column('ROW', F.row_number().over(window)).filter(F.col('ROW')<=RETURN_TOP_N)
crossjoin_images_df = crossjoin_images_df.with_column('COSINE_SIMILARITY', F.round('COSINE_SIMILARITY', 2))
crossjoin_images_df.show()

In [None]:
# sample N images for visualization
SAMPLE_SIZE = 5
sample_images = [row['RELATIVE_PATH_LEFT'] for row in crossjoin_images_df.select('RELATIVE_PATH_LEFT').distinct().sample(n=SAMPLE_SIZE).collect()]
viz_images = crossjoin_images_df.filter(F.col('RELATIVE_PATH_LEFT').in_(sample_images))
# Regenerate URLs in case they are not valid anymore
viz_images = viz_images.with_column('PRESIGNED_URL_LEFT', F.call_builtin('GET_PRESIGNED_URL', '@IMAGES', F.col('RELATIVE_PATH_LEFT')))
viz_images = viz_images.with_column('PRESIGNED_URL_RIGHT', F.call_builtin('GET_PRESIGNED_URL', '@IMAGES', F.col('RELATIVE_PATH_RIGHT')))
df = viz_images.to_pandas()
# Visualize the similar images
plot_similar_images(df, RETURN_TOP_N=RETURN_TOP_N)

# 8 - Query with given image (Reverse Image Search)

In [None]:
filename = 'sample_images/sneaker1.jpg'
print('Query Image:')
display(Image.open(filename).resize((100,100)))
# Retrieve the embedding
embedding = get_embedding(session=session, filename=filename)
# Search image database
search_df = images_df.with_column('QUERY_EMBEDDING', F.lit(embedding[0].tolist()).cast(T.VectorType(float,768)))
search_df = search_df.with_column('COSINE_SIMILARITY', F.call_builtin('VECTOR_COSINE_SIMILARITY', F.col('QUERY_EMBEDDING'), F.col('IMAGE_EMBEDDING')))
search_df = search_df.order_by(F.col('COSINE_SIMILARITY').desc())
search_df = search_df.with_column('COSINE_SIMILARITY', F.round('COSINE_SIMILARITY', 2))
search_df = search_df.with_column('PRESIGNED_URL', F.call_builtin('GET_PRESIGNED_URL', '@IMAGES', F.col('RELATIVE_PATH'))).cache_result()
search_df[['RELATIVE_PATH','COSINE_SIMILARITY']].show(5)
# Visualize the similar images
plot_image_grid(search_df.limit(10).to_pandas())

# 9 - Create a Search Service App

In [None]:
session.sql('DROP SERVICE SEARCH_SERVICE').show()

session.file.put('spcs/streamlit_app/streamlit_container_spec.yml', stage_location='@CONTAINER_FILES', overwrite=True, auto_compress=False)

search_service_def = Service(
    name="SEARCH_SERVICE",
    compute_pool="CPU_POOL",
    spec=ServiceSpecStageFile(spec_file='streamlit_container_spec.yml', stage='CONTAINER_FILES'),
    min_instances=1,
    max_instances=1
)

search_service = demo_db.schemas['PUBLIC'].services.create(search_service_def, mode='if_not_exists')

In [None]:
endpoint = session.sql('SHOW ENDPOINTS IN SERVICE SEARCH_SERVICE').collect()[0]['ingress_url']
if endpoint.startswith('Endpoints provisioning in progress...'):
    print(endpoint)
else:
    print('URL to Streamlit App:')
    print(f"https://{endpoint}")