In [1]:
%load_ext autoreload
%autoreload 2

import os, time
import cv2, random
import pickle, joblib
import sklearn.metrics
import numpy as np
np.set_printoptions(suppress=True)
import gurobipy as gp

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(sci_mode=False)
from torch_geometric.data import Data

from lib.tracking import Tracker
from lib.qpthlocal.qp import QPFunction, QPSolvers
from lib.utils import getIoU, interpolateTrack, interpolateTracks

In [2]:
# random.seed(123)
# np.random.seed(123)
# torch.manual_seed(123)
# torch.cuda.manual_seed(123)

# train_data_list_full, val_data_list_full = [], []
# # root_data_path = "data/train_data/"
# root_data_path = "reproduce_dataset/"
# for file in os.listdir(root_data_path):
#     file_name = root_data_path + file
#     with open(file_name, 'rb') as f:
#         data_list = pickle.load(f)
        
#     if file.startswith('MOT16-09') or file.startswith('MOT16-13'):
#         val_data_list_full = val_data_list_full + data_list
#     else:
#         train_data_list_full = train_data_list_full + data_list
        
# print('Total {} samples in training set, {} in validation set'.format(
#                      len(train_data_list_full),len(val_data_list_full)))

# train_data_list, val_data_list = [], []
# for ind in range(len(train_data_list_full)):
#     if train_data_list_full[ind].x.shape[0] < 200: #Avoid too big graph
#         train_data_list.append(train_data_list_full[ind])
#     else:
#         continue
        
# for ind in range(len(val_data_list_full)):
#     if val_data_list_full[ind].x.shape[0] < 200:
#         val_data_list.append(val_data_list_full[ind])
#     else:
#         continue
        
# print('Used {} samples in training set, {} in validation set'.format(len(train_data_list), len(val_data_list)))

Total 584 samples in training set, 319 in validation set
Used 375 samples in training set, 36 in validation set


In [3]:
# train_vid = ['uav0000013_00000_v.pkl',
#  'uav0000013_01073_v.pkl',
#  'uav0000013_01392_v.pkl',
#  'uav0000020_00406_v.pkl',
#  'uav0000071_03240_v.pkl',
#  'uav0000072_04488_v.pkl']
# test_vid = ['uav0000013_00000_v.pkl']


In [9]:
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed(123)

train_data_list_full, val_data_list_full = [], []
train_root_path = "VisDrone/train_graph/"
val_root_path = "VisDrone/val_graph/"

for file in os.listdir(train_root_path):
    file_name = train_root_path + file
    with open(file_name, 'rb') as f:
        data_list = pickle.load(f)
    train_data_list_full = train_data_list_full + data_list

for file in os.listdir(val_root_path):
    file_name = val_root_path + file
    with open(file_name, 'rb') as f:
        data_list = pickle.load(f)
    val_data_list_full = val_data_list_full + data_list


# # root_data_path = "data/train_data/"
# root_data_path = "reproduce_dataset/"
# # for file in os.listdir(root_data_path):
# for file in train_vid + test_vid:
#     file_name = root_data_path + file
#     with open(file_name, 'rb') as f:
#         data_list = pickle.load(f)
        
#     if file in test_vid:
#         val_data_list_full = val_data_list_full + data_list
#     else:
#         train_data_list_full = train_data_list_full + data_list
        
print('Total {} samples in training set, {} in validation set'.format(
                     len(train_data_list_full),len(val_data_list_full)))

train_data_list, val_data_list = [], []
for ind in range(len(train_data_list_full)):
    if train_data_list_full[ind].x.shape[0] < 200: #Avoid too big graph
        train_data_list.append(train_data_list_full[ind])
    else:
        continue
        
for ind in range(len(val_data_list_full)):
    if val_data_list_full[ind].x.shape[0] < 200:
        val_data_list.append(val_data_list_full[ind])
    else:
        continue
        
print('Used {} samples in training set, {} in validation set'.format(len(train_data_list), len(val_data_list)))

Total 6052 samples in training set, 712 in validation set
Used 861 samples in training set, 63 in validation set


In [10]:
def build_constraint_training(self, data, max_frame_gap = 1):
    """
    Build constraints for end-to-end training.
    """
    ground_truth = np.concatenate((data.ground_truth, np.arange(data.ground_truth.shape[0])[:, None]), axis=1)
    timestamps = ground_truth[:, 0].astype(int)
    edges = data.edge_index.T.numpy()
    num_nodes = data.x.shape[0]
    entry_offset, exit_offset, link_offset = num_nodes, num_nodes * 2, num_nodes * 3
    det_gt = np.ones(num_nodes, dtype = np.int32)
    entry_gt = np.zeros(num_nodes, dtype = np.int32)
    exit_gt = np.zeros(num_nodes, dtype = np.int32)
    tran_gt = np.zeros(data.x.shape[0], dtype = np.int32)
    tran_indicator = np.ones_like(tran_gt)

    linkIndexGraph = np.zeros((num_nodes, num_nodes), dtype = np.int32)
    edge_ind = 0
    for src_node in range(num_nodes):    
        for dst_node in range(num_nodes):
            frame_gap = timestamps[dst_node] - timestamps[src_node]
            if frame_gap >= 1 and frame_gap <= max_frame_gap:
    #             print('src_node {} dst_node {}, edge_ind {} frame_gap {} GT conn {}'.format(
    #                   src_node, dst_node, edge_ind, frame_gap, data.y[edge_ind]))

                tran_gt[edge_ind] = data.y[edge_ind]
                edge_ind += 1
                linkIndexGraph[src_node, dst_node] = edge_ind #starts from 1

    assert linkIndexGraph.max() == edge_ind
    assert tran_indicator.sum() == edge_ind

    entry_gt = np.zeros(num_nodes, dtype = np.int32)
    exit_gt = np.zeros(num_nodes, dtype = np.int32)
    for i in np.unique(ground_truth[:, 1]):
        inds = np.where(ground_truth[:, 1] == i)[0]
        entry_gt[inds[0]] = 1
        exit_gt[inds[-1]] = 1
    assert entry_gt.sum() == exit_gt.sum()
    assert entry_gt.sum() == np.unique(ground_truth[:, 1]).__len__(), 'Ground Truth Error!'
    A_eq, b_eq, A_ub, b_ub = self.build_constraint(linkIndexGraph)
    x_gt = np.concatenate((det_gt, entry_gt, exit_gt, tran_gt))

    return A_eq, b_eq, A_ub, b_ub, x_gt

In [11]:
class Net(nn.Module): 
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Sequential(nn.Linear(6,6), nn.ReLU(), nn.Linear(6,1))
    def forward(self, data):
        # x = self.fc(data.edge_attr)
        x = self.fc(torch.tensor(data.edge_attr, dtype=torch.float32))
        x = nn.Sigmoid()(x)
        return x
    
net = Net()
tracker = Tracker(net)

In [None]:
gamma = 0.1
train_list, val_list = [], [] #Used for login loss, AUC, etc.
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999))

for epoch in range(0, 25):
    np.random.shuffle(train_data_list)
    for itr in range(len(train_data_list)):
        data = train_data_list[itr]
        A_eq, b_eq, A_ub, b_ub, x_gt = tracker.build_constraint_training(data)
        
        A, b = torch.from_numpy(A_eq).float(), torch.from_numpy(b_eq).float().flatten()
        G, h = torch.from_numpy(A_ub).float(), torch.from_numpy(b_ub).float().flatten()

        num_nodes = A.shape[0] // 2
        Q = gamma * torch.eye(A.shape[1])

        prob = net(data)
        prob = torch.clamp(prob, min=1e-7, max=1-1e-7)                       #Predicted matching probability
        prob_numpy = prob.detach().numpy()                                   #Predicted matching probability in np
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ], prob_numpy) #Area under the ROC Curve

        c_det, c_entry, c_exit = -1 * torch.ones(num_nodes), torch.ones(num_nodes), torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = tracker.make_gurobi_model_tracking(G.numpy(),h.numpy(),A.numpy(),b.numpy(),Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50, 
                       model_params=model_params_quad)(Q, c_pred, G, h, A, b)

        loss = nn.MSELoss()(x.flatten(), torch.from_numpy(x_gt))
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:].flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())

        const_cost = 1
        c_ = const_cost * (1 - x_gt[num_nodes*3:])
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()]) #Ground truth cost

        obj_gt = c_gt @ torch.from_numpy(x_gt).float()               #Ground truth objective value
        obj_pred = c_pred @ x.squeeze() #Predicted objective value, should be close to GT objective value after training

        bce = nn.BCELoss()(prob.flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())
        x_sol = tracker.linprog(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_loss = sklearn.metrics.hamming_loss(x_gt, x_sol)

        train_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        print('epoch {} it [{}/{}] pr [{:.2f}-{:.2f}] obj [{:.2f}/{:.2f}] mse {:.4f} mse edge {:.4f} ce {:.3f} \
auc {:.3f} ham {:.3f}'.format(epoch, itr, len(train_data_list), prob.min(), prob.max(), 
                              obj_pred, obj_gt, loss, loss_edge, bce, auc, ham_loss))
        
        optimizer.zero_grad()
        loss_edge.backward()
        #bce.backward()
        optimizer.step()
        
    print('Saving epoch {} ...\n'.format(epoch))
    torch.save(net.state_dict(), 'ckpt/visdrone/epoch-{}.pth'.format(epoch))
    
    np.random.shuffle(val_data_list)
    for itr in range(len(val_data_list)):
        
        val_data = val_data_list[itr]
        A_eq, b_eq, A_ub, b_ub, x_gt = tracker.build_constraint_training(val_data)
        
        A, b = torch.from_numpy(A_eq).float(), torch.from_numpy(b_eq).float().flatten()
        G, h = torch.from_numpy(A_ub).float(), torch.from_numpy(b_ub).float().flatten()

        num_nodes = A.shape[0] // 2
        Q = gamma * torch.eye(A.shape[1])
        
        with torch.no_grad():
            prob = net(val_data)
            prob = torch.clamp(prob, min=1e-7, max=1-1e-7)
        
        prob_numpy = prob.detach().numpy() #matching probability in numpy
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ], prob_numpy) #Area under the ROC Curve

        c_det, c_entry, c_exit = -1 * torch.ones(num_nodes), torch.ones(num_nodes), torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = tracker.make_gurobi_model_tracking(G.numpy(),h.numpy(),A.numpy(),b.numpy(),Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50, 
                       model_params=model_params_quad)(Q, c_pred, G, h, A, b)

        loss = nn.MSELoss()(x.flatten(), torch.from_numpy(x_gt))
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:].flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())

        const_cost = 1
        c_ = const_cost * (1 - x_gt[num_nodes*3:])
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()]) #Ground truth cost

        obj_gt = c_gt @ torch.from_numpy(x_gt).float()               #Ground truth objective value
        obj_pred = c_pred @ x.squeeze() #Predicted objective value, should be close to GT objective value after training

        bce = nn.BCELoss()(prob.flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())
        x_sol = tracker.linprog(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_loss = sklearn.metrics.hamming_loss(x_gt, x_sol)
        val_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        print('vepo {} it [{}/{}] pr [{:.2f}-{:.2f}] obj [{:.2f}/{:.2f}] mse {:.4f} mse edge {:.4f} ce {:.3f} \
auc {:.3f} ham {:.3f}'.format(epoch, itr, len(val_data_list), prob.min(), prob.max(), 
                              obj_pred, obj_gt, loss, loss_edge, bce, auc, ham_loss))

Set parameter Username
Set parameter LicenseID to value 2641140
Academic license - for non-commercial use only - expires 2026-03-24
epoch 0 it [0/861] pr [0.59-0.64] obj [-55.02/-124.00] mse 0.0775 mse edge 0.1035 ce 0.867 auc 0.513 ham 0.102


LU, pivots = torch.lu(A, compute_pivots)
should be replaced with
LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
and
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)
should be replaced with
LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1990.)
  LU, pivots, infos = torch._lu_with_info(
Note that torch.linalg.lu_solve has its arguments reversed.
X = torch.lu_solve(B, LU, pivots)
should be replaced with
X = torch.linalg.lu_solve(LU, pivots, B) (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2147.)
  G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU))


epoch 0 it [1/861] pr [0.59-0.64] obj [-29.52/-66.00] mse 0.1045 mse edge 0.1612 ce 0.856 auc 0.305 ham 0.216
epoch 0 it [2/861] pr [0.58-0.64] obj [-55.20/-122.00] mse 0.0816 mse edge 0.1096 ce 0.870 auc 0.391 ham 0.147
epoch 0 it [3/861] pr [0.59-0.68] obj [-76.73/-165.00] mse 0.0724 mse edge 0.0876 ce 0.913 auc 0.186 ham 0.130
epoch 0 it [4/861] pr [0.59-0.67] obj [-51.20/-109.00] mse 0.0921 mse edge 0.1286 ce 0.885 auc 0.287 ham 0.117
epoch 0 it [5/861] pr [0.59-0.71] obj [-29.87/-60.00] mse 0.1638 mse edge 0.2784 ce 0.882 auc 0.175 ham 0.265
epoch 0 it [6/861] pr [0.59-0.72] obj [-29.28/-56.00] mse 0.1761 mse edge 0.2827 ce 0.897 auc 0.068 ham 0.273
epoch 0 it [7/861] pr [0.58-0.70] obj [-34.33/-70.00] mse 0.1367 mse edge 0.2242 ce 0.876 auc 0.200 ham 0.244
epoch 0 it [8/861] pr [0.59-0.69] obj [-48.73/-98.00] mse 0.1054 mse edge 0.1535 ce 0.910 auc 0.163 ham 0.196
epoch 0 it [9/861] pr [0.59-0.66] obj [-42.96/-91.00] mse 0.1021 mse edge 0.1455 ce 0.875 auc 0.382 ham 0.131
epoch 0