In [18]:
%load_ext autoreload
%autoreload 2

import os, time
import cv2, random
import pickle, joblib
import sklearn.metrics
import numpy as np
np.set_printoptions(suppress=True)
import gurobipy as gp

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(sci_mode=False)
from torch_geometric.data import Data

from lib.tracking import Tracker
from lib.qpthlocal.qp import QPFunction, QPSolvers
from lib.utils import getIoU, interpolateTrack, interpolateTracks

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


[autoreload of lib.qpthlocal.qp failed: Traceback (most recent call last):
  File "/home/khanh/miniconda3/envs/LPT/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/khanh/miniconda3/envs/LPT/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 500, in superreload
    update_generic(old_obj, new_obj)
  File "/home/khanh/miniconda3/envs/LPT/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 397, in update_generic
    update(a, b)
  File "/home/khanh/miniconda3/envs/LPT/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 349, in update_class
    if update_generic(old_obj, new_obj):
  File "/home/khanh/miniconda3/envs/LPT/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 397, in update_generic
    update(a, b)
  File "/home/khanh/miniconda3/envs/LPT/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 349, in update_clas

In [2]:
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed(123)

train_data_list_full, val_data_list_full = [], []
for file in os.listdir('data/train_data/'):
    file_name = 'data/train_data/' + file
    with open(file_name, 'rb') as f:
        data_list = pickle.load(f)
        
    if file.startswith('MOT16-09') or file.startswith('MOT16-13'):
        val_data_list_full = val_data_list_full + data_list
    else:
        train_data_list_full = train_data_list_full + data_list
        
print('Total {} samples in training set, {} in validation set'.format(
                     len(train_data_list_full),len(val_data_list_full)))

train_data_list, val_data_list = [], []
for ind in range(len(train_data_list_full)):
    if train_data_list_full[ind].x.shape[0] < 200: #Avoid too big graph
        train_data_list.append(train_data_list_full[ind])
    else:
        continue
        
for ind in range(len(val_data_list_full)):
    if val_data_list_full[ind].x.shape[0] < 200:
        val_data_list.append(val_data_list_full[ind])
    else:
        continue
        
print('Used {} samples in training set, {} in validation set'.format(len(train_data_list), len(val_data_list)))

Total 748 samples in training set, 319 in validation set
Used 460 samples in training set, 208 in validation set


In [19]:
class Net(nn.Module): 
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Sequential(nn.Linear(6,6), nn.ReLU(), nn.Linear(6,1))
    def forward(self, data):
        x = self.fc(data.edge_attr)
        x = nn.Sigmoid()(x)
        return x
    
net = Net()
tracker = Tracker(net)

In [None]:
gamma = 0.1
train_list, val_list = [], [] #Used for login loss, AUC, etc.
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999))

for epoch in range(0, 25):
    np.random.shuffle(train_data_list)
    for itr in range(len(train_data_list)):
        data = train_data_list[itr]
        A_eq, b_eq, A_ub, b_ub, x_gt = tracker.build_constraint_training(data)
        
        A, b = torch.from_numpy(A_eq).float(), torch.from_numpy(b_eq).float().flatten()
        G, h = torch.from_numpy(A_ub).float(), torch.from_numpy(b_ub).float().flatten()

        num_nodes = A.shape[0] // 2
        Q = gamma * torch.eye(A.shape[1])

        prob = net(data)
        prob = torch.clamp(prob, min=1e-7, max=1-1e-7)                       #Predicted matching probability
        prob_numpy = prob.detach().numpy()                                   #Predicted matching probability in np
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ], prob_numpy) #Area under the ROC Curve

        c_det, c_entry, c_exit = -1 * torch.ones(num_nodes), torch.ones(num_nodes), torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = tracker.make_gurobi_model_tracking(G.numpy(),h.numpy(),A.numpy(),b.numpy(),Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50, 
                       model_params=model_params_quad)(Q, c_pred, G, h, A, b)

        loss = nn.MSELoss()(x.flatten(), torch.from_numpy(x_gt))
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:].flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())

        const_cost = 1
        c_ = const_cost * (1 - x_gt[num_nodes*3:])
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()]) #Ground truth cost

        obj_gt = c_gt @ torch.from_numpy(x_gt).float()               #Ground truth objective value
        obj_pred = c_pred @ x.squeeze() #Predicted objective value, should be close to GT objective value after training

        bce = nn.BCELoss()(prob.flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())
        x_sol = tracker.linprog(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_loss = sklearn.metrics.hamming_loss(x_gt, x_sol)

        train_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        print('epoch {} it [{}/{}] pr [{:.2f}-{:.2f}] obj [{:.2f}/{:.2f}] mse {:.4f} mse edge {:.4f} ce {:.3f} \
auc {:.3f} ham {:.3f}'.format(epoch, itr, len(train_data_list), prob.min(), prob.max(), 
                              obj_pred, obj_gt, loss, loss_edge, bce, auc, ham_loss))
        
        optimizer.zero_grad()
        loss_edge.backward()
        #bce.backward()
        optimizer.step()
        
    print('Saving epoch {} ...\n'.format(epoch))
    torch.save(net.state_dict(), 'ckpt/qp/epoch-{}.pth'.format(epoch))
    
    np.random.shuffle(val_data_list)
    for itr in range(len(val_data_list)):
        
        val_data = val_data_list[itr]
        A_eq, b_eq, A_ub, b_ub, x_gt = tracker.build_constraint_training(val_data)
        
        A, b = torch.from_numpy(A_eq).float(), torch.from_numpy(b_eq).float().flatten()
        G, h = torch.from_numpy(A_ub).float(), torch.from_numpy(b_ub).float().flatten()

        num_nodes = A.shape[0] // 2
        Q = gamma * torch.eye(A.shape[1])
        
        with torch.no_grad():
            prob = net(val_data)
            prob = torch.clamp(prob, min=1e-7, max=1-1e-7)
        
        prob_numpy = prob.detach().numpy() #matching probability in numpy
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ], prob_numpy) #Area under the ROC Curve

        c_det, c_entry, c_exit = -1 * torch.ones(num_nodes), torch.ones(num_nodes), torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = tracker.make_gurobi_model_tracking(G.numpy(),h.numpy(),A.numpy(),b.numpy(),Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50, 
                       model_params=model_params_quad)(Q, c_pred, G, h, A, b)

        loss = nn.MSELoss()(x.flatten(), torch.from_numpy(x_gt))
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:].flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())

        const_cost = 1
        c_ = const_cost * (1 - x_gt[num_nodes*3:])
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()]) #Ground truth cost

        obj_gt = c_gt @ torch.from_numpy(x_gt).float()               #Ground truth objective value
        obj_pred = c_pred @ x.squeeze() #Predicted objective value, should be close to GT objective value after training

        bce = nn.BCELoss()(prob.flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())
        x_sol = tracker.linprog(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_loss = sklearn.metrics.hamming_loss(x_gt, x_sol)
        val_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        print('vepo {} it [{}/{}] pr [{:.2f}-{:.2f}] obj [{:.2f}/{:.2f}] mse {:.4f} mse edge {:.4f} ce {:.3f} \
auc {:.3f} ham {:.3f}'.format(epoch, itr, len(val_data_list), prob.min(), prob.max(), 
                              obj_pred, obj_gt, loss, loss_edge, bce, auc, ham_loss))

epoch 0 it [0/460] pr [0.35-0.45] obj [-2.08/-105.00] mse 0.1612 mse edge 0.1341 ce 0.597 auc 0.033 ham 0.184


LU, pivots = torch.lu(A, compute_pivots)
should be replaced with
LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
and
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)
should be replaced with
LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1990.)
  LU, pivots, infos = torch._lu_with_info(
Note that torch.linalg.lu_solve has its arguments reversed.
X = torch.lu_solve(B, LU, pivots)
should be replaced with
X = torch.linalg.lu_solve(LU, pivots, B) (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2147.)
  G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU))


epoch 0 it [1/460] pr [0.35-0.46] obj [-2.41/-82.00] mse 0.1750 mse edge 0.1479 ce 0.609 auc 0.020 ham 0.210
epoch 0 it [2/460] pr [0.36-0.47] obj [-4.00/-111.00] mse 0.1415 mse edge 0.1251 ce 0.611 auc 0.000 ham 0.178
epoch 0 it [3/460] pr [0.30-0.47] obj [-3.40/-123.00] mse 0.1305 mse edge 0.0985 ce 0.574 auc 0.079 ham 0.159
epoch 0 it [4/460] pr [0.36-0.47] obj [-2.23/-66.00] mse 0.2053 mse edge 0.2137 ce 0.648 auc 0.000 ham 0.254
epoch 0 it [5/460] pr [0.37-0.47] obj [-4.18/-102.00] mse 0.1499 mse edge 0.1375 ce 0.616 auc 0.000 ham 0.190
epoch 0 it [6/460] pr [0.37-0.46] obj [-1.01/-40.00] mse 0.2760 mse edge 0.2701 ce 0.663 auc 0.024 ham 0.318
epoch 0 it [7/460] pr [0.21-0.47] obj [-3.30/-150.00] mse 0.1234 mse edge 0.0873 ce 0.511 auc 0.499 ham 0.141
epoch 0 it [8/460] pr [0.23-0.47] obj [-3.59/-149.00] mse 0.1245 mse edge 0.0894 ce 0.516 auc 0.484 ham 0.143
epoch 0 it [9/460] pr [0.30-0.48] obj [-3.15/-80.00] mse 0.1720 mse edge 0.1529 ce 0.615 auc 0.044 ham 0.215
epoch 0 it [10

In [None]:
torch_geometric.__version__