In [None]:
@dataclass
class EKF:
    dynamic_model: DynamicModel
    sensor_model: MeasurementModel

    def predict(self,
                state_upd_prev_gauss: MultiVarGaussian,
                Ts: float,
                ) -> MultiVarGaussian:
        """Predict the EKF state Ts seconds ahead."""

        #  replace this with your own code
        xk = self.dynamic_model.f(state_upd_prev_gauss.mean, Ts)
        F = self.dynamic_model.F(state_upd_prev_gauss.mean, Ts)
        Q = self.dynamic_model.Q(state_upd_prev_gauss.mean, Ts)
        #from algorithm 1 in the book we have
        P = state_upd_prev_gauss.cov
        Pk = F @ P @ F.T + Q

        state_pred_gauss = MultiVarGaussian(xk, Pk)
        #solution.ekf.EKF.predict(self, state_upd_prev_gauss, Ts)
        return state_pred_gauss

    def predict_measurement(self,
                            state_pred_gauss: MultiVarGaussian
                            ) -> MultiVarGaussian:
        """Predict measurement pdf from using state pdf and model."""
        x_bar, P = state_pred_gauss
        H = self.sensor_model.H(x_bar)
        R = self.sensor_model.R(x_bar)
        z_bar = self.sensor_model.h(x_bar)  # 
        #from algorithm 1 in the book we have
        S = H @ P @ H.T + R  # 
        measure_pred_gauss = MultiVarGaussian(z_bar, S)

        #  replace this with your own code
        #measure_pred_gauss = solution.ekf.EKF.predict_measurement(
        #    self, state_pred_gauss)

        return measure_pred_gauss

    def update(self,
               z: np.ndarray,
               state_pred_gauss: MultiVarGaussian,
               measurement_gauss: Optional[MultiVarGaussian] = None,
               ) -> MultiVarGaussian:
        """Given the prediction and innovation, 
        find the updated state estimate."""
        x_pred, P = state_pred_gauss
        if measurement_gauss is None:
            measurement_gauss = self.predict_measurement(state_pred_gauss)

        z_bar, S = measurement_gauss
        H = self.sensor_model.H(x_pred)
        #from algorithm 1 in the book we have
        v = z - z_bar
        W = P @ H.T @ np.linalg.inv(S)
        x_upd = x_pred + W @ v   # 
        P_upd = (np.eye(4)-W @ H) @ P  # 
        state_upd_gauss = MultiVarGaussian(x_upd, P_upd)

        #  replace this with your own code
        #state_upd_gauss = solution.ekf.EKF.update(
        #    self, z, state_pred_gauss, measurement_gauss)

        return state_upd_gauss

    def step_with_info(self,
                       state_upd_prev_gauss: MultiVarGaussian,
                       z: np.ndarray,
                       Ts: float,
                       ) -> tuple[MultiVarGaussian,
                                  MultiVarGaussian,
                                  MultiVarGaussian]:
        """
        Predict ekfstate Ts units ahead and then update this prediction with z.

        Returns:
            state_pred_gauss: The state prediction
            measurement_pred_gauss: 
                The measurement prediction after state prediction
            state_upd_gauss: The predicted state updated with measurement
        """

        #  replace this with your own code
        state_pred_gauss = self.predict(state_upd_prev_gauss,Ts)
        measurement_pred_gauss = self.predict_measurement(state_pred_gauss)
        state_upd_gauss = self.update(z, state_pred_gauss, measurement_pred_gauss)
        #state_pred_gauss, measurement_pred_gauss, state_upd_gauss = solution.ekf.EKF.step_with_info(
        #    self, state_upd_prev_gauss, z, Ts)

        return state_pred_gauss, measurement_pred_gauss, state_upd_gauss

    def step(self,
             state_upd_prev_gauss: MultiVarGaussian,
             z: np.ndarray,
             Ts: float,
             ) -> MultiVarGaussian:

        _, _, state_upd_gauss = self.step_with_info(state_upd_prev_gauss,
                                                    z, Ts)
        return state_upd_gauss