In [1]:
from statsmodels.tsa.statespace.kalman_filter import KalmanFilter
import numpy as np
from scipy.optimize import linear_sum_assignment
from collections import deque
import time
import imageio
import cv2
import sys
from random import randint

In [2]:
#kalman filter 

class KalmanFilter(object):
    def __init__(self, dt=1,stateVariance=1,measurementVariance=1,method="Velocity" ):
        super(KalmanFilter, self).__init__()
        self.method = method
        self.stateVariance = stateVariance
        self.measurementVariance = measurementVariance
        self.dt = dt
        self.initModel()
        
    def initModel(self): 
        if self.method == "Accerelation":
            self.U = 1
        else: 
            self.U = 0
        self.A = np.matrix( [[1 ,self.dt, 0, 0], [0, 1, 0, 0],[0, 0, 1, self.dt], [0, 0, 0, 1]] )
        self.B = np.matrix( [[self.dt**2/2], [self.dt], [self.dt**2/2], [self.dt]] )
        self.H = np.matrix( [[1,0,0,0], [0,0,1,0]] ) 
        self.P = np.matrix(self.stateVariance*np.identity(self.A.shape[0]))
        self.R = np.matrix(self.measurementVariance*np.identity(self.H.shape[0]))
        self.Q = np.matrix( [[self.dt**4/4 ,self.dt**3/2, 0, 0],[self.dt**3/2, self.dt**2, 0, 0],[0, 0, self.dt**4/4 ,self.dt**3/2],[0, 0, self.dt**3/2,self.dt**2]])
        self.erroCov = self.P
        self.state = np.matrix([[0],[1],[0],[1]])

    def predict(self):
        self.predictedState = self.A*self.state + self.B*self.U
        self.predictedErrorCov = self.A*self.erroCov*self.A.T + self.Q
        temp = np.asarray(self.predictedState)
        return temp[0], temp[2]
    
    def correct(self, currentMeasurement):
        self.kalmanGain = self.predictedErrorCov*self.H.T*np.linalg.pinv(self.H*self.predictedErrorCov*self.H.T+self.R)
        self.state = self.predictedState + self.kalmanGain*(currentMeasurement- (self.H*self.predictedState))
        self.erroCov = (np.identity(self.P.shape[0]) - self.kalmanGain*self.H)*self.predictedErrorCov

In [3]:
#tracks

class Tracks(object):
    def __init__(self, detection, trackId):
        super(Tracks, self).__init__()
        self.KF = KalmanFilter()
        self.KF.predict()
        self.KF.correct(np.matrix(detection).reshape(2,1))
        self.trace = deque(maxlen=20)
        self.prediction = detection.reshape(1,2)
        self.trackId = trackId
        self.skipped_frames = 0

    def predict(self,detection):
        self.prediction = np.array(self.KF.predict()).reshape(1,2)
        self.KF.correct(np.matrix(detection).reshape(2,1))
        

In [4]:
#tracker

class Tracker(object):
    def __init__(self, dist_threshold, max_frame_skipped, max_trace_length):
        super(Tracker, self).__init__()
        self.dist_threshold = dist_threshold
        self.max_frame_skipped = max_frame_skipped
        self.max_trace_length = max_trace_length
        self.trackId = 0
        self.tracks = []

    def update(self, detections):
        if len(self.tracks) == 0:
            for i in range(detections.shape[0]):
                track = Tracks(detections[i], self.trackId)
                self.trackId +=1
                self.tracks.append(track)
                
        N = len(self.tracks)
        M = len(detections)
        cost = []
        for i in range(N):
            diff = np.linalg.norm(self.tracks[i].prediction - detections.reshape(-1,2), axis=1)
            cost.append(diff)
            
        cost = np.array(cost)*0.1
        row, col = linear_sum_assignment(cost)
        assignment = [-1]*N
        
        for i in range(len(row)):
            assignment[row[i]] = col[i]
            
        un_assigned_tracks = []
        for i in range(len(assignment)):
            if assignment[i] != -1:
                if (cost[i][assignment[i]] > self.dist_threshold):
                    assignment[i] = -1
                    un_assigned_tracks.append(i)
                else:
                    self.tracks[i].skipped_frames +=1
                    
        del_tracks = []
        for i in range(N):
            if self.tracks[i].skipped_frames > self.max_frame_skipped :
                del_tracks.append(i)
                
        if len(del_tracks) > 0:
            for i in range(len(del_tracks)):
                del self.tracks[i]
                del assignment[i]
                
        for i in range(M):
            if i not in assignment:
                track = Tracks(detections[i], self.trackId)
                self.trackId +=1
                self.tracks.append(track)
                
        for i in range(len(assignment)):
            if(assignment[i] != -1):
                self.tracks[i].skipped_frames = 0
                self.tracks[i].predict(detections[assignment[i]])
            self.tracks[i].trace.append(self.tracks[i].prediction)


In [5]:
images = []

def createimage(w,h):
	size = (w, h, 1)
	img = np.ones((w,h,3),np.uint8)*255
	return img

def main():
	data = np.array(np.load('Detections.npy'))[0:10,0:150,0:150]
	tracker = Tracker(150, 30, 5)
	skip_frame_count = 0
	track_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(127, 127, 255), (255, 0, 255), (255, 127, 255),(127, 0, 255), (127, 0, 127),(127, 10, 255), (0,255, 127)]

	for i in range(data.shape[1]):
		centers = data[:,i,:]
		frame = createimage(512,512)
		if (len(centers) > 0):
			tracker.update(centers)
			for j in range(len(tracker.tracks)):
				if (len(tracker.tracks[j].trace) > 1):
					x = int(tracker.tracks[j].trace[-1][0,0])
					y = int(tracker.tracks[j].trace[-1][0,1])
					tl = (x-10,y-10)
					br = (x+10,y+10)
					cv2.rectangle(frame,tl,br,track_colors[j],1)
					cv2.putText(frame,str(tracker.tracks[j].trackId), (x-10,y-20),0, 0.5, track_colors[j],2)
					for k in range(len(tracker.tracks[j].trace)):
						x = int(tracker.tracks[j].trace[k][0,0])
						y = int(tracker.tracks[j].trace[k][0,1])
						cv2.circle(frame,(x,y), 3, track_colors[j],-1)
					cv2.circle(frame,(x,y), 6, track_colors[j],-1)
				cv2.circle(frame,(int(data[j,i,0]),int(data[j,i,1])), 6, (0,0,0),-1)
			cv2.imshow('image',frame)
			time.sleep(0.1)
			if cv2.waitKey(1) & 0xFF == ord('q'):
				cv2.destroyAllWindows()
				break                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    

if __name__ == '__main__':
	main()