In [23]:
from Submission.utils import *
import plotly.express as px

In [24]:
class LandmarkAeroplane():

    def __init__(self, init_pos, At, Bt, Ct, Qt, Rt, Rt_landmark, landmark_pos):

        self.state = init_pos
        self.At = At
        self.Bt = Bt
        self.Ct = Ct
        self.Qt = Qt
        self.Rt = Rt
        self.R_landmark = Rt_landmark
        self.landmarks = landmark_pos
        self.belief = None

    def fetchState(self):
        return self.state
    
    def updateState(self, u):
        self.state = np.matmul(self.At, self.state) + np.matmul(self.Bt, u) + np.random.multivariate_normal(np.zeros(self.Qt.shape[0]),self.Qt)

    def sensor_outputs(self):
        return np.matmul(self.Ct, self.state) + np.random.multivariate_normal(np.zeros(self.Rt.shape[0]),self.Rt)

    def getObservation(self):
        observation, landmark = self.evalH(self.state)
        observation = observation + np.random.multivariate_normal(np.zeros(self.R_landmark.shape[0]),self.R_landmark)
        assert abs(observation[0]) != np.inf or abs(observation[1]) != np.inf or abs(observation[2]) != np.inf
        return observation, landmark
    
    def evalH(self, x):
        dist = [np.linalg.norm(landmark - self.state[:3]) for landmark, limit in (self.landmarks)]
        best_landmark = np.argmin(dist)
        if dist[best_landmark] > self.landmarks[best_landmark][1]:
            return x[:4], -1
        else:
            return np.concatenate([x[:3], np.array([np.linalg.norm(self.landmarks[best_landmark][0] - x[:3])])]), best_landmark
    
    def getJacobian(self, x):
        H = np.zeros((4,6))
        H[:3,:3] = np.eye(3)
        dist = [np.linalg.norm(landmark - self.state[:3]) for landmark, limit in (self.landmarks)]
        best_landmark = np.argmin(dist)
        H[3,:3] = (x[:3] - self.landmarks[best_landmark][0])/np.linalg.norm(x[:3] - self.landmarks[best_landmark][0])
        return H

In [25]:
class ExtendedKalmannFilter():

    def __init__(self, agent, init_belief):
        agent.belief = init_belief
        self.agent = agent
        self.belief = init_belief
    
    def predictionUpdate(self, u):
        # Prediction update of the state of the aeroplane
        # u is the control input
        self.belief['mu'] = np.matmul(self.agent.At, self.belief['mu']) + np.matmul(self.agent.Bt, u) 
        self.belief['sigma'] = np.matmul(self.agent.At, np.matmul(self.belief['sigma'], self.agent.At.T)) + self.agent.Qt
        
    def measurementUpdate(self, z):
    
        # Measurement update of the state of the aeroplane
        # z is the measurement
    
        Kt = np.matmul(np.matmul(self.belief['sigma'], self.agent.Ct.T), np.linalg.inv(np.matmul(self.agent.Ct, np.matmul(self.belief['sigma'], self.agent.Ct.T)) + self.agent.Rt))
        self.belief['mu'] = self.belief['mu'] + np.matmul(Kt, z - np.matmul(self.agent.Ct, self.belief['mu']))
        self.belief['sigma'] = np.matmul(np.identity(self.belief['sigma'].shape[0]) - np.matmul(Kt, self.agent.Ct), self.belief['sigma'])

    def extendedMeasurementUpdate(self):
        z, landmark = self.agent.getObservation()
        if landmark == -1: #No landmark within limits
            z = self.agent.sensor_outputs()
            self.measurementUpdate(z[:3])
        else: #Landmark within limits
            H = self.agent.getJacobian(self.belief['mu'])
            inter1 = np.linalg.inv(np.matmul(H, np.matmul(self.belief['sigma'], H.T)) + self.agent.R_landmark)
            Kt = np.matmul(np.matmul(self.belief['sigma'], H.T), inter1)
            self.belief['mu'] = self.belief['mu'] + np.matmul(Kt, z - self.agent.evalH(self.belief['mu'])[0])
            self.belief['sigma'] = np.matmul(np.identity(self.belief['sigma'].shape[0]) - np.matmul(Kt, H), self.belief['sigma'])

    def extendedPredictionUpdate(self, u):
        self.predictionUpdate(u)
    
    def extendedUpdateBelief(self, u = None, z = None):
        if u is not None:
            self.extendedPredictionUpdate(u)
        if z is not None:
            self.extendedMeasurementUpdate()
        else:
            print("No update provided")

def getIncrementLandmark(t):
    return np.array([-0.128*np.cos(0.032*t), -0.128*np.sin(0.032*t), 0.01])

In [26]:
A = np.array([[1,0,0,1,0,0],[0,1,0,0,1,0],[0,0,1,0,0,1],[0,0,0,1,0,0],[0,0,0,0,1,0],[0,0,0,0,0,1]])
B = np.array([[0,0,0],[0,0,0],[0,0,0],[1,0,0],[0,1,0],[0,0,1]])
C = np.array([[1,0,0,0,0,0],[0,1,0,0,0,0],[0,0,1,0,0,0]])
rx, ry, rz = 0.01, 0.01, 0.01
rvx, rvy, rvz  = 0.01, 0.01, 0.01
Qt = np.diag([rx, ry, rz, rvx, rvy, rvz])**2
Rt = np.eye(3)*(10**2)
Rt_landmark = np.eye(4)*(10**2)
Rt_landmark[3,3] = 0.0001
mu0 = np.array([100,0,0,0,4,0])
sigma0 = np.eye(6)*((0.01)**2)
simulation_iterations = 200

In [27]:
def simulateLandmark(Agent,mu0, sigma0, simulation_iterations = 500, leave_obs_cond = lambda i : i < -1):
    AgentEstimator = ExtendedKalmannFilter(Agent, {'mu':mu0, 'sigma':sigma0})
    true_state_list = [Agent.state]
    observed_list = [Agent.getObservation()]
    estimated_list = [mu0]
    belief_covariances = [sigma0]
    beliefs = [mu0]
    for i in range(1,simulation_iterations+1):
        try:
            Agent.updateState(u = getIncrementLandmark(i))

            if leave_obs_cond(i):
                AgentEstimator.extendedUpdateBelief(u = getIncrementLandmark(i))
            else:
                AgentEstimator.extendedUpdateBelief(u = getIncrementLandmark(i), z = 1)
        except:
            print(beliefs)
            raise
        Agent.belief = AgentEstimator.belief
        true_state_list.append(Agent.state)
        observed_list.append(Agent.getObservation())
        estimated_list.append(Agent.belief['mu'])
        belief_covariances.append(Agent.belief['sigma'])
    return true_state_list, observed_list, estimated_list, belief_covariances

In [28]:
landmarks = [(np.array((150, 0, 100)),90), (np.array((-150, 0, 100)),90), (np.array((0, 150, 100)),90) , (np.array((0, -150, 100)),90)]
Agent = LandmarkAeroplane(mu0, A, B, C, Qt, Rt, Rt_landmark, landmarks)
true_state_list, observed_list, estimated_list, belief_covariances = simulateLandmark(Agent, mu0, sigma0, simulation_iterations)

In [31]:
[(np.array((150, 0, 100)),90), (np.array((-150, 0, 100)),90), (np.array((0, 150, 100)),90) , (np.array((0, -150, 100)),90)]

array([1, 2, 3])

In [37]:
tp = [(true_state_list, 'True State'), ([observed_list[i][0] for i in range(len(observed_list))], 'Observed State'), (estimated_list, 'Estimated State')]
plt = plot_trajectories(tp, 0,1,2)
plt.write_html('Q3.html')

In [39]:
plt.add_trace(go.Scatter3d(
        mode='markers',
        x=[150,-150,0,0],
        y=[0,0,150,-150],
        z=[100,100,100,100],
        marker=dict(color=px.colors.qualitative.D3,size=[90, 90, 90,90],sizemode='diameter'
        )
    )
)

In [117]:
def distance_metric(estimated, true):
    return np.linalg.norm(estimated - true)/np.sqrt(len(estimated))

In [118]:
distance_metric(np.array(estimated_list), np.array(true_state_list))

3.1675709625356165

In [119]:
def hungarian_method(cost_matrix):
    cost_matrix = np.array(cost_matrix)
    num_rows, num_cols = cost_matrix.shape

    for i in range(num_rows):
        min_val = np.min(cost_matrix[i, :])
        cost_matrix[i, :] -= min_val

    for j in range(num_cols):
        min_val = np.min(cost_matrix[:, j])
        cost_matrix[:, j] -= min_val

    print(cost_matrix)

    row_covered = np.zeros(num_rows, dtype=bool)
    col_covered = np.zeros(num_cols, dtype=bool)
    cover_zeros(cost_matrix, row_covered, col_covered)
    print(np.sum(row_covered))

    while np.sum(row_covered) < num_rows:
        min_uncovered_val = find_min_uncovered_value(cost_matrix, row_covered, col_covered)
        cost_matrix -= min_uncovered_val
        cover_zeros(cost_matrix, row_covered, col_covered)

    return np.argwhere(cost_matrix == 0)

def cover_zeros(cost_matrix, row_covered, col_covered):
    num_rows, num_cols = cost_matrix.shape

    for i in range(num_rows):
        for j in range(num_cols):
            if cost_matrix[i, j] == 0 and not row_covered[i] and not col_covered[j]:
                row_covered[i] = True
                col_covered[j] = True
                break

def find_min_uncovered_value(cost_matrix, row_covered, col_covered):
    num_rows, num_cols = cost_matrix.shape
    min_val = float('inf')

    for i in range(num_rows):
        for j in range(num_cols):
            if not row_covered[i] and not col_covered[j]:
                min_val = min(min_val, cost_matrix[i, j])

    return min_val

In [121]:
arr = hungarian_method([[1,2,3],[4,5,6],[7,8,9]])

[[0 0 0]
 [0 0 0]
 [0 0 0]]
3


In [123]:
from scipy.optimize import linear_sum_assignment

In [124]:
linear_sum_assignment([[1,2,3],[4,5,6],[7,8,9]])

(array([0, 1, 2]), array([0, 1, 2]))

In [None]:
import plotly.graph_objects as go

fig = px.line_3d(
    data=[go.Line(x=np.array(true_state_list)[:,0], y=np.array(true_state_list)[:,1], z = np.array(true_state_list)[:,2])],
    layout=go.Layout(
        title=go.layout.Title(text="A Figure Specified By A Graph Object")
    )
)

fig.show()

In [16]:
fig = dict({
    "data": [{"type": "line_3d",
              "x": np.array(true_state_list)[:,0],
              "y": np.array(true_state_list)[:,1]}],
    "layout": {"title": {"text": "A Figure Specified By Python Dictionary"}}
})

# To display the figure defined by this dict, use the low-level plotly.io.show function
import plotly.io as pio

pio.show(fig)

ValueError: 
    Invalid element(s) received for the 'data' property of 
        Invalid elements include: [{'type': 'line', 'x': array([ 100.        ,  100.00392091,   99.8721543 ,   99.61915049,
         99.23926502,   98.75671241,   98.11843962,   97.35475938,
         96.47619226,   95.45215113,   94.34665772,   93.1364379 ,
         91.80943128,   90.37629958,   88.79556422,   87.11101572,
         85.32580454,   83.42976735,   81.44053005,   79.34862426,
         77.15491534,   74.8581214 ,   72.45740293,   69.97451563,
         67.41019992,   64.76550965,   62.01661315,   59.18557908,
         56.27948499,   53.28016282,   50.21162478,   47.06691212,
         43.84907269,   40.57668441,   37.22551259,   33.80264645,
         30.34941901,   26.83014774,   23.2705912 ,   19.6481627 ,
         15.99977496,   12.32581976,    8.62592528,    4.88444225,
          1.12908014,   -2.63745437,   -6.38556826,  -10.16603697,
        -13.9323358 ,  -17.70725234,  -21.49877479,  -25.25266475,
        -29.02450065,  -32.77910567,  -36.51464736,  -40.23023116,
        -43.92462793,  -47.58106169,  -51.23303203,  -54.8232817 ,
        -58.36996284,  -61.82653417,  -65.23576536,  -68.62704332,
        -71.98412386,  -75.3010656 ,  -78.53512288,  -81.70493844,
        -84.80578776,  -87.82966952,  -90.77971418,  -93.66174251,
        -96.45249177,  -99.17061846, -101.79873878, -104.35232471,
       -106.80148859, -109.14044434, -111.39397542, -113.52187339,
       -115.55454732, -117.46841849, -119.27982955, -120.9630162 ,
       -122.54095457, -124.02909135, -125.39867182, -126.63927694,
       -127.75272396, -128.77154339, -129.66857085, -130.4179257 ,
       -131.03375172, -131.54528295, -131.94297126, -132.19849704,
       -132.32411431, -132.31357975, -132.18406224, -131.94353219,
       -131.57072153, -131.07198353, -130.46195281, -129.70512063,
       -128.83117978, -127.81657199, -126.6846311 , -125.46042756,
       -124.09333763, -122.59139431, -120.96907622, -119.22011446,
       -117.36252968, -115.38983544, -113.2876888 , -111.07979053,
       -108.77008811, -106.30924898, -103.7595225 , -101.12492994,
        -98.38420791,  -95.55398867,  -92.63091541,  -89.6266461 ,
        -86.54414549,  -83.393445  ,  -80.16642727,  -76.85000017,
        -73.46367967,  -69.98770702,  -66.45003849,  -62.82838251,
        -59.14351981,  -55.42171076,  -51.62811757,  -47.76163906,
        -43.85507709,  -39.88163897,  -35.86908384,  -31.80406453,
        -27.68633176,  -23.55878158,  -19.37308444,  -15.1500992 ,
        -10.92283091,   -6.67747614,   -2.39368929,    1.86708659,
          6.12784874,   10.40549173,   14.65749813,   18.90435024,
         23.14751217,   27.37550414,   31.5791234 ,   35.76759647,
         39.93378881,   44.06228819,   48.14231505,   52.19278735,
         56.17108975,   60.10045142,   63.99296495,   67.80352421,
         71.54242648,   75.22707651,   78.83145917,   82.35111358,
         85.78295027,   89.13396641,   92.3762373 ,   95.51181375,
         98.59457414,  101.5853731 ,  104.4952571 ,  107.30654285,
        110.02474373,  112.63988172,  115.12989796,  117.51326924,
        119.7773703 ,  121.9164189 ,  123.93247277,  125.85573143,
        127.65640494,  129.36232357,  130.90835818,  132.33553763,
        133.68515074,  134.88784555,  135.97223967,  136.93924823,
        137.77967363,  138.51398925,  139.1278663 ,  139.60757877,
        139.94913108,  140.16409654,  140.26065797,  140.21890847,
        140.04562212]), 'y': array([   0.        ,    4.01136488,    7.99405759,   11.96949667,
         15.95641953,   19.93218625,   23.89703611,   27.83246467,
         31.74136376,   35.6400594 ,   39.51508468,   43.32476675,
         47.10658254,   50.81443173,   54.48369141,   58.07901272,
         61.58586367,   65.03080156,   68.38491488,   71.69865207,
         74.93603567,   78.123844  ,   81.21704068,   84.22230661,
         87.14600309,   89.9745276 ,   92.70903703,   95.36546821,
         97.95981627,  100.44079482,  102.8400934 ,  105.15401564,
        107.34478774,  109.45361216,  111.45569303,  113.33907236,
        115.09001603,  116.73288572,  118.24065118,  119.63349447,
        120.89325734,  122.02464029,  123.02861916,  123.91394421,
        124.65314381,  125.2690315 ,  125.76823517,  126.13962882,
        126.36165741,  126.45468433,  126.40813893,  126.22929285,
        125.92138943,  125.48020019,  124.94407537,  124.25550875,
        123.43888687,  122.48195412,  121.41200226,  120.20974889,
        118.87077918,  117.4221991 ,  115.85319328,  114.20675846,
        112.395758  ,  110.47703443,  108.43263476,  106.28313244,
        104.03328757,  101.67777643,   99.21110119,   96.6157585 ,
         93.93688071,   91.17898487,   88.33944948,   85.38836267,
         82.37011256,   79.27375359,   76.11843013,   72.89612296,
         69.57552807,   66.16586442,   62.71620912,   59.18091419,
         55.57255152,   51.88935152,   48.14783725,   44.34260415,
         40.50729059,   36.64654879,   32.69141655,   28.72101717,
         24.73802434,   20.71594478,   16.65551565,   12.55154625,
          8.47607396,    4.38813326,    0.28025939,   -3.81976991,
         -7.93249981,  -12.04137326,  -16.13860221,  -20.23327769,
        -24.30808556,  -28.35114046,  -32.37221372,  -36.36789704,
        -40.35883971,  -44.28643239,  -48.14573426,  -51.97641235,
        -55.76103455,  -59.48365217,  -63.16324735,  -66.75535849,
        -70.31220688,  -73.75290608,  -77.14501695,  -80.4571081 ,
        -83.67758541,  -86.81268854,  -89.86719951,  -92.82897686,
        -95.69590411,  -98.49066449, -101.16816744, -103.73969002,
       -106.22056753, -108.58663201, -110.85517504, -112.99188503,
       -115.0179504 , -116.91879628, -118.70624474, -120.37295544,
       -121.9027835 , -123.31029406, -124.58657907, -125.76255763,
       -126.82059592, -127.74139117, -128.52806203, -129.21073112,
       -129.76551462, -130.17809489, -130.45114079, -130.58389526,
       -130.59839823, -130.49411507, -130.24889641, -129.86962744,
       -129.37691164, -128.76471942, -128.03255404, -127.17035622,
       -126.19620951, -125.0860376 , -123.85278688, -122.50158085,
       -121.02782257, -119.43131148, -117.72319382, -115.91203907,
       -113.99043235, -111.94958468, -109.80092394, -107.58534667,
       -105.27493224, -102.8458    , -100.31295325,  -97.6923041 ,
        -94.96918106,  -92.13702789,  -89.22905949,  -86.25050249,
        -83.18817167,  -80.0659684 ,  -76.86708139,  -73.57993346,
        -70.22343257,  -66.8434873 ,  -63.38066987,  -59.90120071,
        -56.36006051,  -52.78813615,  -49.130378  ,  -45.42592013,
        -41.69024041,  -37.91172679,  -34.11969234,  -30.27677483,
        -26.4495931 ,  -22.58347142,  -18.6740446 ,  -14.78241338,
        -10.89526894,   -7.0240723 ,   -3.14270341,    0.75249857,
          4.62728962])}]

    The 'data' property is a tuple of trace instances
    that may be specified as:
      - A list or tuple of trace instances
        (e.g. [Scatter(...), Bar(...)])
      - A single trace instance
        (e.g. Scatter(...), Bar(...), etc.)
      - A list or tuple of dicts of string/value properties where:
        - The 'type' property specifies the trace type
            One of: ['bar', 'barpolar', 'box', 'candlestick',
                     'carpet', 'choropleth', 'choroplethmapbox',
                     'cone', 'contour', 'contourcarpet',
                     'densitymapbox', 'funnel', 'funnelarea',
                     'heatmap', 'heatmapgl', 'histogram',
                     'histogram2d', 'histogram2dcontour', 'icicle',
                     'image', 'indicator', 'isosurface', 'mesh3d',
                     'ohlc', 'parcats', 'parcoords', 'pie',
                     'pointcloud', 'sankey', 'scatter',
                     'scatter3d', 'scattercarpet', 'scattergeo',
                     'scattergl', 'scattermapbox', 'scatterpolar',
                     'scatterpolargl', 'scattersmith',
                     'scatterternary', 'splom', 'streamtube',
                     'sunburst', 'surface', 'table', 'treemap',
                     'violin', 'volume', 'waterfall']

        - All remaining properties are passed to the constructor of
          the specified trace type

        (e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])

array([ 100.        ,  100.00392091,   99.8721543 ,   99.61915049,
         99.23926502,   98.75671241,   98.11843962,   97.35475938,
         96.47619226,   95.45215113,   94.34665772,   93.1364379 ,
         91.80943128,   90.37629958,   88.79556422,   87.11101572,
         85.32580454,   83.42976735,   81.44053005,   79.34862426,
         77.15491534,   74.8581214 ,   72.45740293,   69.97451563,
         67.41019992,   64.76550965,   62.01661315,   59.18557908,
         56.27948499,   53.28016282,   50.21162478,   47.06691212,
         43.84907269,   40.57668441,   37.22551259,   33.80264645,
         30.34941901,   26.83014774,   23.2705912 ,   19.6481627 ,
         15.99977496,   12.32581976,    8.62592528,    4.88444225,
          1.12908014,   -2.63745437,   -6.38556826,  -10.16603697,
        -13.9323358 ,  -17.70725234,  -21.49877479,  -25.25266475,
        -29.02450065,  -32.77910567,  -36.51464736,  -40.23023116,
        -43.92462793,  -47.58106169,  -51.23303203,  -54.82328