In [1]:
import argparse
import time
from pathlib import Path
from os.path import join
import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from typing import List, Tuple, Dict, Any, Union
import networkx as nx
from dataclasses import dataclass, field
from stg_utils import *
from sys import path
path.append("/Users/mohammadzainabbas/Masters/CS/Big-Data-Research-Project/src/object_detection/yolov7_with_object_tracking")

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, \
                check_imshow, non_max_suppression, apply_classifier, \
                scale_coords, xyxy2xywh, strip_optimizer, set_logging, \
                increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized, TracedModel

from sort import *

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = [
    [       1647,         121,        1716,         284,     0.91595,           0],
    [       1606,         609,        1716,         863,     0.90923,           0],
    [        249,         787,         374,        1080,     0.89867,           0],
    [        801,         245,         857,         422,     0.88594,           0],
    [       1452,          35,        1535,         185,     0.88343,           0],
    [       1575,         913,        1715,        1079,     0.87925,           0],
    [        713,         229,         783,         409,     0.87111,           0],
    [        871,         391,         985,         613,     0.86602,           0],
    [       1347,          70,        1395,         151,     0.81387,           1],
    [        879,          77,         943,         232,     0.80651,           0],
    [       1797,         180,        1864,         367,     0.78584,           0],
    [        886,         482,        1022,         643,     0.78523,           1],
    [        294,         298,         364,         481,     0.71708,           0],
    [        615,          78,         690,         193,     0.71245,          58],
    [        616,        1007,         752,        1080,     0.69813,           0],
    [       1509,          90,        1553,         192,     0.69539,           1],
    [       1646,         686,        1712,         791,     0.69249,          26],
    [       1347,          16,        1402,         129,     0.68978,           0],
    [       1154,           0,        1197,          75,     0.65086,           0],
    [         11,         277,          90,         435,     0.58677,           0],
    [        985,           0,        1032,         127,      0.5406,           0],
    [       1324,           0,        1358,          47,     0.53892,           0],
    [        247,         951,         284,        1023,     0.34445,          26],
    [        284,         398,         315,         456,     0.30317,          26],
    [        593,         244,         802,         417,     0.29631,          13],
    [       1837,         262,        1883,         344,     0.28662,           0],
    [       1806,         208,        1857,         292,     0.26673,          24]]

[[1647, 121, 1716, 284, 0.91595, 0],
 [1606, 609, 1716, 863, 0.90923, 0],
 [249, 787, 374, 1080, 0.89867, 0],
 [801, 245, 857, 422, 0.88594, 0],
 [1452, 35, 1535, 185, 0.88343, 0],
 [1575, 913, 1715, 1079, 0.87925, 0],
 [713, 229, 783, 409, 0.87111, 0],
 [871, 391, 985, 613, 0.86602, 0],
 [1347, 70, 1395, 151, 0.81387, 1],
 [879, 77, 943, 232, 0.80651, 0],
 [1797, 180, 1864, 367, 0.78584, 0],
 [886, 482, 1022, 643, 0.78523, 1],
 [294, 298, 364, 481, 0.71708, 0],
 [615, 78, 690, 193, 0.71245, 58],
 [616, 1007, 752, 1080, 0.69813, 0],
 [1509, 90, 1553, 192, 0.69539, 1],
 [1646, 686, 1712, 791, 0.69249, 26],
 [1347, 16, 1402, 129, 0.68978, 0],
 [1154, 0, 1197, 75, 0.65086, 0],
 [11, 277, 90, 435, 0.58677, 0],
 [985, 0, 1032, 127, 0.5406, 0],
 [1324, 0, 1358, 47, 0.53892, 0],
 [247, 951, 284, 1023, 0.34445, 26],
 [284, 398, 315, 456, 0.30317, 26],
 [593, 244, 802, 417, 0.29631, 13],
 [1837, 262, 1883, 344, 0.28662, 0],
 [1806, 208, 1857, 292, 0.26673, 24]]

In [4]:
@dataclass
class Node:
    id: int = field(default=0)
    x1: int = field(default=0)
    y1: int = field(default=0)
    x2: int = field(default=0)
    y2: int = field(default=0)
    conf: float = field(default=float(0))
    detclass: int = field(default=0)
    class_name: str = field(default="")
    centroid: tuple = field(init=False)
    def __post_init__(self):
        self.centroid = ((self.x1 + self.x2) // 2, (self.y1 + self.y2) // 2)

In [5]:
@dataclass
class Edge:
    weight: Union[float, int] = field(default=0)

In [47]:
def generate_spatial_graph(img, bbox, identities=None, categories=None, confidences = None, names=None, colors = None):
    """
    Construct a spatial graph from the bounding boxes, identities, categories, confidences, names and colors
    """
    graph = nx.Graph()
    for i, box in enumerate(bbox):
        x1, y1, x2, y2 = [int(i) for i in box]

        cat = int(categories[i]) if categories is not None else 0
        id = int(identities[i]) if identities is not None else 0
        conf = confidences[i] if confidences is not None else 0
        class_name = names[cat]
        graph.add_node(Node(id, x1, y1, x2, y2, conf, cat, class_name))

        tl = opt.thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness

        color = colors[cat]
        
        if not opt.nobbox:
            cv2.circle(img, (x1, y1), (x2, y2), color, tl)

        if not opt.nolabel:
            label = f"Node({str(id)}): {names[cat]} {confidences[i]:.2f}" if identities is not None else  f'{names[cat]} {confidences[i]:.2f}'
            tf = max(tl - 1, 1)  # font thickness
            t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
            c2 = x1 + t_size[0], y1 - t_size[1] - 3
            cv2.circle(img, (x1, y1), c2, color, -1, cv2.LINE_AA)  # filled
            cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
    return img, graph

In [None]:
_, graph = generate_spatial_graph(_, bbox, identities, categories, confidences, names, colors)