In [None]:
from sqlalchemy import select, text, create_engine
from sqlalchemy.orm import Session
from gorillatracker.ssl_pipeline.models import TrackingFrameFeature, Tracking, Video, Camera, VideoFeature
from gorillatracker.ssl_pipeline.dataset import GorillaDatasetKISZ
import importlib
from typing import List, Literal, Optional, Tuple

from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms

import gorillatracker.type_helper as gtypes
from gorillatracker.transform_utils import SquarePad
from gorillatracker.type_helper import Id, Label
from gorillatracker.utils.labelencoder import LabelEncoder

from gorillatracker.data.nlet_dm import FlatNletBuilder, NletDataModule
from gorillatracker.data.nlet import SupervisedDataset, build_onelet
from gorillatracker.utils.embedding_generator import generate_embeddings, df_from_predictions
from gorillatracker.utils.wandb_loader import load_model

from pathlib import Path
import gorillatracker.utils.embedding_generator as eg
import pandas as pd
from io import BytesIO
import base64
from sklearn.manifold import TSNE
import numpy as np
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, LinearColorMapper, HoverTool, WheelZoomTool, FixedTicker, DatetimeTicker
from bokeh.io import output_notebook
from bokeh.resources import INLINE
import colorcet as cc
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
import ipywidgets as widgets
from bokeh.models import Slider, CustomJS, Div
from bokeh.layouts import column, row
from bokeh.models import Text, Line, Scatter
from sqlalchemy.orm import sessionmaker
import geopy.distance
from collections import Counter
from IPython.display import display, clear_output
import dill
import torch
from gorillatracker.model.wrappers_supervised import BaseModuleSupervised
from gorillatracker.old_model import VisionTransformerWrapper

output_notebook(INLINE)

sample = 10
query = text(
f"""WITH ranked_features AS (
    SELECT
        tracking_id,
        tracking_frame_feature_id,
        bbox_width,
        bbox_height,
        frame_nr,
        feature_type,
        ROW_NUMBER() OVER (PARTITION BY tracking_id ORDER BY RANDOM()) AS rn
    FROM tracking_frame_feature
    WHERE feature_type = 'face_45'
        AND bbox_width >= 184
        AND bbox_height >= 184
        AND tracking_id IS NOT NULL
)
SELECT
    tracking_frame_feature_id
FROM ranked_features
WHERE rn <= {sample}
"""     # 10 random samples pro tracking_id
)

engine = create_engine(GorillaDatasetKISZ.DB_URI)

In [None]:
# devcontainer.mount add "source=/mnt/vast-gorilla/cropped-images,target=/workspaces/gorillatracker/cropped-images,type=bind,ro"
base_path = "../video_data/cropped-images/2024-04-18"

class HackDataset(Dataset[Tuple[Id, Tensor, Label]]):
    def get_tffs(self) -> list[TrackingFrameFeature]:
        engine = create_engine(GorillaDatasetKISZ.DB_URI)

        stmt = select(TrackingFrameFeature, Video.video_id, Tracking.tracking_id).join(
            Tracking, Tracking.tracking_id == TrackingFrameFeature.tracking_id).join(
            Video, Tracking.video_id == Video.video_id).where(
            TrackingFrameFeature.tracking_frame_feature_id.in_(query)).where(Video.start_time >= "2022-01-01 00:00:00")

        with Session(engine) as session:
            return session.execute(stmt).all()

    def __init__(
        self, base_dir: str, partition: Literal["train", "val", "test"], nlet_builder: FlatNletBuilder, transform: Optional[gtypes.Transform] = None
    ):
        self.tffs = self.get_tffs()
        self.transform = transform
        self.partition = partition
        self.video_ids = [tff.video_id for tff in self.tffs]
        self.tracking_ids = [tff.tracking_id for tff in self.tffs]

    def __len__(self) -> int:
        return len(self.tffs)

    def __getitem__(self, idx: int) -> Tuple[Id, Tensor, Label]:
        """tracklets will be labels for now"""
        tff = self.tffs[idx][0]

        img = Image.open(tff.cache_path(base_path))
        if self.transform:
            img = self.transform(img)
        assert tff.tracking_id is not None
        return tff.tracking_frame_feature_id, img, tff.tracking_id
    
    @classmethod
    def get_transforms(cls) -> gtypes.Transform:
        return transforms.Compose(
            [
                SquarePad(),
                # Uniform input, you may choose higher/lower sizes.
                transforms.Resize(256),
                transforms.ToTensor(),
            ]
        )

In [None]:
# python3 train.py --config_path cfgs/swinv2_cxl.yml

# gorillas/Embedding-SwinV2Base-CXL-Open/model-yl2lx567:v12
if False:
    run_url = "https://wandb.ai/gorillas/Embedding-SwinV2Large-CXL-Open/runs/olql2abq?nw=nwuseremirhan404"
    model = get_model_for_run_url(run_url)
else:
    # path = "/workspaces/gorillatracker/models/vit_large_dinov2_baseline.ckpt"
    path = "/workspaces/gorillatracker/models/roberts_models/gorillas_models/vit_large_dinov2_bayes/fold-0-epoch-19-cxlkfold/fold-0/val/embeddings/knn5_crossvideo/accuracy-0.63.ckpt"
    model = BaseModuleSupervised.load_from_checkpoint(path, data_module=None, wandb_run=None, strict=False)
    # model = load_model(VisionTransformerWrapper, path)
model_transforms = transforms.Compose(
    [
        SquarePad(),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [None]:
importlib.reload(eg)
# python3 train.py --config_path cfgs/swinv2_cxl.yml

# gorillas/Embedding-SwinV2Base-CXL-Open/model-yl2lx567:v12
data_dir = "/workspaces/gorillatracker/data/supervised/splits/cxl_faces_square_all-in-val"
data_module = NletDataModule(
    data_dir=Path(data_dir),
    dataset_class=SupervisedDataset,
    nlet_builder=build_onelet,
    batch_size=32,
    workers=10,
    model_transforms=model_transforms,
    training_transforms=lambda x: x,
    dataset_names=["Inference"],
)
print("test1")
data_module.setup("validate")
print("test2")
dataloader = data_module.val_dataloader()
print("test3")
predictions = eg.generate_embeddings(model, dataloader)
print("test4")
print(len(predictions[0]), len(predictions[1]), len(predictions[2]))
df_sup = eg.df_from_predictions(predictions)
print("test5")
df_sup.to_pickle("embeddings.pkl")
print(df_sup.head(5))

In [None]:
importlib.reload(eg)

data_dir = "/workspaces/gorillatracker/video_data/cropped-images/2024-04-18"

data_module = NletDataModule(
    data_dir=Path(data_dir),
    dataset_class=HackDataset,
    nlet_builder=build_onelet,
    batch_size=32,
    workers=40,
    model_transforms=model_transforms,
    training_transforms=lambda x: x,
    dataset_names=["Inference"],
)
print("test1")
data_module.setup("validate")
print("test2")
dataloader = data_module.val_dataloader()
print("test3")
predictions = eg.generate_ssl_embeddings(model, dataloader)
print("test4")
df = eg.df_from_predictions(predictions)
print("test5")
with open("ssl_embeddings_2022.pkl", "wb") as dill_file:
    dill.dump(df, dill_file)

In [None]:
with open("ssl_embeddings.pkl", "rb") as dill_file:
    df = dill.load(dill_file)
# df = pd.read_pickle("ssl_embeddings.pkl")
display(df.head())
df["embedding"] = df["embedding"].apply(torch.from_numpy)
print(df["embedding"].iloc[0])

with open("embeddings.pkl", "rb") as dill_file:
    df_sup = dill.load(dill_file)
# df_sup = pd.read_pickle("embeddings.pkl")
df_sup["embedding"] = df_sup["embedding"].apply(torch.from_numpy)
display(df_sup.head())

print(df.shape, df_sup.shape)

In [None]:
tff_ids = df["id"].tolist()

# Ensure tff_ids are scalars if they are tensor objects
tff_ids = [t.item() for t in tff_ids]

# Create a mapping of tff_ids to their positions
id_order = {id: index for index, id in enumerate(tff_ids)}

print(tff_ids)

# Create the SQL statement
stmt = (
    select(TrackingFrameFeature.tracking_frame_feature_id, TrackingFrameFeature.video_id, TrackingFrameFeature.tracking_id)
    .where(TrackingFrameFeature.tracking_frame_feature_id.in_(tff_ids))
)

# Execute the query
with Session(engine) as session:
    result = session.execute(stmt).all()

# Unpack the results and order them based on the original tff_ids list
ordered_result = sorted(result, key=lambda x: id_order[x[0]])

# Unzip the ordered result into separate lists
tff_ids2, video_ids, tracking_ids = zip(*ordered_result)

print(tff_ids2)
print(len(tff_ids2))

In [None]:
embeddings_sup = np.vstack(df_sup["embedding"].values)
labels_sup = df_sup["label"].values
labels_sup = [x.item() for x in labels_sup]
label_strings_sup = df_sup["label_string"].values

embeddings = np.vstack(df["embedding"].values)
print(len(embeddings))
labels = df["label"].values

In [None]:
images = []
raw_images = []
for tff_id in tff_ids:
    path = Path(
        base_path,
        str(tff_id % 2**8),
        str(tff_id % 2**16),
        f"{tff_id}.png",
    )
    image = Image.open(path)
    raw_images.append(image)
    buffer = BytesIO()
    image.save(buffer, format="JPEG")
    image_byte = base64.b64encode(buffer.getvalue()).decode("utf-8")
    images.append(image_byte)

print(len(embeddings_sup))

images_sup = []
raw_images_sup = []
for id in df_sup["id"]:
    img = Image.open(id)
    raw_images_sup.append(img)
    buffer = BytesIO()
    img.save(buffer, format="JPEG")
    image_byte = base64.b64encode(buffer.getvalue()).decode("utf-8")
    images_sup.append(image_byte)

algo = TSNE(n_components=2, perplexity=30)

all_embeddings = np.vstack([embeddings, embeddings_sup])
low_dim_em_all = algo.fit_transform(all_embeddings)
low_dim_em = low_dim_em_all[:len(embeddings)]
low_dim_em_sup = low_dim_em_all[len(embeddings):]

print(len(low_dim_em_all))

In [None]:
fig = figure(tools="pan, wheel_zoom, box_zoom, reset")

cds = ColumnDataSource(data={'x': low_dim_em[:, 0], 'y': low_dim_em[:, 1], 'video_id': video_ids, 'tracking_id': tracking_ids, 'image': images})

exp_cmap = LinearColorMapper(palette=cc.glasbey, 
                             low = min(tracking_ids), 
                             high = max(tracking_ids))
fig.scatter(
    source=cds,
    size=12,
    line_color="black",
    fill_color={"field": "tracking_id", "transform": exp_cmap},
)
fig.toolbar.active_scroll = fig.select_one(WheelZoomTool)

hover = HoverTool(tooltips='<img src="data:image/jpeg;base64,@image" width="128" height="128">')
fig.add_tools(hover)
show(fig)

In [None]:
fig = figure(tools="pan, wheel_zoom, box_zoom, reset")

cds_sup = ColumnDataSource(data={'x': low_dim_em_sup[:, 0], 'y': low_dim_em_sup[:, 1], 'label': labels_sup, 'label_str': label_strings_sup, 'image': images_sup})

exp_cmap_sup = LinearColorMapper(palette=cc.glasbey, 
                             low = min(labels_sup), 
                             high = max(labels_sup))
fig.scatter(
    source=cds_sup,
    size=12,
    line_color="black",
    fill_color={"field": "label", "transform": exp_cmap_sup},
)
fig.toolbar.active_scroll = fig.select_one(WheelZoomTool)

fig.scatter(
    source=cds,
    size=12,
    line_color="black",
    fill_color={"field": "tracking_id", "transform": exp_cmap},
    alpha=0.2
)

hover = HoverTool(tooltips="""<div>
                    <img src="data:image/jpeg;base64,@image" width="128" height="128">
                    <div><strong>Label:</strong> @label_str</div>
                  </div>""")
fig.add_tools(hover)
show(fig)

In [None]:
k = 5
knn = KNeighborsClassifier(n_neighbors=k)

knn.fit(embeddings_sup, label_strings_sup)

distance_threshold = 10

# Get the distances and indices of the nearest neighbors
distances, indices = knn.kneighbors(embeddings)

# Classify the new embeddings
predictions = knn.predict(embeddings)

# Introduce the threshold distance condition
new_len = 0
valid_predictions = []
for i, (dist, pred) in enumerate(zip(distances, predictions)):
    if dist[0] > distance_threshold:
        valid_predictions.append(-1)  # -1 for invalid label
    else:
        valid_predictions.append(pred)
        new_len+=1

display("Valid Predictions:", valid_predictions)

print(new_len, len(predictions))

In [None]:
image_pages = []
for i, pred in enumerate(valid_predictions):
    if pred != -1:
        closest = [(raw_images_sup[indices[i][j]], label_strings_sup[indices[i][j]], distances[i][j]) for j in range(k)]
        image_pages.append((raw_images[i], closest, pred))
        
def display_images(page):
    fig, axs = plt.subplots(1, k+1)  # Create subplots
    if page < len(image_pages): 
        img = image_pages[page][0]
        axs[0].imshow(img)
        axs[0].set_title("Prediction: {}".format(image_pages[page][2]), fontsize=8)
        for i in range(k):
            img, label, dist = image_pages[page][1][i]
            axs[i+1].imshow(img)
            axs[i+1].set_title(str(i) + ". closest image \n dist: " + str(round(dist, 3)) + "\n Label: " + label, fontsize=8)
            axs[i+1].axis("off")
        else:
            axs[i].axis("off")  # Hide axes for empty subplots
    plt.tight_layout()
    plt.show()


page_selector = widgets.IntSlider(min=0, max=(len(image_pages) - 1), description="Page:")
widgets.interact(display_images, page=page_selector)

In [None]:
index_dict = {}

# Create a dictionary to map classes to image indices
for i, cls in enumerate(valid_predictions):
    if cls == -1:
        continue
    if cls not in index_dict:
        index_dict[cls] = []
    index_dict[cls].append(i)
    
# Convert the dictionary to a list of tuples for easy indexing
index_list = list(index_dict.items())

# Constants
IMAGES_PER_PAGE = 6  # Number of images to display per page

def display_images(class_index, page):
    if class_index < len(index_list):
        class_name, indices = index_list[class_index]
        start_idx = page * IMAGES_PER_PAGE
        end_idx = start_idx + IMAGES_PER_PAGE
        num_imgs = len(indices[start_idx:end_idx])
        
        fig, axs = plt.subplots(1, num_imgs, figsize=(15, 5))
        fig.suptitle(f"Class {class_name} - Page {page}")

        # If there is only one image, axs is not an array, so we handle that case
        if num_imgs == 1:
            axs = [axs]

        for ax, idx in zip(axs, indices[start_idx:end_idx]):
            img = raw_images[idx]
            ax.imshow(img)
            ax.axis("off")

        # Hide unused subplots if any
        for ax in axs[num_imgs:]:
            ax.axis("off")

        plt.tight_layout()
        plt.show()
    else:
        print("Invalid class index.")
        
# Create a slider widget to select the page within the selected class
page_selector = widgets.IntSlider(min=0, max=0, description="Page:")

def update_page_selector(*args):
    class_index = class_selector.value
    num_pages = (len(index_list[class_index][1]) + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE
    page_selector.max = num_pages - 1
    page_selector.value = 0  # Reset to the first page when class changes

# Create a dropdown widget to select the class
class_selector = widgets.Dropdown(
    options=[(f"Class {cls}", i) for i, (cls, _) in enumerate(index_list)],
    description="Class:"
)
class_selector.observe(update_page_selector, 'value')

# Display the widgets and images
widgets.interact(display_images, class_index=class_selector, page=page_selector)

# Initialize the page selector based on the initial class selection
update_page_selector()


In [None]:
# labels = df.loc[cluster_indices[cls], "label"].values

# rest_cluster_indices = [x for i,x in enumerate(cluster_indices) if i != cls]

# low_dim_em_class = low_dim_em[cluster_indices[cls]]

low_dim_em_classes = [(cls, [low_dim_em[x] for x in indices]) for cls, indices in index_dict.items()]

fig = figure(tools="pan, wheel_zoom, box_zoom, reset")

cds1 = ColumnDataSource(data={'x': low_dim_em[:, 0], 'y': low_dim_em[:, 1], 'video_id': video_ids, 'tracking_id': tracking_ids})

fig.scatter(
    source=cds1,
    size=12,
    line_color="black",
    fill_color="gray",
    alpha=0.3
)

exp_cmap = LinearColorMapper(palette=cc.glasbey, 
                             low = min(tracking_ids), 
                             high = max(tracking_ids))

x = [low_dim_em_classes[x][1] for x in range(len(index_list))]

source = ColumnDataSource(data={'x': x, 'video_id': [[video_ids[idx] for idx in index_list[x][1]] for x in range(len(index_list))], 'tracking_id': [[tracking_ids[idx] for idx in index_list[x][1]] for x in range(len(index_list))], 'class': [low_dim_em_classes[x][0] for x in range(len(index_list))]})
source_visible = ColumnDataSource(data={'x': np.array(low_dim_em_classes[0][1])[:, 0], 'y': np.array(low_dim_em_classes[0][1])[:, 1], 'video_id': [video_ids[idx] for idx in index_list[0][1]], 'tracking_id': [tracking_ids[idx] for idx in index_list[0][1]], 'class': [low_dim_em_classes[0][0]]})

display(source.data["video_id"])

fig.scatter(
    source=source_visible,
    size=12,
    line_color="black",
    fill_color={"field": "tracking_id", "transform": exp_cmap},
)

div = Div(text=f"<b>Class: {low_dim_em_classes[0][0]}</b>")

slider = Slider(start=0, end=len(index_list)-1, value=0, step=1, title="Cluster")
slider.js_on_change("value", CustomJS(args=dict(source_visible=source_visible, source_avaliable=source, slider=slider, div=div), code="""
    var data_visible = source_visible.data;
    var data = source_avaliable.data;
    
    // Update x values
    const dataArray = data.x[slider.value.toString()];
    const firstColumn = dataArray.map(row => row[0]);
    
    data_visible.x = firstColumn;
    
    // Update y values
    const secondColumn = dataArray.map(row => row[1]);
    
    data_visible.y = secondColumn;

    // Update video_id
    data_visible.video_id = data.video_id[slider.value.toString()];

    // Update tracking_id
    data_visible.tracking_id = data.tracking_id[slider.value.toString()];
    
    data_visible.class = data.class[slider.value.toString()];

    div.text = "<b>Class: " + data_visible.class + "</b>";
    
    source_visible.change.emit();
"""))


fig.toolbar.active_scroll = fig.select_one(WheelZoomTool)

hover = HoverTool(tooltips=[("video_id", "@video_id"), ("tracking_id", "@tracking_id")])
fig.add_tools(hover)
show(column(fig, column(slider, div)))

In [None]:
engine = GorillaDatasetKISZ().engine
session_cls = sessionmaker(bind=engine)

sorted_videos_list = []

for cls in range(len(index_list)):
    videos = [video_ids[idx] for idx in index_list[cls][1]]
    print(len(videos))
    stmt = select(Video.camera_id, Video.start_time, Video.video_id).where(Video.video_id.in_(videos)).where(Video.start_time.isnot(None))

    with session_cls() as session:
        result = session.execute(stmt).all()

    sorted_videos = sorted(result, key=lambda x: x[1])
    print(len(sorted_videos))
    print(sorted_videos)
    
    sorted_videos_list.append((index_list[cls][0], sorted_videos))
    
print(sorted_videos_list)

In [None]:
stmt = (select(Video.video_id, VideoFeature.value).join(Video, Video.video_id == VideoFeature.video_id)
        .where(Video.video_id.in_(video_ids)))

with session_cls() as session:
    video_labels = session.execute(stmt).all()
    
    video_labels_map = dict(video_labels)

print(video_labels_map)

In [None]:
stmt = (
    select(Camera.camera_id, Camera.longitude, Camera.latitude).where(Camera.longitude.isnot(None)).where(Camera.latitude.isnot(None))
    )

with session_cls() as session:
    result = session.execute(stmt)
    all2 = sorted(result.fetchall(), key=lambda x: x[0])

mapping = {x : (y,z) for x, y, z in all2}

id, x, y = zip(*all2)

def better_camera(id1, id2, valid_videos):
    count1 = 0
    count2 = 0
    for vid in valid_videos:
        if vid.camera_id == id1:
            count1 += 1
        if vid.camera_id == id2:
            count2 += 1
    return count1 > count2

valid_videos_list = []

model_error = 0
model_error_match = 0
for cls, videos in sorted_videos_list:
    valid_videos = []
    error_count = 0
    match_error_count = 0
    
    for vid in videos:
        if vid[0] in id and vid[1] is not None:
            valid_videos.append(vid)

    final_videos = []
    # print(len(valid_videos))
    skip_next = False
    for i in range(len(valid_videos)):
        label_mismatch = False
        try:
            if cls[:2] != video_labels_map[valid_videos[i].video_id]:
                print(cls[:2], video_labels_map[valid_videos[i].video_id])
                match_error_count += 1
                label_mismatch = True
        except(KeyError):
            pass
        if i == len(valid_videos) - 1:
            break
        if skip_next:
            skip_next = False
            continue
        time = valid_videos[i+1].start_time - valid_videos[i].start_time
        x1, y1 = mapping[valid_videos[i].camera_id]
        x2, y2 = mapping[valid_videos[i+1].camera_id]
        dist = geopy.distance.distance((y1, x1), (y2, x2)).m
        # print("time", (time.seconds//3600))
        # print("dist", dist)
        avg_speed = 0.0115741 # m/s
        if dist/time.seconds > avg_speed:
            error_count+=1
            # print(i)
            if better_camera(valid_videos[i].camera_id, valid_videos[i+1].camera_id, valid_videos):
                if not label_mismatch:
                    final_videos.append(valid_videos[i])
            else:
                try:
                    if cls[:2] == video_labels_map[valid_videos[i+1].video_id]:
                        final_videos.append(valid_videos[i+1])
                        skip_next = True
                except(KeyError):
                    final_videos.append(valid_videos[i+1])
                    skip_next = True
        else:
            if not label_mismatch:
                final_videos.append(valid_videos[i])
    if len(valid_videos) > 0:
        model_error += error_count/len(valid_videos)
        model_error_match += match_error_count/len(valid_videos)
    valid_videos = final_videos
    valid_videos_list.append((cls, valid_videos))

model_error /= len(sorted_videos_list)
print("model error: ", model_error)

model_error_match /= len(sorted_videos_list)
print("model error match: ", model_error_match)

p = figure(x_range=(26.98, 27.12), y_range=(12.78, 12.89),  width=550, height=550)
p.image_url(url=['wald.jpg'], x=26.5, y=13, w=1, h=0.5)

cam_source = ColumnDataSource(data=dict(x=x, y=y, text=id))
glyph = Text(x="x", y="y", text="text", text_align="center", text_font_size="10pt", text_color="white")
p.add_glyph(cam_source, glyph)
print(valid_videos_list)
valid_videos = valid_videos_list[10][1]

source = ColumnDataSource(data=dict(x=[mapping[x][0] for x, _, _ in valid_videos], y=[mapping[x][1] for x, _, _ in valid_videos], time=[x.strftime('%Y-%m-%d %H:%M:%S') for _, x, _ in valid_videos]))
source_visible = ColumnDataSource(data=dict(x=[mapping[valid_videos[0][0]][0]], y=[mapping[valid_videos[0][0]][1]], time=[valid_videos[0][1].strftime('%Y-%m-%d %H:%M:%S')]))
p.scatter(x="x", y="y", size=10, source=source_visible, fill_color="red", line_color="black")

div = Div(text=f"<b>Time: {valid_videos[0][1]}</b>")

slider = Slider(start=0, end=len(valid_videos)-1, value=0, step=1, title="Time")
slider.js_on_change("value", CustomJS(args=dict(source_visible=source_visible, source_avaliable=source, slider=slider, div=div), code="""
    var data_visible = source_visible.data;
    var data = source_avaliable.data;
    data_visible.x = [data.x[slider.value.toString()]];
    data_visible.y = [data.y[slider.value.toString()]];
    data_visible.time = data.time[slider.value.toString()];
    div.text = "<b>Time: " + data_visible.time + "</b>";
    source_visible.change.emit();
"""))

p.xaxis.axis_label = "Longitude"
p.yaxis.axis_label = "Latitude"

q = figure(x_range=(26.98, 27.12), y_range=(12.78, 12.89), width=550, height=550)
q.image_url(url=['wald.jpg'], x=26.5, y=13, w=1, h=0.5)

points = []
for (camera_id, start_time, _) in valid_videos:
    (long, lat) = mapping[camera_id]
    if long is None or lat is None:
        continue
    points.append((long + np.random.normal(0.0005, 0.0015), lat + np.random.normal(0.0005, 0.0015)))
        
q.add_glyph(cam_source, glyph)

x,y = zip(*points)

q.scatter(x, y, size=10, fill_color="red", line_color="black")

# plt.hist2d(x,y, bins=1000, cmap="twilight", range=[[26.98, 27.12], [12.78, 12.89]])
# plt.xlabel("Longitude (2.17km between ticks)")
# plt.ylabel("Latitude (2.22km between ticks)")
# plt.xlim(26.98, 27.12)
# plt.ylim(12.78, 12.89)
# plt.title("Tracklet distribution heatmap over camera positions")
# for i, (x, y) in mapping.items():
#     if x is None or y is None:
#         continue
#     plt.text(x, y, f"{i}", ha="center", fontsize=6)

show(row(column(p, column(slider, div)), q))

# plt.show()

In [None]:
def find_starting_camera(valid_videos):
    camera_counts = Counter([vid.camera_id for vid in valid_videos])
    print(camera_counts)
    return camera_counts.most_common(1)[0][0]

distance_sorted_cameras_map = {}
data_values_map = {}

valid_camera_ids_map = {}
for cls, valid_videos in valid_videos_list:
    valid_camera_ids = list(dict.fromkeys([vid.camera_id for vid in valid_videos]))
    valid_camera_ids_map[cls] = valid_camera_ids

valid_cls_list = []
for cls, valid_videos in valid_videos_list:
    if(len(valid_videos) == 0):
        continue
    valid_cls_list.append(cls)
    c_ids = valid_camera_ids_map[cls]
    distance_sorted_cameras = []
    
    start = find_starting_camera(valid_videos)
    distances = []
    for c_id in c_ids:
        x, y = mapping[start]
        x2, y2 = mapping[c_id]
        dist = geopy.distance.distance((y, x), (y2, x2)).km
        distances.append((c_id, dist))
    distances = sorted(distances, key = lambda t: t[1])
    distance_sorted_cameras = [x for x, _ in distances]
    distance_sorted_cameras_map[cls] = distance_sorted_cameras
    
    data_values_map[cls] = [(c_id, time) for c_id, time, _ in valid_videos]


# x = [x.strftime('%Y-%m-%d-%H-%M-%S') for _, x in data_values_map["VI40"]]
# y = [str(y) for y, _ in data_values_map["VI40"]]

# fig = figure(x_range=x, y_range=list(dict.fromkeys(y)))

# print(x)
# print(y)

# cds = ColumnDataSource(data={'x': x, 'y': y})

# glyph = Line(x="x", y="y", line_color="black", line_width=2)

# fig.add_glyph(cds, glyph)
# fig.xaxis.major_label_orientation = "vertical"

# show(fig)

def update_page(change):
    cls = class_selector.value
    x = [x.strftime('%Y-%m-%d-%H-%M-%S') for _, x in data_values_map[cls]]
    y = [str(y) for y, _ in data_values_map[cls]]
    
    # Clear previous output
    clear_output(wait=True)
    display(class_selector)
    
    print(x)
    print(y)
    
    fig = figure(x_range=x, y_range=list(map(str, distance_sorted_cameras_map[cls])))
    cds = ColumnDataSource(data={'x': x, 'y': y})
    glyph = Scatter(x="x", y="y", fill_color="black", line_color="black", size=7)
    fig.add_glyph(cds, glyph)
    fig.xaxis.major_label_orientation = "vertical"
    
    # Show the updated figure
    show(fig)

# Create a dropdown widget to select the class
class_selector = widgets.Dropdown(
    options=[cls for cls in valid_cls_list],
    description="Class:"
)

# Link the dropdown to the update_page function
class_selector.observe(update_page, names='value')

# Display the dropdown
display(class_selector)

# Initial plot
update_page(None)

In [None]:
stmt = (select(Video.video_id, Video.camera_id, Video.start_time).where(Video.video_id.in_(video_ids)))

with session_cls() as session:
    result = session.execute(stmt).all()

print(video_ids)
print(result)

dct = {cls: [] for cls in index_dict.keys()}
for cls, indices in index_dict.items():
    for idx in indices:
        new_vid_ids = [x[0] for x in result]
        index = new_vid_ids.index(video_ids[idx])
        dct[cls].append(result[index].camera_id)
        
print(dct)

social_groups = {cls[:2]: [] for cls in index_dict.keys()}

for cls in index_dict.keys():
    social_groups[cls[:2]].append(cls)
    
print(social_groups)

def similarity(a, b):
    counter1, counter2 = Counter(a), Counter(b)
    score = 0
    real_len = len(a)
    for cam, count in counter1.items():
        if cam not in counter2:
            score += count
        else:
            score += abs(count - counter2[cam])
    for cam, count in counter2.items():
        if cam not in counter1:
            score += count
            real_len += 1
    
    score /= real_len
    
    return score

grp = social_groups["US"]
print(grp)
for i in range(len(grp)):
    for j in range(i+1, len(grp)):
        if i == j:
            continue
        print(grp[i], grp[j])
        print(dct[grp[i]])
        print(dct[grp[j]])
        print(similarity(dct[grp[i]], dct[grp[j]]))
        

In [None]:
similarity_matrix = np.zeros((len(index_dict.keys()), len(index_dict.keys())))

for i, cls1 in enumerate(sorted(index_dict.keys())):
    for j, cls2 in enumerate(sorted(index_dict.keys())):
        if i == j:
            similarity_matrix[i, j] = np.inf
            continue
        similarity_matrix[i, j] = round(similarity(dct[cls1], dct[cls2]), 2)
        
print('\n'.join(['\t'.join([str(cell) for cell in row]) for row in similarity_matrix]))

min_value = np.min(similarity_matrix)

# Find the index of the minimum value
min_index = np.unravel_index(np.argmin(similarity_matrix), similarity_matrix.shape)

print(min_value, min_index, list(sorted(index_dict.keys()))[min_index[0]], list(sorted(index_dict.keys()))[min_index[1]])
print(np.average(similarity_matrix[np.isfinite(similarity_matrix)]))

In [None]:
new_grouping = {}

valid_video_ids = {}
for cls, videos in valid_videos_list:
    for vid in videos:
        if vid.video_id not in valid_video_ids:
            valid_video_ids[vid.video_id] = []
        valid_video_ids[vid.video_id].append(cls)
        
print(valid_video_ids)

test_dict = {}
for key in valid_video_ids.keys():
    try:
        test_dict[key] = (video_labels_map[key], valid_video_ids[key])
    except KeyError:
        pass

print(test_dict)
print(video_labels_map)

same_video_individuals = {}
for vid, cls_list in valid_video_ids.items():
    for i in range(len(cls_list)):
        # for j in range(len(cls_list)):
        #     if i == j:
        #         continue
        #     if cls_list[i] not in same_video_individuals:
        #         same_video_individuals[cls_list[i]] = set()
        #     same_video_individuals[cls_list[i]].add(cls_list[j])
        try:
            same_video_individuals[cls_list[i]].update(cls_list)
        except:
            same_video_individuals[cls_list[i]] = Counter(cls_list)
        
        try:
            new_grouping[cls_list[i][:2]].add(cls_list)
        except:
            new_grouping[cls_list[i][:2]] = set(cls_list)
            
print(same_video_individuals)

print(new_grouping)
for key in new_grouping.keys():
    print(key, [x for x in same_video_individuals.keys() if x[:2] == key])