In [41]:
import argparse
import time
from pathlib import Path
from os.path import join
import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from typing import List, Tuple, Dict, Any, Union
import networkx as nx
from dataclasses import dataclass, field
from scipy.spatial import distance
import nx_altair as nxa
from stg_utils import *
from sys import path
path.append("/Users/mohammadzainabbas/Masters/CS/Big-Data-Research-Project/src/object_detection/yolov7_with_object_tracking")

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, \
                check_imshow, non_max_suppression, apply_classifier, \
                scale_coords, xyxy2xywh, strip_optimizer, set_logging, \
                increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized, TracedModel

from sort import *

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
data = [
    [       1647,         121,        1716,         284,     0.91595,           0],
    [       1606,         609,        1716,         863,     0.90923,           0],
    [        249,         787,         374,        1080,     0.89867,           0],
    [        801,         245,         857,         422,     0.88594,           0],
    [       1452,          35,        1535,         185,     0.88343,           0],
    [       1575,         913,        1715,        1079,     0.87925,           0],
    [        713,         229,         783,         409,     0.87111,           0],
    [        871,         391,         985,         613,     0.86602,           0],
    [       1347,          70,        1395,         151,     0.81387,           1],
    [        879,          77,         943,         232,     0.80651,           0],
    [       1797,         180,        1864,         367,     0.78584,           0],
    [        886,         482,        1022,         643,     0.78523,           1],
    [        294,         298,         364,         481,     0.71708,           0],
    [        615,          78,         690,         193,     0.71245,          58],
    [        616,        1007,         752,        1080,     0.69813,           0],
    [       1509,          90,        1553,         192,     0.69539,           1],
    [       1646,         686,        1712,         791,     0.69249,          26],
    [       1347,          16,        1402,         129,     0.68978,           0],
    [       1154,           0,        1197,          75,     0.65086,           0],
    [         11,         277,          90,         435,     0.58677,           0],
    [        985,           0,        1032,         127,      0.5406,           0],
    [       1324,           0,        1358,          47,     0.53892,           0],
    [        247,         951,         284,        1023,     0.34445,          26],
    [        284,         398,         315,         456,     0.30317,          26],
    [        593,         244,         802,         417,     0.29631,          13],
    [       1837,         262,        1883,         344,     0.28662,           0],
    [       1806,         208,        1857,         292,     0.26673,          24]]

In [13]:
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

In [30]:
@dataclass(frozen=True)
class Node:
    id: int = field(default=0)
    x1: int = field(default=0)
    y1: int = field(default=0)
    x2: int = field(default=0)
    y2: int = field(default=0)
    conf: float = field(default=float(0))
    detclass: int = field(default=0)
    class_name: str = field(default="")
    centroid: tuple = field(default=(0, 0))
    # def __post_init__(self):
    #     self.centroid = ((self.x1 + self.x2) // 2, (self.y1 + self.y2) // 2)

In [31]:
@dataclass(frozen=True)
class Edge:
    weight: Union[float, int] = field(default=0)

In [32]:
def generate_spatial_graph(img, bbox, identities=None, categories=None, confidences = None, names=None, colors = None):
    """
    Construct a spatial graph from the bounding boxes, identities, categories, confidences, names and colors
    """
    graph = nx.Graph()
    for i, box in enumerate(bbox):
        x1, y1, x2, y2 = [int(i) for i in box]
        centroid = ((x1 + x2) // 2, (y1 + y2) // 2)

        cat = int(categories[i]) if categories is not None else 0
        id = int(identities[i]) if identities is not None else 0
        conf = confidences[i] if confidences is not None else 0
        class_name = names[cat]
        graph.add_node(Node(id, x1, y1, x2, y2, conf, cat, class_name, centroid))

    for node1 in graph.nodes:
        for node2 in graph.nodes:
            if node1.id == node2.id: continue
            graph.add_edge(node1, node2, weight=distance.euclidean(node1.centroid, node2.centroid))

    return img, graph

In [33]:
data = np.array(data)
bbox = data[:, :4]
identities = [f"{x}" for x in range(len(data))]
confidences = data[:, 4]
categories = data[:, 5]
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

In [34]:
_, graph = generate_spatial_graph(_, bbox, identities, categories, confidences, names, colors)

In [36]:
str(graph)

'Graph with 27 nodes and 351 edges'

In [40]:
graph.nodes

NodeView((Node(id=0, x1=1647, y1=121, x2=1716, y2=284, conf=0.91595, detclass=0, class_name='person', centroid=(1681, 202)), Node(id=1, x1=1606, y1=609, x2=1716, y2=863, conf=0.90923, detclass=0, class_name='person', centroid=(1661, 736)), Node(id=2, x1=249, y1=787, x2=374, y2=1080, conf=0.89867, detclass=0, class_name='person', centroid=(311, 933)), Node(id=3, x1=801, y1=245, x2=857, y2=422, conf=0.88594, detclass=0, class_name='person', centroid=(829, 333)), Node(id=4, x1=1452, y1=35, x2=1535, y2=185, conf=0.88343, detclass=0, class_name='person', centroid=(1493, 110)), Node(id=5, x1=1575, y1=913, x2=1715, y2=1079, conf=0.87925, detclass=0, class_name='person', centroid=(1645, 996)), Node(id=6, x1=713, y1=229, x2=783, y2=409, conf=0.87111, detclass=0, class_name='person', centroid=(748, 319)), Node(id=7, x1=871, y1=391, x2=985, y2=613, conf=0.86602, detclass=0, class_name='person', centroid=(928, 502)), Node(id=8, x1=1347, y1=70, x2=1395, y2=151, conf=0.81387, detclass=1, class_name=

In [39]:
# nx.draw(graph, with_labels=True, font_weight='bold')
nx.draw(graph, pos=nx.spring_layout(graph))

In [17]:
a = Node(1, 1, 1, 1, 1, 1, 1, "a")

In [21]:
Node(1, 1, 1, 1, 1, 1, 1, "a").__dict__

{'id': 1,
 'x1': 1,
 'y1': 1,
 'x2': 1,
 'y2': 1,
 'conf': 1,
 'detclass': 1,
 'class_name': 'a',
 'centroid': (1, 1)}

In [20]:
a.__dict__

{'id': 1,
 'x1': 1,
 'y1': 1,
 'x2': 1,
 'y2': 1,
 'conf': 1,
 'detclass': 1,
 'class_name': 'a',
 'centroid': (1, 1)}

In [19]:
print(f"{dir(a) = }")

dir(a) = ['__annotations__', '__class__', '__dataclass_fields__', '__dataclass_params__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__match_args__', '__module__', '__ne__', '__new__', '__post_init__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'centroid', 'class_name', 'conf', 'detclass', 'id', 'x1', 'x2', 'y1', 'y2']
