In [None]:
import sys
sys.path.insert(0, 'lib/deep-person-reid')
import torchreid

import os
import cv2
import random
import pickle
import joblib
import time
import sklearn.metrics
import numpy as np
np.set_printoptions(suppress=True)
import gurobipy as gp

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(sci_mode=False)
from torch_geometric.data import Data

from lib.tracking import Tracker
from lib.qpthlocal.qp import QPFunction, QPSolvers
from lib.utils import getIoU, interpolateTrack, interpolateTracks
#from lib.qpthlocal.qp_cuda import QPFunction, QPSolvers, make_gurobi_model #If we want to fine-tune reid

# Prepare data used for training

In [None]:
random.seed(123)
np.random.seed(123)
torch.random.manual_seed(123)

train_data_list_full, val_data_list_full = [], []
for file in os.listdir('data/train_data/'):
    file_name = 'data/train_data/' + file
    with open(file_name, 'rb') as f:
        data_list = pickle.load(f)
        
    if file.startswith('MOT16-09') or file.startswith('MOT16-13'):
        val_data_list_full = val_data_list_full + data_list
    else:
        train_data_list_full = train_data_list_full + data_list
        
print('Total {} samples in training set, {} in validation set'.format(
                     len(train_data_list_full),len(val_data_list_full)))

train_data_list, val_data_list = [], []
for ind in range(len(train_data_list_full)):
    if train_data_list_full[ind].x.shape[0] < 200: #Avoid too big graph
        train_data_list.append(train_data_list_full[ind])
    else:
        continue
        
for ind in range(len(val_data_list_full)):
    if val_data_list_full[ind].x.shape[0] < 200:
        val_data_list.append(val_data_list_full[ind])
    else:
        continue
        
print('Used {} samples in training set, {} in validation set'.format(len(train_data_list), len(val_data_list)))

# Perform training

In [None]:
class Net(nn.Module): 
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Sequential(nn.Linear(6,6), nn.ReLU(), nn.Linear(6,1))
    def forward(self, data):
        x = self.fc(data.edge_attr)
        x = nn.Sigmoid()(x)
        return x
    
net = Net()
tracker = Tracker(net)
train_list, val_list = [], [] #Used for login loss, AUC, etc.

gamma = 0.1
optimizer = torch.optim.Adam(tracker.net.parameters(), lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999))
#optimizer = torch.optim.Adam(tracker.net.parameters(), lr=3e-3, weight_decay=1e-4, betas=(0.9, 0.999))

for epoch in range(1, 25):
    np.random.shuffle(train_data_list) #Shuffle data list
    for iteration in range(len(train_data_list)):
        
        train_data = train_data_list[iteration]
        A_eq, b_eq, A_ub, b_ub, x_gt, tran_indicator = tracker.build_constraint(train_data, max_frame_gap=1)
        A, b, G, h = tracker.build_constraint_torch(A_eq, b_eq, A_ub, b_ub)

        num_nodes = int(A_eq.shape[0] / 2)
        Q = gamma * torch.eye(G.shape[1])

        prob = tracker.net(train_data)
        prob = torch.clamp(prob, min=1e-7, max=1-1e-7)                       #Predicted matching probability
        prob_numpy = prob.detach().numpy()                                   #Predicted matching probability in np
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ], prob_numpy) #Area under the ROC Curve

        c_entry, c_exit, c_det = torch.ones(num_nodes), torch.ones(num_nodes), -1 * torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = tracker.make_gurobi_model_tracking(G.numpy(), h.numpy(), A.numpy(), 
                                                               b.numpy(), Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50, 
                       model_params=model_params_quad)(Q,c_pred,G,h,A,b)

        loss = nn.MSELoss()(x, torch.from_numpy(x_gt).float().t())
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:], torch.from_numpy(x_gt[num_nodes*3:]).float().t())

        const_cost = 1
        c_ = 0 * x_gt[num_nodes*3:] + const_cost * (1 - x_gt[num_nodes*3:])
        c_ = c_.squeeze()
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()]) #Ground truth cost
        
        obj_gt = c_gt @ torch.from_numpy(x_gt.squeeze()).float()               #Ground truth objective value
        obj_pred = c_pred @ x.squeeze() #Predicted objective value, should be close to GT objective value after training

        bce = nn.BCELoss()(prob, torch.from_numpy(x_gt[num_nodes*3:]).float())
        x_sol = tracker.linprog(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_loss = sklearn.metrics.hamming_loss(x_gt, x_sol)
        train_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        
        print('Train Epoch {} iter {}/{}, Objective {:.2f}/{:.2f}, mse {:.4f} mse edge {:.4f} ce {:.3f} \
auc {:.3f} hamming {:.3f}'.format(epoch, iteration, len(train_data_list), obj_pred.item(), 
                                  obj_gt.item(),loss.item(), loss_edge.item(), bce.item(), auc, ham_loss))
        optimizer.zero_grad()
        loss_edge.backward()
        #bce.backward()
        optimizer.step()
    torch.save(tracker.net.state_dict(), 'ckpt/epoch_{}.pth'.format(epoch))
    
    np.random.shuffle(val_data_list)
    for iteration in range(len(val_data_list)):
        
        val_data = val_data_list[iteration]
        A_eq, b_eq, A_ub, b_ub, x_gt, tran_indicator = tracker.build_constraint(val_data, max_frame_gap=1)
        A, b, G, h = tracker.build_constraint_torch(A_eq, b_eq, A_ub, b_ub)
        Q = gamma*torch.eye(G.shape[1])
        num_nodes = int(A_eq.shape[0] / 2)
        
        with torch.no_grad():
            prob = tracker.net(val_data)
            prob = torch.clamp(prob, min=1e-7, max=1-1e-7)
        
        prob_numpy = prob.detach().numpy() #matching probability in numpy
        auc = sklearn.metrics.roc_auc_score(x_gt[num_nodes*3: ], prob_numpy)

        c_entry, c_exit, c_det = torch.ones(num_nodes), torch.ones(num_nodes), -1 * torch.ones(num_nodes)
        c_pred = -1 * torch.log(prob).squeeze()
        c_pred = torch.cat([c_det, c_entry, c_exit, c_pred])

        model_params_quad = tracker.make_gurobi_model_tracking(G.numpy(), h.numpy(), A.numpy(),
                                                               b.numpy(), Q.numpy())
        x = QPFunction(verbose=False, solver=QPSolvers.GUROBI, maxIter=50,
                       model_params=model_params_quad)(Q, c_pred, G, h, A, b)

        loss = nn.MSELoss()(x, torch.from_numpy(x_gt).float().t())
        loss_edge = nn.MSELoss()(x[:, num_nodes*3:], torch.from_numpy(x_gt[num_nodes*3:]).float().t())
        
        const_cost = 1
        c_ = 0 * x_gt[num_nodes*3:] + const_cost * (1 - x_gt[num_nodes*3:])
        c_ = c_.squeeze()
        c_gt = torch.cat([c_pred[:num_nodes*3], torch.from_numpy(c_).float()])
        
        obj_gt = c_gt @ torch.from_numpy(x_gt.squeeze()).float()
        obj_pred = c_pred @ x.squeeze()

        bce = nn.BCELoss()(prob, torch.from_numpy(x_gt[num_nodes*3: ]).float())
        x_sol = tracker.linprog(c_pred.detach().numpy(), A_eq, b_eq, A_ub, b_ub)
        ham_loss = sklearn.metrics.hamming_loss(x_gt, x_sol)
        val_list.append((loss.item(), loss_edge.item(), auc, bce.item()))
        print('Val Epoch {}, iter {}/{}, Objective {:.2f}/{:.2f}, mse {:.3f} mse edge {:.3f} ce {:.3f}\
auc {:.3f} hamming {:.3f}'.format(epoch, iteration, len(val_data_list), obj_pred.item(), 
                                  obj_gt.item(), loss.item(), loss_edge.item(), bce.item(), auc, ham_loss))