In [19]:
%load_ext autoreload
%autoreload 2

import os, time
import cv2, random
import pickle, joblib
import sklearn.metrics
import numpy as np
np.set_printoptions(suppress=True)
import gurobipy as gp

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(sci_mode=False)
from torch_geometric.data import Data

from lib.tracking import Tracker
from lib.qpthlocal.qp import QPFunction, QPSolvers
from lib.utils import getIoU, interpolateTrack, interpolateTracks

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
os.listdir("reproduce_train_graph_MOT16/")

['MOT16-02.pkl',
 'MOT16-04.pkl',
 'MOT16-05.pkl',
 'MOT16-09.pkl',
 'MOT16-10.pkl',
 'MOT16-11.pkl',
 'MOT16-13.pkl']

In [21]:
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed(123)

train_data_list_full, val_data_list_full = [], []
# root_data_path = "data/train_data/"
root_data_path = "reproduce_train_graph_MOT16_reidFeature/"
for file in os.listdir(root_data_path):
    file_name = root_data_path + file
    with open(file_name, 'rb') as f:
        data_list = pickle.load(f)
        
    if file.startswith('MOT16-09') or file.startswith('MOT16-13'):
        val_data_list_full = val_data_list_full + data_list
    else:
        train_data_list_full = train_data_list_full + data_list
        
print('Total {} samples in training set, {} in validation set'.format(
                     len(train_data_list_full),len(val_data_list_full)))

train_data_list, val_data_list = [], []
for ind in range(len(train_data_list_full)):
    if train_data_list_full[ind].x.shape[0] < 200: #Avoid too big graph
        train_data_list.append(train_data_list_full[ind])
    else:
        continue
        
for ind in range(len(val_data_list_full)):
    if val_data_list_full[ind].x.shape[0] < 200:
        val_data_list.append(val_data_list_full[ind])
    else:
        continue
        
print('Used {} samples in training set, {} in validation set'.format(len(train_data_list), len(val_data_list)))

Total 1011 samples in training set, 319 in validation set
Used 379 samples in training set, 36 in validation set


In [23]:
class Net(nn.Module): 
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Sequential(nn.Linear(6,6), nn.ReLU(), nn.Linear(6,1))
        # self.fc = nn.Sequential(nn.Linear(6,1))
    def forward(self, data):
        # x = self.fc(data.edge_attr)
        x = self.fc(torch.tensor(data.edge_attr, dtype=torch.float32))
        x = nn.Sigmoid()(x)
        return x
    
net = Net()
tracker = Tracker(net)

In [None]:
gamma = 0.1
train_list, val_list = [], [] #Used for login loss, AUC, etc.
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999))

for epoch in range(0, 25):
    np.random.shuffle(train_data_list)
    for itr in range(len(train_data_list)):
        data = train_data_list[itr]
        A_eq, b_eq, A_ub, b_ub, x_gt = tracker.build_constraint_training(data)
        
        A, b = torch.from_numpy(A_eq).float(), torch.from_numpy(b_eq).float().flatten()
        G, h = torch.from_numpy(A_ub).float(), torch.from_numpy(b_ub).float().flatten()

        num_nodes = A.shape[0] // 2
        Q = gamma * torch.eye(A.shape[1])

        prob = net(data)
        prob = torch.clamp(prob, min=1e-7, max=1-1e-7)                       #Predicted matching probability
        prob_numpy = prob.detach().numpy()                                   #Predicted matching probability in np
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ], prob_numpy) #Area under the ROC Curve

        c_det, c_entry, c_exit = -1 * torch.ones(num_nodes), torch.ones(num_nodes), torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = tracker.make_gurobi_model_tracking(G.numpy(),h.numpy(),A.numpy(),b.numpy(),Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50, 
                       model_params=model_params_quad)(Q, c_pred, G, h, A, b)

        loss = nn.MSELoss()(x.flatten(), torch.from_numpy(x_gt))
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:].flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())

        const_cost = 1
        c_ = const_cost * (1 - x_gt[num_nodes*3:])
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()]) #Ground truth cost

        obj_gt = c_gt @ torch.from_numpy(x_gt).float()               #Ground truth objective value
        obj_pred = c_pred @ x.squeeze() #Predicted objective value, should be close to GT objective value after training

        bce = nn.BCELoss()(prob.flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())
        x_sol = tracker.linprog(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_loss = sklearn.metrics.hamming_loss(x_gt, x_sol)

        train_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        print('epoch {} it [{}/{}] pr [{:.2f}-{:.2f}] obj [{:.2f}/{:.2f}] mse {:.4f} mse edge {:.4f} ce {:.3f} \
auc {:.3f} ham {:.3f}'.format(epoch, itr, len(train_data_list), prob.min(), prob.max(), 
                              obj_pred, obj_gt, loss, loss_edge, bce, auc, ham_loss))
        
        optimizer.zero_grad()
        loss_edge.backward()
        #bce.backward()
        optimizer.step()
        
    print('Saving epoch {} ...\n'.format(epoch))
    torch.save(net.state_dict(), 'ckpt/mot16_reid_feature/epoch-{}.pth'.format(epoch))
    
    np.random.shuffle(val_data_list)
    for itr in range(len(val_data_list)):
        
        val_data = val_data_list[itr]
        A_eq, b_eq, A_ub, b_ub, x_gt = tracker.build_constraint_training(val_data)
        
        A, b = torch.from_numpy(A_eq).float(), torch.from_numpy(b_eq).float().flatten()
        G, h = torch.from_numpy(A_ub).float(), torch.from_numpy(b_ub).float().flatten()

        num_nodes = A.shape[0] // 2
        Q = gamma * torch.eye(A.shape[1])
        
        with torch.no_grad():
            prob = net(val_data)
            prob = torch.clamp(prob, min=1e-7, max=1-1e-7)
        
        prob_numpy = prob.detach().numpy() #matching probability in numpy
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ], prob_numpy) #Area under the ROC Curve

        c_det, c_entry, c_exit = -1 * torch.ones(num_nodes), torch.ones(num_nodes), torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = tracker.make_gurobi_model_tracking(G.numpy(),h.numpy(),A.numpy(),b.numpy(),Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50, 
                       model_params=model_params_quad)(Q, c_pred, G, h, A, b)

        loss = nn.MSELoss()(x.flatten(), torch.from_numpy(x_gt))
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:].flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())

        const_cost = 1
        c_ = const_cost * (1 - x_gt[num_nodes*3:])
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()]) #Ground truth cost

        obj_gt = c_gt @ torch.from_numpy(x_gt).float()               #Ground truth objective value
        obj_pred = c_pred @ x.squeeze() #Predicted objective value, should be close to GT objective value after training

        bce = nn.BCELoss()(prob.flatten(), torch.from_numpy(x_gt[num_nodes*3:]).float())
        x_sol = tracker.linprog(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_loss = sklearn.metrics.hamming_loss(x_gt, x_sol)
        val_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        print('vepo {} it [{}/{}] pr [{:.2f}-{:.2f}] obj [{:.2f}/{:.2f}] mse {:.4f} mse edge {:.4f} ce {:.3f} \
auc {:.3f} ham {:.3f}'.format(epoch, itr, len(val_data_list), prob.min(), prob.max(), 
                              obj_pred, obj_gt, loss, loss_edge, bce, auc, ham_loss))

epoch 0 it [0/379] pr [0.59-0.60] obj [-32.98/-76.00] mse 0.0931 mse edge 0.1452 ce 0.829 auc 0.758 ham 0.100
epoch 0 it [1/379] pr [0.58-0.61] obj [-59.01/-135.00] mse 0.0693 mse edge 0.0904 ce 0.858 auc 0.764 ham 0.081
epoch 0 it [2/379] pr [0.58-0.61] obj [-67.53/-156.00] mse 0.0641 mse edge 0.0790 ce 0.868 auc 0.561 ham 0.096
epoch 0 it [3/379] pr [0.58-0.61] obj [-55.91/-126.00] mse 0.0759 mse edge 0.1029 ce 0.860 auc 0.531 ham 0.143
epoch 0 it [4/379] pr [0.58-0.66] obj [-66.04/-147.00] mse 0.0681 mse edge 0.0882 ce 0.875 auc 0.590 ham 0.077
epoch 0 it [5/379] pr [0.59-0.60] obj [-47.17/-104.00] mse 0.0795 mse edge 0.1070 ce 0.848 auc 0.803 ham 0.097
epoch 0 it [6/379] pr [0.59-0.62] obj [-49.70/-114.00] mse 0.0788 mse edge 0.1079 ce 0.856 auc 0.645 ham 0.086
epoch 0 it [7/379] pr [0.58-0.62] obj [-62.19/-140.00] mse 0.0690 mse edge 0.0911 ce 0.864 auc 0.637 ham 0.091
epoch 0 it [8/379] pr [0.59-0.61] obj [-70.16/-158.00] mse 0.0614 mse edge 0.0752 ce 0.867 auc 0.774 ham 0.069


# SHAP analysis

In [58]:
class ShapNet(nn.Module): 
    def __init__(self):
        super(ShapNet, self).__init__()
        self.fc = nn.Sequential(nn.Linear(6,6), nn.ReLU(), nn.Linear(6,1))
    def forward(self, x):
        # x = self.fc(data.edge_attr)
        x = self.fc(x)
        x = nn.Sigmoid()(x)
        return x

In [62]:
net = ShapNet()
net.load_state_dict(torch.load('ckpt/original/epoch_20.pth'))

  net.load_state_dict(torch.load('ckpt/original/epoch_20.pth'))


<All keys matched successfully>

In [78]:
with torch.no_grad():
    out = net(torch.tensor(data_list[0].edge_attr, dtype=torch.float32))

In [77]:
# explainer = shap.DeepExplainer(net, torch.tensor(data_list[0].edge_attr, dtype=torch.float32))
shap_values = explainer.shap_values(torch.tensor(data_list[0].edge_attr[:100], dtype=torch.float32), check_additivity=True)

AssertionError: The SHAP explanations do not sum up to the model's output! This is either because of a rounding error or because an operator in your computation graph was not fully supported. If the sum difference of %f is significant compared to the scale of your model outputs, please post as a github issue, with a reproducible example so we can debug it. Used framework: pytorch - Max. diff: 0.643880556570366 - Tolerance: 0.01

In [79]:
out

tensor([[    0.9967],
        [    0.0002],
        [    0.0000],
        ...,
        [    0.0016],
        [    0.0010],
        [    0.9962]])

In [55]:
net = ShapNet()
net.load_state_dict(torch.load('ckpt/mot1/epoch-20.pth'))
net.fc[0].weight

  net.load_state_dict(torch.load('ckpt/mot16_simple/epoch-20.pth'))


Parameter containing:
tensor([[-0.3524, -6.1803,  2.2318, -5.5801,  4.4129,  2.2822]],
       requires_grad=True)

In [36]:
net

ShapNet(
  (fc): Sequential(
    (0): Linear(in_features=6, out_features=6, bias=True)
    (1): ReLU()
    (2): Linear(in_features=6, out_features=1, bias=True)
  )
)