## Helper functions

In [22]:
import numpy as np
import matplotlib.pyplot as plt
import sklearn as sklearn
from datetime import datetime
import tensorflow as tf
import csv

import matplotlib.patches as mpatches
from sklearn import preprocessing 
from matplotlib.colors import from_levels_and_colors
from matplotlib import animation
from mat4py import loadmat
import itertools

import io
import sys
import pandas as pd


### Manually define the conflict zone

In [23]:
def getConflictZoneCoordinates():
    A = [-4, 8]
    B = [1, 6]
    C = [-3, -4]
    D = [-8, -2]
    
    return A, B, C, D


def drawConflictZone():
    A, B, C, D = helper_functions.getConflictZoneCoordinates()
    
    plt.plot(A[0], A[1], 'b+')  # plot x and y using blue circle markers
    plt.plot(B[0], B[1], 'b+')
    plt.plot(C[0], C[1], 'b+')
    plt.plot(D[0], D[1], 'b+')
    
    #plt.text(A[0], A[0], 'A', fontsize=9)
    #plt.text(B[0], B[0], 'B', fontsize=9)
    #plt.text(C[0], C[0], 'C', fontsize=9)
    #plt.text(D[0], D[0], 'D', fontsize=9)
    
    plt.plot([A[0], B[0], C[0], D[0], A[0]], [A[1], B[1], C[1], D[1], A[1]], color='magenta', linewidth=1)


### Coordinates

In [24]:
def getMaxMinXCoordinate(instances):
    maxXCoordinate = 0
    minXCoordinate = 10000000000

    for i in instances:
        xCoordinateMax = max(instances[i]['Xcoord'])[0]
        xCoordinateMin = min(instances[i]['Xcoord'])[0]
        
        if (maxXCoordinate < xCoordinateMax):
            maxXCoordinate = xCoordinateMax
            
        if (minXCoordinate > xCoordinateMin):
            minXCoordinate = xCoordinateMin
            
    return maxXCoordinate, minXCoordinate


def getMaxMinYCoordinate(instances):
    maxYCoordinate = 0
    minYCoordinate = 10000000000

    for i in instances:
        yCoordinateMax = max(instances[i]['Ycoord'])[0]
        yCoordinateMin = min(instances[i]['Ycoord'])[0]
        
        if (maxYCoordinate < yCoordinateMax):
            maxYCoordinate = yCoordinateMax
            
        if (minYCoordinate > yCoordinateMin):
            minYCoordinate = yCoordinateMin
            
    return maxYCoordinate, minYCoordinate


### Timestamps

In [25]:
def getAllTimeSteps(instances):
    allTimeStamps = []

    for i in instances:
        timeStamps = np.ravel(instances[i]['Sec'])
        allTimeStamps.extend(timeStamps)
     
    allTimeStamps.sort()
    return allTimeStamps

def findIndexOfTimeStamp(instance, timeStamp):
    timeStamps = np.ravel(instance['Sec'])

    try:
        index = timeStamps.tolist().index(timeStamp)
        return index
    
    except:
        return None


### Animations

In [26]:
def animateTrajectories(instances, index=0, name_suffix='', saveAsVideo=True, showPlot=True):

    print('Number of instances interacting (1 car + bicycles)', len(instances))
    print('Graph index: ', index)

    green = '#00cc00' # Cyclist
    blue = '#33ccff' # Car
    
    maxX, minX = helper_functions.getMaxMinXCoordinate(instances)
    maxY, minY = helper_functions.getMaxMinYCoordinate(instances)
    
    # First set up the figure, the axis, and the plot element we want to animate
    fig = plt.figure()
    ax = plt.axes(xlim=(minX, maxX), ylim=(minY, maxY))
    plt.grid(b=True, color='r')

    instance_array = []
    
    for key in instances:
        instance = instances[key] 
        instance_array.append(instance)
    
    
    lines = []
    
    # Let's cretae a line for each instance who interacted
    for i in range(len(instance_array)):
        
        instance = instance_array[i]
        id, t, ts, x_coord, y_coord, v, type, timeDiff = getFeatures(instance)
        line_color = green if type == 1 else blue
        
        instance['LastIndex'] = 0
        
        line, = ax.plot([], [], line_color, lw=1)
        lines.append(line)
        
    
    
    def init():
        i = 0
        while i < len(instance_array):
            lines[i].set_data([], [])
            i+=1

        return lines

    # Gets input timestamp from frames
    def animate(timestamp):        
        n=0
        while n < len(lines):
            
            formattedTs =  datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
            plt.title(formattedTs)
            
            index = helper_functions.findIndexOfTimeStamp(instance_array[n], timestamp)
           
            
            if (index != None):
                instance_array[n]['LastIndex'] = index
                
            
            currentIndex = instance_array[n]['LastIndex']

            x = instance_array[n]['Xcoord'][:currentIndex]
            y = instance_array[n]['Ycoord'][:currentIndex]
            
            lines[n].set_data(x, y)
            n+=1
            
        return lines

    timeSteps = helper_functions.getAllTimeSteps(instances)
    
    # Call the animator. blit=True means only re-draw the parts that have changed
    anim = animation.FuncAnimation(fig, func=animate, init_func=init, frames=timeSteps, interval=100, blit=True)

    if saveAsVideo == True:
        video_name = 'basic_animation_{}_{}.mp4'.format(name_suffix, index)
        anim.save(video_name, fps=30, extra_args=['-vcodec', 'libx264'])

    if showPlot == True:
        plt.show()
