In [None]:
import sys
sys.path.append(r'c:\Users\ice\projects\iris')

from iris.config.data_pipeline_config_manager import DataPipelineConfigManager
from iris.data_pipeline.qdrant_manager import QdrantManager
import numpy as np
from sklearn.manifold import TSNE
from PIL import Image
import base64
from io import BytesIO

In [None]:
# Initialize Qdrant configuration and manager
qdrant_config = DataPipelineConfigManager().qdrant_config
qdrant_manager = QdrantManager(qdrant_config)

In [None]:
# Fetch all points from the collection 
with qdrant_manager as qm:
    batch = qm._client.scroll(
        collection_name=qdrant_config.image_collection,
        limit=1000,  # Adjust based on your collection size
        with_vectors=True
    )[0]  # [0] gets points, [1] gets next offset
    
    points = batch

# Extract vectors and create a dictionary for point types
vectors = np.array([point.vector for point in points])

# Define colors for different point types
type_colors = {
    'full_image': '#1f77b4',  # Blue
    'localization': '#ff7f0e'  # Orange
}

# Create color list based on point types
node_colors = [type_colors[point.payload['type']] for point in points]

In [None]:
# Use t-SNE to reduce dimensionality to 2D
tsne = TSNE(random_state=42)
vectors_2d = tsne.fit_transform(vectors)

# Print shape to verify the transformation
print(f"Original shape: {vectors.shape}")
print(f"Transformed shape: {vectors_2d.shape}")

In [None]:
from dash import Dash, html, dcc, Input, Output, State, no_update, callback
import plotly.graph_objects as go
from PIL import Image
import base64
from io import BytesIO

# Create Dash app
app = Dash(__name__)

# Define app layout
app.layout = html.Div([
    html.H1("Vector Space Visualization"),
    
    # Main scatter plot
    html.Div([
        dcc.Graph(
            id='vector-space-plot',
            figure={
                'data': [
                    go.Scatter(
                        x=vectors_2d[:, 0],
                        y=vectors_2d[:, 1],
                        mode='markers',
                        marker=dict(
                            color=node_colors,
                            size=8
                        ),
                        hoverinfo='none',  # Disable default hover info
                        hovertemplate=None  # Required for custom tooltip
                    )
                ],
                'layout': go.Layout(
                    title='2D Vector Space',
                    hovermode='closest'
                )
            },
            clear_on_unhover=True
        ),
        dcc.Tooltip(id='graph-tooltip'),
    ]),
    
    # Image display area for clicked points
    html.Div([
        html.H3("Selected Point Details", id='point-details'),
        html.Img(id='selected-image', style={'maxWidth': '300px'}),
        html.Pre(id='point-info')
    ], style={'margin': '20px'})
])

# Callback for tooltips
@callback(
    Output('graph-tooltip', 'show'),
    Output('graph-tooltip', 'bbox'),
    Output('graph-tooltip', 'children'),
    Input('vector-space-plot', 'hoverData'),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update
    
    # Get point index
    point_idx = hoverData['points'][0]['pointIndex']
    point = points[point_idx]
    bbox = hoverData['points'][0]['bbox']
    
    # Load and convert image
    try:
        with Image.open(point.payload['local_path']) as img:
            if point.payload['type'] == 'localization':
                # Crop image if it's a localization
                bbox_coords = point.payload['bbox']
                width, height = img.size
                img = img.crop((
                    bbox_coords[0] * width,
                    bbox_coords[1] * height,
                    (bbox_coords[0] + bbox_coords[2]) * width,
                    (bbox_coords[1] + bbox_coords[3]) * height
                ))
            
            # Resize for tooltip
            img.thumbnail((150, 150))
            
            # Convert to base64
            buffered = BytesIO()
            img.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue()).decode()
            
        children = [
            html.Div([
                html.Img(
                    src=f'data:image/jpeg;base64,{img_str}',
                    style={'width': '150px', 'display': 'block', 'margin': '0 auto'}
                ),
                html.P(
                    f"ID: {point.id} ({point.payload['type']})",
                    style={'text-align': 'center', 'margin': '5px'}
                )
            ])
        ]
        
        return True, bbox, children
    
    except Exception as e:
        return False, no_update, no_update

# Original click callback remains unchanged
@callback(
    [Output('selected-image', 'src'),
     Output('point-details', 'children'),
     Output('point-info', 'children')],
    [Input('vector-space-plot', 'clickData')],
    prevent_initial_call=True
)
def update_image(clickData):
    if not clickData:
        return None, "Select a point", "No point selected"
    
    # Get point index
    point_idx = clickData['points'][0]['pointIndex']
    point = points[point_idx]
    
    # Load and convert image
    try:
        with Image.open(point.payload['local_path']) as img:
            if point.payload['type'] == 'localization':
                # Crop image if it's a localization
                bbox = point.payload['bbox']
                width, height = img.size
                img = img.crop((
                    bbox[0] * width,
                    bbox[1] * height,
                    (bbox[0] + bbox[2]) * width,
                    (bbox[1] + bbox[3]) * height
                ))
            
            # Convert to base64
            buffered = BytesIO()
            img.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue()).decode()
            
        title = f"Point {point.id} ({point.payload['type']})"
        info = f"Type: {point.payload['type']}\nID: {point.id}"
        if point.payload['type'] == 'localization':
            info += f"\nParent Image: {point.payload['parent_image_hash']}"
        
        return f'data:image/jpeg;base64,{img_str}', title, info
    
    except Exception as e:
        return None, "Error loading image", str(e)


In [None]:
# Start the server
if __name__ == '__main__':
    app.run(debug=True, use_reloader=False)