In [1]:
# ! pip install jupyter_dash dash

### Dashによるホバーしたら対応する画像が表示される散布図

Jupyter Notebook上で動作可能  
ただし、一度起動した後はカーネル再起動必須（原因不明）

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import cv2
import base64
import random
import socket
import pandas as pd
from pathlib import Path
from PIL import Image
from io import BytesIO

import plotly.express as px
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
from jupyter_dash import JupyterDash

def get_host_ip():
    connect_interface = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    connect_interface.connect(("8.8.8.8", 80))
    ip = connect_interface.getsockname()[0]
    connect_interface.close()
    return ip

def plot_scatter_by_dash(df, n_x, n_y, n_class=None, n_imgPaths="画像のパス", title=None, show_height=500, host=None, port=None):
    def image_path_to_b64(image_path, img_resize_height=150):
        # 画像ファイルを開く
        im_pil = Image.open(image_path)

        # RGB形式に変換（もし必要な場合）
        if im_pil.mode != 'RGB':
            im_pil = im_pil.convert('RGB')

        # 処理を軽くするためにリサイズ
        im_pil = im_pil.resize((int(im_pil.width * img_resize_height / im_pil.height), img_resize_height))

        buff = BytesIO()
        im_pil.save(buff, format='png')  # PNG形式で保存
        im_b64 = base64.b64encode(buff.getvalue()).decode('utf-8')  # base64エンコード

        return im_b64
    
    # 散布図の作成
    fig = px.scatter(df, x=n_x, y=n_y, color=n_class, title=title, width=show_height, height=show_height)

    # JupyterDashアプリケーションの初期化
    app = JupyterDash(__name__)

    app.layout = html.Div([
                        html.Div([
                            dcc.Graph(id="fig1", figure=fig),
                            html.Div(id="output_img", style={"margin-left": "10px"}),
                        ], style={"display" : "flex"}),
                        html.Div(id="output_text")
    ], style={'background-color': 'white'})

    # コールバックを定義し、ホバー時の動作を設定
    @app.callback(
        Output('output_img', 'children'),
        [Input('fig1', 'hoverData')])
    def display_image(hoverData):
        contents = []
        if hoverData:
            idx = hoverData['points'][0]['pointIndex']
            image_path = df.iloc[idx][n_imgPaths]
            im_b64 = image_path_to_b64(image_path)
            image_src = 'data:image/png;base64,{}'.format(im_b64)
            contents += [html.Img(src=image_src, height=f'{int(show_height*0.8)}px')]  # 画像表示
        else:
            contents += [html.Img(src=None, height=f'{int(show_height*0.8)}px')]  # 画像表示
        return html.Div(contents)
    
    @app.callback(
        Output('output_text', 'children'),
        [Input('fig1', 'hoverData')])
    def display_image_info(hoverData):
        contents = []
        text_style = {'line-height': '1'}  # 行間を狭めるために追加
        if hoverData:
            idx = hoverData['points'][0]['pointIndex']
            image_path = df.iloc[idx][n_imgPaths]
            img = cv2.imread(str(image_path))
            contents += [html.Div(f"shape : {img.shape[0]}x{img.shape[1]}", style=text_style)]
            contents += [html.Div(f"Path  : {image_path}", style=text_style)]
        else:
            contents += [html.Div("shape : ", style=text_style)]
            contents += [html.Div("Path  : No image selected.", style=text_style)]
        return html.Div(contents)

    # アプリケーションを実行（Jupyter環境での表示に適しています）
    app.run_server(mode='inline')
    
    # if host is None:
    #     host = get_host_ip()
    # if port is None:
    #     port = 20000
    
    # app.run_server(mode='external', host=host, port=port, debug=True)

    return None


In [3]:
def main():
    # CSVファイルからデータを読み込む代わりに擬似的なデータフレームを作成
    item_num = 100

    df = pd.DataFrame({
        'X座標':     [random.randint(0, 100) for _ in range(item_num)],
        'Y座標':     [random.randint(0, 100) for _ in range(item_num)],
        '画像のパス':  random.sample(list(Path('./data/sample_coco_train2017/').glob('**/*.jpg')), item_num),
        'color':     [random.choice(["A","B","C","D","E",]) for _ in range(item_num)],
    })

    plot_scatter_by_dash(df, n_x='X座標', n_y='Y座標', n_class='color', n_imgPaths="画像のパス", title='Scatter Plot')

if __name__ == '__main__':
    main()



JupyterDash is deprecated, use Dash instead.
See https://dash.plotly.com/dash-in-jupyter for more details.

