# Import libraries

In [1]:
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.insert(0, 'lib/reid')
import torchreid

import numpy as np
np.set_printoptions(suppress=True)     #I hate scientific numbers!
import sklearn, pickle, random, time, datetime, cv2, os, joblib

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(sci_mode=False) #I hate scientific numbers!
import torchvision.models as models
from torch_geometric.data import Data

import gurobipy as gp
from lib.train import build_constraint, build_constraint_torch, make_gurobi_model_tracking, _remove_redundant_rows
from lib.inference import forwardLP
from lib.qpthlocal.qp import QPFunction, QPSolvers, make_gurobi_model
#from lib.gnn import ReID
#from lib.qpthlocal.qp_cuda import QPFunction, QPSolvers, make_gurobi_model #If we want to fine-tune reid
from lib.utils import getIoU

import matplotlib.pyplot as plt
plt.style.use('bmh')
%matplotlib inline

torch.random.manual_seed(123)
np.random.seed(123)

# Prepare data used for training

In [2]:
train_data_list_full = []
val_data_list_full = []
for file in os.listdir('data/train_data/'):
    file_name = 'data/train_data/' + file
    with open(file_name, 'rb') as f:
        data_list = pickle.load(f)
        
    if file.startswith('MOT16-09') or file.startswith('MOT16-13'):
        val_data_list_full = val_data_list_full + data_list
    else:
        train_data_list_full = train_data_list_full + data_list
        
print('{} samples in training set, {} in validation set'.format(
                     len(train_data_list_full),len(val_data_list_full)))

train_data_list = []
val_data_list = []
for ind in range(len(train_data_list_full)):
    if train_data_list_full[ind].x.shape[0] < 200:
        train_data_list.append(train_data_list_full[ind])
    else:
        continue
        
for ind in range(len(val_data_list_full)):
    if val_data_list_full[ind].x.shape[0] < 200:
        val_data_list.append(val_data_list_full[ind])
    else:
        continue
        
print('{} samples in training set, {} in validation set'.format(len(train_data_list), len(val_data_list)))

748 samples in training set, 319 in validation set
460 samples in training set, 208 in validation set


In [3]:
class Model(nn.Module): 
    def __init__(self):
        super(Model, self).__init__()
        self.fc = nn.Sequential(nn.Linear(6,6), nn.ReLU(), nn.Linear(6,1))
    def forward(self, data):
        x = self.fc(data.edge_attr)
        x = nn.Sigmoid()(x)
        return x

In [None]:
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999))
#optimizer = torch.optim.Adam(model.parameters(), lr=3e-3, weight_decay=1e-4, betas=(0.9, 0.999))

train_list, val_list = [], []
gamma = 0.1

for epoch in range(1, 25):
    #Shuffle data list
    np.random.shuffle(train_data_list)
    for iteration in range(len(train_data_list)):
        train_data = train_data_list[iteration]
        
        A_eq, b_eq, A_ub, b_ub, x_gt, tran_indicator = build_constraint(train_data, 1)
        A, b, G, h = build_constraint_torch(A_eq, b_eq, A_ub, b_ub)
        num_nodes = int(A_eq.shape[0] / 2)
        Q = gamma * torch.eye(G.shape[1])
    
        prob = model(train_data)
        prob = torch.clamp(prob, min=1e-7, max=1-1e-7)
        prob_numpy = prob.detach().squeeze().numpy()
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ].squeeze(), prob_numpy)

        c_entry, c_exit, c_det = torch.ones(num_nodes), torch.ones(num_nodes), -1 * torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = make_gurobi_model_tracking(G.numpy(), h.numpy(), A.numpy(), b.numpy(), Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50, 
                       model_params=model_params_quad)(Q,c_pred,G,h,A,b)

        loss = nn.MSELoss()(x, torch.from_numpy(x_gt).float().t())
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:], torch.from_numpy(x_gt[num_nodes*3:]).float().t())
        
        const_cost = 1
        c_ = 0 * x_gt[num_nodes*3:] + const_cost * (1 - x_gt[num_nodes*3:])
        c_ = c_.squeeze()
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()]) #this is the ground truth cost
        
        obj_gt = c_gt @ torch.from_numpy(x_gt.squeeze()).float()
        obj_pred = c_pred @ x.squeeze()

        bce = nn.BCELoss()(prob, torch.from_numpy(x_gt[num_nodes*3:]).float())
        x_sol = forwardLP(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_dist = sklearn.metrics.hamming_loss(x_gt, x_sol)
        train_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        print('Train Epoch {} iter {}/{}, Obj {:.2f}/{:.2f}, mse {:.4f} mse edge {:.4f} ce {:.3f} auc {:.3f} hamming {:.3f}'.format(
                    epoch, iteration, len(train_data_list), obj_pred.item(), obj_gt.item(), loss.item(), loss_edge.item(), bce.item(), auc, ham_dist))
        
        optimizer.zero_grad()
        loss_edge.backward()
        #bce.backward()
        optimizer.step()
    torch.save(model.state_dict(), 'ckpt/epoch_{}.pth'.format(epoch))
    
    np.random.shuffle(val_data_list)
    for iteration in range(len(val_data_list)):
        val_data = val_data_list[iteration]
        A_eq, b_eq, A_ub, b_ub, x_gt, tran_indicator = build_constraint(val_data, 1)
        A, b, G, h = build_constraint_torch(A_eq, b_eq, A_ub, b_ub)
        Q = gamma*torch.eye(G.shape[1])
        num_nodes = int(A_eq.shape[0] / 2)
        
        with torch.no_grad():
            prob = model(val_data)
            prob = torch.clamp(prob, min=1e-7, max=1-1e-7)
        prob_numpy = prob.detach().squeeze().numpy()
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ].squeeze(), prob_numpy)

        c_entry, c_exit, c_det = torch.ones(num_nodes), torch.ones(num_nodes), -1 * torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = make_gurobi_model_tracking(G.numpy(), h.numpy(), A.numpy(), b.numpy(), Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50,
                       model_params=model_params_quad)(Q, c_pred, G, h, A, b)

        loss = nn.MSELoss()(x, torch.from_numpy(x_gt).float().t())
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:], torch.from_numpy(x_gt[num_nodes*3:]).float().t())
        
        const_cost = 1
        c_ = 0 * x_gt[num_nodes*3:] + const_cost * (1 - x_gt[num_nodes*3:])
        c_ = c_.squeeze()
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()]) #this is the ground truth cost
        
        obj_gt = c_gt @ torch.from_numpy(x_gt.squeeze()).float()
        obj_pred = c_pred @ x.squeeze()

        bce = nn.BCELoss()(prob, torch.from_numpy(x_gt[num_nodes*3:]).float())
        x_sol = forwardLP(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_dist = sklearn.metrics.hamming_loss(x_gt, x_sol)
        val_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        print('Val Epoch {}, iter {}/{}, Obj {:.2f}/{:.2f}, mse {:.3f} mse edge {:.3f} ce {:.3f} auc {:.3f} hamming {:.3f}'.format(
                    epoch, iteration, len(val_data_list), obj_pred.item(), obj_gt.item(), loss.item(), loss_edge.item(), 
                    bce.item(), auc, ham_dist))

In [None]:
num_train = len(train_data_list)
num_val = len(val_data_list)

# loss, loss_edge, auc, bce
train_array = np.array(train_list)
val_array = np.array(val_list)

train_loss_list = []
train_loss_edge_list = []
train_auc_list = []
train_bce_list = [] 
for ind in range(epoch):
    #print('Epoch {}: {} to {}'.format(ind, ind*num_train, (ind+1)*train))
    train_loss_list.append(     train_array[ind*num_train:(ind+1)*num_train, :][:, 0].mean())
    train_loss_edge_list.append(train_array[ind*num_train:(ind+1)*num_train, :][:, 1].mean())
    train_auc_list.append(      train_array[ind*num_train:(ind+1)*num_train, :][:, 2].mean())
    train_bce_list.append(      train_array[ind*num_train:(ind+1)*num_train, :][:, 3].mean())
    
val_loss_list = []
val_loss_edge_list = []
val_auc_list = []
val_bce_list = []
for ind in range(epoch):
    #print('Epoch {}: {} to {}'.format(ind, ind*num_val, (ind+1)*num_val))
    val_loss_list.append(     val_array[ind*num_val:(ind+1)*num_val, :][:, 0].mean())
    val_loss_edge_list.append(val_array[ind*num_val:(ind+1)*num_val, :][:, 1].mean())
    val_auc_list.append(      val_array[ind*num_val:(ind+1)*num_val, :][:, 2].mean())
    val_bce_list.append(      val_array[ind*num_val:(ind+1)*num_val, :][:, 3].mean())

In [None]:
fig, axes = plt.subplots(1,4, figsize=(18,3))
plt.subplots_adjust(wspace=0.3, hspace=0)
ax0, ax1, ax2, ax3 = axes

ax0.set_title('MSE')
ax0.plot(train_loss_list, color='b')
ax0.plot(val_loss_list, color='r')
ax0.set_xlabel('Epochs')
ax0.set_ylabel('MSE')

ax1.set_title('MSE Edge')
ax1.plot(train_loss_edge_list, color='b')
ax1.plot(val_loss_edge_list, color='r')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('MSE Edge')

ax2.set_title('AUC')
ax2.plot(train_auc_list, color='b')
ax2.plot(val_auc_list, color='r')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('AUC')

ax3.set_title('BCE')
ax3.plot(train_bce_list, color='b')
ax3.plot(val_bce_list, color='r')
ax3.set_xlabel('Epochs')
ax3.set_ylabel('BCE')

In [None]:
np.argmin(val_bce_list), np.argmin(val_loss_list), np.argmin(val_loss_edge_list)