In [None]:
from sqlalchemy import select, text, create_engine
from sqlalchemy.orm import Session
from gorillatracker.ssl_pipeline.models import TrackingFrameFeature, Tracking, Video, Camera
from gorillatracker.ssl_pipeline.dataset import GorillaDatasetKISZ
import importlib
from typing import List, Literal, Optional, Tuple

from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms

import gorillatracker.type_helper as gtypes
from gorillatracker.transform_utils import SquarePad
from gorillatracker.type_helper import Id, Label
from gorillatracker.utils.labelencoder import LabelEncoder

from gorillatracker.data.nlet import FlatNletBuilder
from gorillatracker.data.nlet import NletDataModule, SupervisedDataset, build_onelet
from gorillatracker.utils.embedding_generator import generate_embeddings, df_from_predictions
from gorillatracker.utils.wandb_loader import get_model_for_run_url, load_model
from gorillatracker.model import VisionTransformerWrapper

from pathlib import Path
import gorillatracker.utils.embedding_generator as eg
import pandas as pd
from io import BytesIO
import base64
from sklearn.manifold import TSNE
import numpy as np
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, LinearColorMapper, HoverTool, WheelZoomTool
from bokeh.io import output_notebook
from bokeh.resources import INLINE
import colorcet as cc
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
import ipywidgets as widgets
from bokeh.models import Slider, CustomJS, Div
from bokeh.layouts import column
from bokeh.models import Text
from sqlalchemy.orm import sessionmaker
import geopy.distance

output_notebook(INLINE)

sample = 10
query = text(
    f"""WITH ranked_features AS (
    SELECT
        tracking_id,
        tracking_frame_feature_id,
        bbox_width,
        bbox_height,
        frame_nr,
        feature_type,
        ROW_NUMBER() OVER (PARTITION BY tracking_id ORDER BY RANDOM()) AS rn
    FROM tracking_frame_feature
    WHERE feature_type = 'face_45'
        AND bbox_width >= 184
        AND bbox_height >= 184
        AND tracking_id IS NOT NULL
)
SELECT
    tracking_frame_feature_id
FROM ranked_features
WHERE rn <= {sample}
"""
)

engine = create_engine(GorillaDatasetKISZ.DB_URI)

In [None]:
# devcontainer.mount add "source=/mnt/vast-gorilla/cropped-images,target=/workspaces/gorillatracker/cropped-images,type=bind,ro"
base_path = "../video_data/cropped-images/2024-04-18"

def cast_label_to_int(labels: List[str]) -> List[int]:
    return LabelEncoder.encode_list(labels)


class HackDataset(Dataset[Tuple[Id, Tensor, Label]]):
    def get_tffs(self) -> list[TrackingFrameFeature]:
        engine = create_engine(GorillaDatasetKISZ.DB_URI)

        stmt = select(TrackingFrameFeature, Video.video_id, Tracking.tracking_id).join(
            Tracking, Tracking.tracking_id == TrackingFrameFeature.tracking_id).join(
            Video, Tracking.video_id == Video.video_id).where(
            TrackingFrameFeature.tracking_frame_feature_id.in_(query)).where(
            Video.start_time >= "2023-01-01 00:00:00")

        with Session(engine) as session:
            return session.execute(stmt).all()

    def __init__(
        self, data_dir: str, partition: Literal["train", "val", "test"], nlet_builder: FlatNletBuilder, transform: Optional[gtypes.Transform] = None
    ):
        self.tffs = self.get_tffs()
        self.transform = transform
        self.partition = partition
        self.video_ids = [tff.video_id for tff in self.tffs]
        self.tracking_ids = [tff.tracking_id for tff in self.tffs]

    def __len__(self) -> int:
        return len(self.tffs)

    def __getitem__(self, idx: int) -> Tuple[Id, Tensor, Label]:
        """tracklets will be labels for now"""
        tff = self.tffs[idx][0]

        img = Image.open(tff.cache_path(base_path))
        if self.transform:
            img = self.transform(img)
        assert tff.tracking_id is not None
        return tff.tracking_frame_feature_id, img, tff.tracking_id
    
    @classmethod
    def get_transforms(cls) -> gtypes.Transform:
        return transforms.Compose(
            [
                SquarePad(),
                # Uniform input, you may choose higher/lower sizes.
                transforms.Resize(256),
                transforms.ToTensor(),
            ]
        )

In [None]:
# python3 train.py --config_path cfgs/swinv2_cxl.yml

# gorillas/Embedding-SwinV2Base-CXL-Open/model-yl2lx567:v12
if False:
    run_url = "https://wandb.ai/gorillas/Embedding-SwinV2Large-CXL-Open/runs/olql2abq?nw=nwuseremirhan404"
    model = get_model_for_run_url(run_url)
else:
    model = load_model(VisionTransformerWrapper, "/workspaces/gorillatracker/models/vit_large_dinov2_baseline.ckpt")
model_transforms = transforms.Compose(
    [
        SquarePad(),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [None]:
importlib.reload(eg)
# python3 train.py --config_path cfgs/swinv2_cxl.yml

# gorillas/Embedding-SwinV2Base-CXL-Open/model-yl2lx567:v12
data_dir = "/workspaces/gorillatracker/data/supervised/splits/cxl_faces_square_all-in-val"
data_module = NletDataModule(
    data_dir=Path(data_dir),
    dataset_class=SupervisedDataset,
    nlet_builder=build_onelet,
    batch_size=32,
    workers=10,
    model_transforms=model_transforms,
    training_transforms=lambda x: x,
    dataset_names=["Inference"],
)
print("test1")
data_module.setup("validate")
print("test2")
dataloader = data_module.val_dataloader()
print("test3")
predictions = eg.generate_embeddings(model, dataloader)
print("test4")
print(len(predictions[0]), len(predictions[1]), len(predictions[2]))
df_sup = eg.df_from_predictions(predictions)
print("test5")
df_sup.to_pickle("embeddings.pkl")

In [None]:
data_dir = "/workspaces/gorillatracker/video_data/cropped-images/2024-04-18"

data_module = NletDataModule(
    data_dir=Path(data_dir),
    dataset_class=HackDataset,
    nlet_builder=build_onelet,
    batch_size=32,
    workers=10,
    model_transforms=model_transforms,
    training_transforms=lambda x: x,
    dataset_names=["Inference"],
)
print("test1")
data_module.setup("validate")
print("test2")
dataloader = data_module.val_dataloader()
print("test3")
predictions = generate_embeddings(model, dataloader)
print("test4")
df = df_from_predictions(predictions)
print("test5")
df.to_pickle("ssl_embeddings.pkl")

In [None]:
df = pd.read_pickle("ssl_embeddings.pkl")
display(df.head())

df_sup = pd.read_pickle("embeddings.pkl")
display(df_sup.head())

In [None]:
tff_ids = df["id"].tolist()

# Ensure tff_ids are scalars if they are tensor objects
tff_ids = [t.item() for t in tff_ids]

# Create a mapping of tff_ids to their positions
id_order = {id: index for index, id in enumerate(tff_ids)}

print(tff_ids)

# Create the SQL statement
stmt = (
    select(TrackingFrameFeature.tracking_frame_feature_id, TrackingFrameFeature.video_id, TrackingFrameFeature.tracking_id)
    .where(TrackingFrameFeature.tracking_frame_feature_id.in_(tff_ids))
)

# Execute the query
with Session(engine) as session:
    result = session.execute(stmt).all()

# Unpack the results and order them based on the original tff_ids list
ordered_result = sorted(result, key=lambda x: id_order[x[0]])

# Unzip the ordered result into separate lists
tff_ids2, video_ids, tracking_ids = zip(*ordered_result)

print(tff_ids2)

In [None]:
embeddings = np.vstack(df["embedding"].values)
print(len(embeddings))
labels = df["label"].values
# tracking_ids = LinearSequenceEncoder().encode_list(tracking_ids)

images = []
raw_images = []
for tff_id in tff_ids:
    path = Path(
        base_path,
        str(tff_id % 2**8),
        str(tff_id % 2**16),
        f"{tff_id}.png",
    )
    image = Image.open(path)
    raw_images.append(image)
    buffer = BytesIO()
    image.save(buffer, format="JPEG")
    image_byte = base64.b64encode(buffer.getvalue()).decode("utf-8")
    images.append(image_byte)

embeddings_sup = np.vstack(df_sup["embedding"].values)
print(len(embeddings_sup))
labels_sup = df_sup["label"].values
labels_sup = [x.item() for x in labels_sup]
label_strings_sup = df_sup["label_string"].values
images_sup = []
raw_images_sup = []
for id in df_sup["id"]:
    img = Image.open(id)
    raw_images_sup.append(img)
    buffer = BytesIO()
    img.save(buffer, format="JPEG")
    image_byte = base64.b64encode(buffer.getvalue()).decode("utf-8")
    images_sup.append(image_byte)

algo = TSNE(n_components=2, perplexity=30)

all_embeddings = np.vstack([embeddings, embeddings_sup])
low_dim_em_all = algo.fit_transform(all_embeddings)
low_dim_em = low_dim_em_all[:len(embeddings)]
low_dim_em_sup = low_dim_em_all[len(embeddings):]

In [None]:
fig = figure(tools="pan, wheel_zoom, box_zoom, reset")

cds = ColumnDataSource(data={'x': low_dim_em[:, 0], 'y': low_dim_em[:, 1], 'video_id': video_ids, 'tracking_id': tracking_ids, 'image': images})

exp_cmap = LinearColorMapper(palette=cc.glasbey, 
                             low = min(tracking_ids), 
                             high = max(tracking_ids))
fig.scatter(
    source=cds,
    size=12,
    line_color="black",
    fill_color={"field": "tracking_id", "transform": exp_cmap},
)
fig.toolbar.active_scroll = fig.select_one(WheelZoomTool)

hover = HoverTool(tooltips='<img src="data:image/jpeg;base64,@image" width="128" height="128">')
fig.add_tools(hover)
show(fig)

In [None]:
fig = figure(tools="pan, wheel_zoom, box_zoom, reset")

cds_sup = ColumnDataSource(data={'x': low_dim_em_sup[:, 0], 'y': low_dim_em_sup[:, 1], 'label': labels_sup, 'label_str': label_strings_sup, 'image': images_sup})

exp_cmap_sup = LinearColorMapper(palette=cc.glasbey, 
                             low = min(labels_sup), 
                             high = max(labels_sup))
fig.scatter(
    source=cds_sup,
    size=12,
    line_color="black",
    fill_color={"field": "label", "transform": exp_cmap_sup},
)
fig.toolbar.active_scroll = fig.select_one(WheelZoomTool)

fig.scatter(
    source=cds,
    size=12,
    line_color="black",
    fill_color={"field": "tracking_id", "transform": exp_cmap},
    alpha=0.2
)

hover = HoverTool(tooltips="""<div>
                    <img src="data:image/jpeg;base64,@image" width="128" height="128">
                    <div><strong>Label:</strong> @label_str</div>
                  </div>""")
fig.add_tools(hover)
show(fig)

In [7]:
k = 5
knn = KNeighborsClassifier(n_neighbors=k)

knn.fit(embeddings_sup, label_strings_sup)

distance_threshold = 20.0

# Get the distances and indices of the nearest neighbors
distances, indices = knn.kneighbors(embeddings)

# Classify the new embeddings
predictions = knn.predict(embeddings)

# Introduce the threshold distance condition
valid_predictions = []
for i, (dist, pred) in enumerate(zip(distances, predictions)):
    if dist[0] > distance_threshold:
        valid_predictions.append(-1)  # -1 for invalid label
    else:
        valid_predictions.append(pred)

display("Valid Predictions:", valid_predictions)

display("Predictions:", predictions)

'Valid Predictions:'

['US60',
 'VI40',
 'RC40',
 'GN00',
 'AP02',
 'ES02',
 'VI40',
 'VI40',
 'GA03',
 'b0',
 'OE40',
 'AP02',
 'AP00',
 'ES03',
 'HU21',
 'GN02',
 'NN20',
 'VI01',
 'VI01',
 'GA03',
 'VI40',
 'JZ20',
 'VI40',
 'GA40',
 'VI01',
 'AP02',
 'ES03',
 'PL60',
 'RX00',
 'GA01',
 'VI42',
 'RC01',
 'ES02',
 'VI40',
 'AP03',
 'GA03',
 'NN42',
 'DU40',
 'AP01',
 'DU40',
 'GA02',
 'VI41',
 'VI41',
 'GN01',
 'HU00',
 'TU03',
 'VI01',
 'TU05',
 'VI01',
 'GA03',
 'AP61',
 'RC40',
 'AP00',
 'AP00',
 'VI00',
 'NN20',
 'GA02',
 'AP00',
 'VI41',
 'PL01',
 'VI42',
 'ME01',
 'b0',
 'GN00',
 'RC42',
 'VI01',
 'VI41',
 'AP03',
 'GA02',
 'AP01',
 'NN40',
 'AP02',
 'AP00',
 'GN01',
 'TU01',
 'AP00',
 'RC00',
 'GA40',
 'PL01',
 'AP01',
 'AP01',
 'RX00',
 'AP00',
 'ES02',
 'PL45',
 'VI41',
 'VI40',
 'RC41',
 'JZ20',
 'VI01',
 'PL41',
 'VI01',
 'VI40',
 'DU40',
 'DU40',
 'AP00',
 'RX00',
 'ES03',
 'VI01',
 'ES03',
 'GA02',
 'RC00',
 'VI01',
 'VI40',
 'VI40',
 'VI42',
 'RC40',
 'AP02',
 'JZ41',
 'ES01',
 'VI41',
 'VI4

'Predictions:'

array(['US60', 'VI40', 'RC40', ..., 'AP00', 'DU40', 'VI01'], dtype=object)

In [None]:
image_pages = []
for i, pred in enumerate(valid_predictions):
    if pred != -1:
        closest = [(raw_images_sup[indices[i][j]], label_strings_sup[indices[i][j]], distances[i][j]) for j in range(k)]
        image_pages.append((raw_images[i], closest, pred))
        
def display_images(page):
    fig, axs = plt.subplots(1, k+1)  # Create subplots
    if page < len(image_pages): 
        img = image_pages[page][0]
        axs[0].imshow(img)
        axs[0].set_title("Prediction: {}".format(image_pages[page][2]), fontsize=8)
        for i in range(k):
            img, label, dist = image_pages[page][1][i]
            axs[i+1].imshow(img)
            axs[i+1].set_title(str(i) + ". closest image \n dist: " + str(round(dist, 3)) + "\n Label: " + label, fontsize=8)
            axs[i+1].axis("off")
        else:
            axs[i].axis("off")  # Hide axes for empty subplots
    plt.tight_layout()
    plt.show()


page_selector = widgets.IntSlider(min=0, max=(len(image_pages) - 1), description="Page:")
widgets.interact(display_images, page=page_selector)

In [None]:
# Assuming raw_images and valid_predictions are already defined
# raw_images: list of images (e.g., numpy arrays or PIL images)
# valid_predictions: list of class predictions corresponding to raw_images

index_dict = {}

# Create a dictionary to map classes to image indices
for i, cls in enumerate(valid_predictions):
    if cls not in index_dict:
        index_dict[cls] = []
    index_dict[cls].append(i)
    
# Convert the dictionary to a list of tuples for easy indexing
index_list = list(index_dict.items())

# Constants
IMAGES_PER_PAGE = 6  # Number of images to display per page

def display_images(class_index, page):
    if class_index < len(index_list):
        class_name, indices = index_list[class_index]
        start_idx = page * IMAGES_PER_PAGE
        end_idx = start_idx + IMAGES_PER_PAGE
        num_imgs = len(indices[start_idx:end_idx])
        
        fig, axs = plt.subplots(1, num_imgs, figsize=(15, 5))
        fig.suptitle(f"Class {class_name} - Page {page}")

        # If there is only one image, axs is not an array, so we handle that case
        if num_imgs == 1:
            axs = [axs]

        for ax, idx in zip(axs, indices[start_idx:end_idx]):
            img = raw_images[idx]
            ax.imshow(img)
            ax.axis("off")

        # Hide unused subplots if any
        for ax in axs[num_imgs:]:
            ax.axis("off")

        plt.tight_layout()
        plt.show()
    else:
        print("Invalid class index.")

def update_page_selector(*args):
    class_index = class_selector.value
    num_pages = (len(index_list[class_index][1]) + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE
    page_selector.max = num_pages - 1
    page_selector.value = 0  # Reset to the first page when class changes

# Create a dropdown widget to select the class
class_selector = widgets.Dropdown(
    options=[(f"Class {cls}", i) for i, (cls, _) in enumerate(index_list)],
    description="Class:"
)
class_selector.observe(update_page_selector, 'value')

# Create a slider widget to select the page within the selected class
page_selector = widgets.IntSlider(min=0, max=0, description="Page:")

# Display the widgets and images
widgets.interact(display_images, class_index=class_selector, page=page_selector)

# Initialize the page selector based on the initial class selection
update_page_selector()


In [None]:
# labels = df.loc[cluster_indices[cls], "label"].values

# rest_cluster_indices = [x for i,x in enumerate(cluster_indices) if i != cls]

# low_dim_em_class = low_dim_em[cluster_indices[cls]]

low_dim_em_classes = [(cls, [low_dim_em[x] for x in indices]) for cls, indices in index_dict.items()]

fig = figure(tools="pan, wheel_zoom, box_zoom, reset")

cds1 = ColumnDataSource(data={'x': low_dim_em[:, 0], 'y': low_dim_em[:, 1], 'video_id': video_ids, 'tracking_id': tracking_ids})

fig.scatter(
    source=cds1,
    size=12,
    line_color="black",
    fill_color="gray",
    alpha=0.3
)

exp_cmap = LinearColorMapper(palette=cc.glasbey, 
                             low = min(tracking_ids), 
                             high = max(tracking_ids))

x = [low_dim_em_classes[x][1] for x in range(len(index_list))]

source = ColumnDataSource(data={'x': x, 'video_id': [[video_ids[idx] for idx in index_list[x][1]] for x in range(len(index_list))], 'tracking_id': [[tracking_ids[idx] for idx in index_list[x][1]] for x in range(len(index_list))], 'class': [low_dim_em_classes[x][0] for x in range(len(index_list))]})
source_visible = ColumnDataSource(data={'x': np.array(low_dim_em_classes[0][1])[:, 0], 'y': np.array(low_dim_em_classes[0][1])[:, 1], 'video_id': [video_ids[idx] for idx in index_list[0][1]], 'tracking_id': [tracking_ids[idx] for idx in index_list[0][1]], 'class': [low_dim_em_classes[0][0]]})

display(source.data["video_id"])

fig.scatter(
    source=source_visible,
    size=12,
    line_color="black",
    fill_color={"field": "tracking_id", "transform": exp_cmap},
)

div = Div(text=f"<b>Class: {low_dim_em_classes[0][0]}</b>")

slider = Slider(start=0, end=len(index_list)-1, value=0, step=1, title="Cluster")
slider.js_on_change("value", CustomJS(args=dict(source_visible=source_visible, source_avaliable=source, slider=slider, div=div), code="""
    var data_visible = source_visible.data;
    var data = source_avaliable.data;
    
    // Update x values
    const dataArray = data.x[slider.value.toString()];
    const firstColumn = dataArray.map(row => row[0]);
    
    data_visible.x = firstColumn;
    
    // Update y values
    const secondColumn = dataArray.map(row => row[1]);
    
    data_visible.y = secondColumn;

    // Update video_id
    data_visible.video_id = data.video_id[slider.value.toString()];

    // Update tracking_id
    data_visible.tracking_id = data.tracking_id[slider.value.toString()];
    
    data_visible.class = data.class[slider.value.toString()];

    div.text = "<b>Class: " + data_visible.class + "</b>";
    
    source_visible.change.emit();
"""))


fig.toolbar.active_scroll = fig.select_one(WheelZoomTool)

hover = HoverTool(tooltips=[("video_id", "@video_id"), ("tracking_id", "@tracking_id")])
fig.add_tools(hover)
show(column(fig, column(slider, div)))

In [None]:
cls = 1

engine = GorillaDatasetKISZ().engine
session_cls = sessionmaker(bind=engine)

videos = [video_ids[idx] for idx in index_list[cls][1]]
print(len(videos))
stmt = select(Video.camera_id, Video.start_time).where(Video.video_id.in_(videos)).where(Video.start_time.isnot(None))

with session_cls() as session:
    result = session.execute(stmt).all()

sorted_videos = sorted(result, key=lambda x: x[1])

print(len(sorted_videos))

print(sorted_videos)

In [None]:
stmt = (
    select(Camera.camera_id, Camera.longitude, Camera.latitude).where(Camera.longitude.isnot(None)).where(Camera.latitude.isnot(None))
    )

with session_cls() as session:
    result = session.execute(stmt)
    all2 = sorted(result.fetchall(), key=lambda x: x[0])

mapping = {x : (y,z) for x, y, z in all2}

id, x, y = zip(*all2)

valid_videos = []

for vid in sorted_videos:
    if vid[0] in id and vid[1] is not None:
        valid_videos.append(vid)

print(len(valid_videos))
for i in range(len(valid_videos)):
    if i == len(valid_videos) - 1:
        break
    if valid_videos[i].camera_id != valid_videos[i+1].camera_id:
        time = valid_videos[i+1].start_time - valid_videos[i].start_time
        x1, y1 = mapping[valid_videos[i].camera_id]
        x2, y2 = mapping[valid_videos[i+1].camera_id]
        dist = geopy.distance.distance((y1, x1), (y2, x2)).km
        print(dist/time.total_seconds())
        avg_speed = 2.3148e-5 #km/s
        if dist/time.total_seconds() > avg_speed:
            print(i)
            if id.count(valid_videos[i].camera_id) > id.count(valid_videos[i+1].camera_id):
                valid_videos.pop(i+1)
            else:
                valid_videos.pop(i)
        
print(len(valid_videos))
p = figure(x_range=(26.98, 27.12), y_range=(12.78, 12.89))
p.image_url(url=['wald.jpg'], x=26.5, y=13, w=1, h=0.5)

cam_source = ColumnDataSource(data=dict(x=x, y=y, text=id))
glyph = Text(x="x", y="y", text="text", text_align="center", text_font_size="10pt", text_color="white")
p.add_glyph(cam_source, glyph)

source = ColumnDataSource(data=dict(x=[mapping[x][0] for x, _ in valid_videos], y=[mapping[x][1] for x, _ in valid_videos], time=[x.strftime('%Y-%m-%d %H:%M:%S') for _, x in valid_videos]))
source_visible = ColumnDataSource(data=dict(x=[mapping[valid_videos[0][0]][0]], y=[mapping[valid_videos[0][0]][1]], time=[valid_videos[0][1].strftime('%Y-%m-%d %H:%M:%S')]))
p.scatter(x="x", y="y", size=10, source=source_visible, fill_color="red", line_color="black")

div = Div(text=f"<b>Time: {valid_videos[0][1]}</b>")

slider = Slider(start=0, end=len(valid_videos)-1, value=0, step=1, title="Time")
slider.js_on_change("value", CustomJS(args=dict(source_visible=source_visible, source_avaliable=source, slider=slider, div=div), code="""
    var data_visible = source_visible.data;
    var data = source_avaliable.data;
    data_visible.x = [data.x[slider.value.toString()]];
    data_visible.y = [data.y[slider.value.toString()]];
    data_visible.time = data.time[slider.value.toString()];
    div.text = "<b>Time: " + data_visible.time + "</b>";
    source_visible.change.emit();
"""))

p.xaxis.axis_label = "Longitude"
p.yaxis.axis_label = "Latitude"

q = figure(x_range=(26.98, 27.12), y_range=(12.78, 12.89))
q.image_url(url=['wald.jpg'], x=26.5, y=13, w=1, h=0.5)

points = []
for (camera_id, start_time) in valid_videos:
    (long, lat) = mapping[camera_id]
    if long is None or lat is None:
        continue
    points.append((long + np.random.normal(0.0005, 0.0015), lat + np.random.normal(0.0005, 0.0015)))
        
q.add_glyph(cam_source, glyph)

x,y = zip(*points)

q.scatter(x, y, size=10, fill_color="red", line_color="black")

# plt.hist2d(x,y, bins=1000, cmap="twilight", range=[[26.98, 27.12], [12.78, 12.89]])
# plt.xlabel("Longitude (2.17km between ticks)")
# plt.ylabel("Latitude (2.22km between ticks)")
# plt.xlim(26.98, 27.12)
# plt.ylim(12.78, 12.89)
# plt.title("Tracklet distribution heatmap over camera positions")
# for i, (x, y) in mapping.items():
#     if x is None or y is None:
#         continue
#     plt.text(x, y, f"{i}", ha="center", fontsize=6)

show(column(p, column(slider, div)))
show(q)

# plt.show()

# OLD CODE

In [None]:
from sklearn.cluster import KMeans

n_clusters = 60

kmeans = KMeans(n_clusters=n_clusters, random_state=0)

# Fit KMeans to embeddings
kmeans.fit(embeddings)

# Get cluster labels
cluster_labels = kmeans.labels_

cluster_indices = [np.where(cluster_labels == i)[0] for i in range(n_clusters)]


In [None]:
from bokeh.models import Slider, CustomJS
from bokeh.layouts import column


# labels = df.loc[cluster_indices[cls], "label"].values

# rest_cluster_indices = [x for i,x in enumerate(cluster_indices) if i != cls]

# low_dim_em_class = low_dim_em[cluster_indices[cls]]

low_dim_em_classes = [low_dim_em[x] for x in cluster_indices]

fig = figure(tools="pan, wheel_zoom, box_zoom, reset")

cds1 = ColumnDataSource(data={'x': low_dim_em[:, 0], 'y': low_dim_em[:, 1], 'video_id': [video_ids[idx] for idx in range(len(low_dim_em))], 'tracking_id': [tracking_ids[idx] for idx in range(len(low_dim_em))]})

fig.scatter(
    source=cds1,
    size=12,
    line_color="black",
    fill_color="gray",
    alpha=0.3
)

exp_cmap = LinearColorMapper(palette=cc.glasbey, 
                             low = min(tracking_ids), 
                             high = max(tracking_ids))

source = ColumnDataSource(data={'x': [low_dim_em_classes[x].tolist() for x in range(n_clusters)], 'video_id': [[video_ids[idx] for idx in cluster_indices[x]] for x in range(len(cluster_indices))], 'tracking_id': [[tracking_ids[idx] for idx in cluster_indices[x]] for x in range(len(cluster_indices))]})
source_visible = ColumnDataSource(data={'x': low_dim_em_classes[0][:, 0], 'y': low_dim_em_classes[0][:, 1], 'video_id': [video_ids[idx] for idx in cluster_indices[0]], 'tracking_id': [tracking_ids[idx] for idx in cluster_indices[0]]})

display(source.data["video_id"])

fig.scatter(
    source=source_visible,
    size=12,
    line_color="black",
    fill_color={"field": "tracking_id", "transform": exp_cmap},
)

slider = Slider(start=0, end=n_clusters-1, value=0, step=1, title="Cluster")
slider.js_on_change("value", CustomJS(args=dict(source_visible=source_visible, source_avaliable=source, slider=slider), code="""
    var data_visible = source_visible.data;
    var data = source_avaliable.data;
    
    // Update x values
    const dataArray = data.x[slider.value.toString()];
    const firstColumn = dataArray.map(row => row[0]);
    
    data_visible.x = firstColumn;
    
    // Update y values
    const secondColumn = dataArray.map(row => row[1]);
    
    data_visible.y = secondColumn;

    // Update video_id
    data_visible.video_id = data.video_id[slider.value.toString()];

    // Update tracking_id
    data_visible.tracking_id = data.tracking_id[slider.value.toString()];

    source_visible.change.emit();
"""))


fig.toolbar.active_scroll = fig.select_one(WheelZoomTool)

hover = HoverTool(tooltips=[("video_id", "@video_id"), ("tracking_id", "@tracking_id")])
fig.add_tools(hover)
show(column(fig, slider))

In [None]:
from bokeh.models import Slider, CustomJS
from bokeh.layouts import column


# labels = df.loc[cluster_indices[cls], "label"].values

# rest_cluster_indices = [x for i,x in enumerate(cluster_indices) if i != cls]

# low_dim_em_class = low_dim_em[cluster_indices[cls]]

low_dim_em_classes = [low_dim_em[x] for x in cluster_indices]

fig = figure(tools="pan, wheel_zoom, box_zoom, reset")

cds1 = ColumnDataSource(data={'x': low_dim_em[:, 0], 'y': low_dim_em[:, 1], 'video_id': video_ids, 'tracking_id': tracking_ids})

fig.scatter(
    source=cds1,
    size=12,
    line_color="black",
    fill_color="gray",
    alpha=0.3
)

exp_cmap = LinearColorMapper(palette=cc.glasbey, 
                             low = min(tracking_ids), 
                             high = max(tracking_ids))

source = ColumnDataSource(data={'x': [low_dim_em_classes[x].tolist() for x in range(n_clusters)], 'video_id': [[video_ids[idx] for idx in cluster_indices[x]] for x in range(len(cluster_indices))], 'tracking_id': [[tracking_ids[idx] for idx in cluster_indices[x]] for x in range(len(cluster_indices))]})
source_visible = ColumnDataSource(data={'x': low_dim_em_classes[0][:, 0], 'y': low_dim_em_classes[0][:, 1], 'video_id': [video_ids[idx] for idx in cluster_indices[0]], 'tracking_id': [tracking_ids[idx] for idx in cluster_indices[0]]})

display(source.data["video_id"])

fig.scatter(
    source=source_visible,
    size=12,
    line_color="black",
    fill_color={"field": "tracking_id", "transform": exp_cmap},
)

slider = Slider(start=0, end=n_clusters-1, value=0, step=1, title="Cluster")
slider.js_on_change("value", CustomJS(args=dict(source_visible=source_visible, source_avaliable=source, slider=slider), code="""
    var data_visible = source_visible.data;
    var data = source_avaliable.data;
    
    // Update x values
    const dataArray = data.x[slider.value.toString()];
    const firstColumn = dataArray.map(row => row[0]);
    
    data_visible.x = firstColumn;
    
    // Update y values
    const secondColumn = dataArray.map(row => row[1]);
    
    data_visible.y = secondColumn;

    // Update video_id
    data_visible.video_id = data.video_id[slider.value.toString()];

    // Update tracking_id
    data_visible.tracking_id = data.tracking_id[slider.value.toString()];

    source_visible.change.emit();
"""))


fig.toolbar.active_scroll = fig.select_one(WheelZoomTool)

hover = HoverTool(tooltips=[("video_id", "@video_id"), ("tracking_id", "@tracking_id")])
fig.add_tools(hover)
show(column(fig, slider))

In [None]:
from bokeh.models import ColumnDataSource, Text, Circle
from sqlalchemy.orm import sessionmaker
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.resources import INLINE

output_notebook(INLINE)

cls = 1

engine = GorillaDatasetKISZ().engine
session_cls = sessionmaker(bind=engine)

videos = [video_ids[idx] for idx in cluster_indices[cls]]

stmt = select(Video.camera_id, Video.start_time).where(Video.video_id.in_(videos)).where(Video.start_time.isnot(None))

with session_cls() as session:
    result = session.execute(stmt).all()

sorted_videos = sorted(result, key=lambda x: x[1])

print(sorted_videos)