### Import Libraries

In [1]:
import matplotlib.pyplot as plt

from clip_embedder import CLIPEmbedder
from tiny_imagenet_db import Image, load_tiny_imagenet, search_images
from app import generate_app

# Enable inline plotting for Jupyter Notebook
%matplotlib inline

### Configuration

In [2]:
# Define the model and dataset configurations
MODEL_NAME = "openai/clip-vit-base-patch32"  # Pre-trained CLIP model
DATA_SPLIT = "valid"                         # Dataset split to use (train/valid)
TABLE_NAME = "image_search"                  # Table name for LanceDB

### Load Dataset and Define Pipeline

In [None]:
# Load the Tiny-ImageNet dataset for the specified split
dataset = load_tiny_imagenet(DATA_SPLIT, verbose=True)

# Initialize the CLIP pipeline for embedding generation
pipeline = CLIPEmbedder(MODEL_NAME)

### Embed Images and Preprocess Dataset

In [None]:
def map_embed_image(batch: dict) -> dict:
    """Generate image embeddings for a batch of images."""
    embeddings = pipeline.embed_image(batch["image"])
    return {"vector": embeddings}


# Apply the embedding function to the dataset
processed_dataset = dataset.map(map_embed_image, batched=True, batch_size=128)

# Display the processed dataset information
print(processed_dataset)

### Create LanceDB Table

In [None]:
# Create a LanceDB table from the processed dataset
table = Image.create_table(TABLE_NAME, processed_dataset)

# Display the first few rows of the table
display(table.head().to_pandas())

### Search and Visualize Images

In [None]:
# Define a text query for image search
text_query: str = "fish"

# Test the search_images function
retrieved_images = search_images(pipeline, table, text_query, verbose=True)

# Visualize the retrieved images in a 3x3 grid
fig, axes = plt.subplots(3, 3, figsize=(6, 6))  # Adjust figsize for better display
for ax, img in zip(axes.ravel(), retrieved_images):
    ax.imshow(img)
    ax.axis('off')
plt.tight_layout()
plt.show()

### Launch Gradio App

In [None]:
demo = generate_app(pipeline, table)

In [None]:
demo.close()