Skip to content

Commit

Permalink
Kalman filtering. Code cleaning.
Browse files Browse the repository at this point in the history
  • Loading branch information
formatBCE committed Aug 27, 2022
1 parent 6523fe9 commit 61ace19
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 119 deletions.
138 changes: 112 additions & 26 deletions custom_components/format_ble_tracker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from __future__ import annotations

import asyncio
from asyncio import events
import json
import time
import logging
import time
from typing import Any
#import numpy as np
import math

import voluptuous as vol

Expand All @@ -20,19 +21,15 @@
ALIVE_NODES_TOPIC,
DOMAIN,
MAC,
MERGE_IDS,
NAME,
ROOM,
ROOT_TOPIC,
RSSI,
TIMESTAMP,
MERGE_IDS,
)

PLATFORMS: list[Platform] = [
Platform.DEVICE_TRACKER,
Platform.SENSOR,
Platform.NUMBER
]
PLATFORMS: list[Platform] = [Platform.DEVICE_TRACKER, Platform.SENSOR, Platform.NUMBER]
_LOGGER = logging.getLogger(__name__)

MQTT_PAYLOAD = vol.Schema(
Expand All @@ -41,7 +38,7 @@
vol.Schema(
{
vol.Required(RSSI): vol.Coerce(int),
vol.Optional(TIMESTAMP): vol.Coerce(int)
vol.Optional(TIMESTAMP): vol.Coerce(int),
},
extra=vol.ALLOW_EXTRA,
),
Expand All @@ -68,7 +65,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
elif MERGE_IDS in entry.data:
hass.config_entries.async_setup_platforms(entry, [Platform.DEVICE_TRACKER])


return True


Expand All @@ -80,7 +76,10 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
else:
platforms = [Platform.DEVICE_TRACKER]

if unload_ok := await hass.config_entries.async_unload_platforms(entry, platforms) and entry.entry_id in hass.data[DOMAIN]:
if (
unload_ok := await hass.config_entries.async_unload_platforms(entry, platforms)
and entry.entry_id in hass.data[DOMAIN]
):
hass.data[DOMAIN].pop(entry.entry_id)

if MAC in entry.data:
Expand All @@ -93,29 +92,34 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:


class BeaconCoordinator(DataUpdateCoordinator[dict[str, Any]]):
"""Class to arrange interaction with MQTT"""
"""Class to arrange interaction with MQTT."""

def __init__(self, hass: HomeAssistant, data) -> None:
"""Initialise coordinator."""
self.mac = data[MAC]
self.expiration_time : int
self.default_expiration_time : int = 2
self.expiration_time: int
self.default_expiration_time: int = 2
given_name = data[NAME] if data.__contains__(NAME) else self.mac
self.room_data = dict[str, int]()
self.filtered_room_data = dict[str, int]()
self.room_filters = dict[str, KalmanFilter]()
self.room_expiration_timers = dict[str, asyncio.TimerHandle]()
self.room = None
self.room: str | None = None
self.last_received_adv_time = None

super().__init__(hass, _LOGGER, name=given_name)

async def _async_update_data(self) -> dict[str, Any]:
"""Update data via library."""
if len(self.room_data) == 0:
if len(self.filtered_room_data) == 0:
self.room = None
self.last_received_adv_time = None
else:
self.room = next(
iter(
dict(
sorted(
self.room_data.items(),
self.filtered_room_data.items(),
key=lambda item: item[1],
reverse=True,
)
Expand All @@ -125,7 +129,7 @@ async def _async_update_data(self) -> dict[str, Any]:
return {**{ROOM: self.room}}

async def subscribe_to_mqtt(self) -> None:
"""Subscribe coordinator to MQTT messages"""
"""Subscribe coordinator to MQTT messages."""

@callback
async def message_received(self, msg):
Expand All @@ -136,20 +140,27 @@ async def message_received(self, msg):
_LOGGER.debug("Skipping update because of malformatted data: %s", error)
return
msg_time = data.get(TIMESTAMP)
if (msg_time is not None):
if msg_time is not None:
current_time = int(time.time())
if (current_time - msg_time >= self.get_expiration_time()):
if current_time - msg_time >= self.get_expiration_time():
_LOGGER.info("Received message with old timestamp, skipping")
return

self.time_from_previous = None if self.last_received_adv_time is None else (current_time - self.last_received_adv_time)
self.last_received_adv_time = current_time

room_topic = msg.topic.split("/")[2]

await self.schedule_data_expiration(room_topic)
self.room_data[room_topic] = data.get(RSSI)

rssi = data.get(RSSI)
self.room_data[room_topic] = rssi
self.filtered_room_data[room_topic] = self.get_filtered_value(room_topic, rssi)

await self.async_refresh()

async def schedule_data_expiration(self, room):
"""Start timer for data expiration for certain room"""
"""Start timer for data expiration for certain room."""
if room in self.room_expiration_timers:
self.room_expiration_timers[room].cancel()
loop = asyncio.get_event_loop()
Expand All @@ -159,20 +170,95 @@ async def schedule_data_expiration(self, room):
)
self.room_expiration_timers[room] = timer

def get_filtered_value(self, room, value) -> int:
"""Apply Kalman filter"""
k_filter: KalmanFilter
if room in self.room_filters:
k_filter = self.room_filters[room]
else:
k_filter = KalmanFilter(0.01, 5)
self.room_filters[room] = k_filter
return int(k_filter.filter(value))

def get_expiration_time(self):
"""Calculate current expiration delay"""
"""Calculate current expiration delay."""
return getattr(self, "expiration_time", self.default_expiration_time) * 60

async def expire_data(self, room):
"""Set data for certain room expired"""
"""Set data for certain room expired."""
del self.room_data[room]
del self.filtered_room_data[room]
del self.room_filters[room]
del self.room_expiration_timers[room]
await self.async_refresh()

async def on_expiration_time_changed(self, new_time : int):
"""Respond to expiration time changed by user"""
async def on_expiration_time_changed(self, new_time: int):
"""Respond to expiration time changed by user."""
if new_time is None:
return
self.expiration_time = new_time
for room in self.room_expiration_timers.keys():
await self.schedule_data_expiration(room)

class KalmanFilter:
"""Filtering RSSI data."""

cov = float('nan')
x = float('nan')

def __init__(self, R, Q):
"""
Constructor
:param R: Process Noise
:param Q: Measurement Noise
"""
self.A = 1
self.B = 0
self.C = 1

self.R = R
self.Q = Q

def filter(self, measurement):
"""
Filters a measurement
:param measurement: The measurement value to be filtered
:return: The filtered value
"""
u = 0
if math.isnan(self.x):
self.x = (1 / self.C) * measurement
self.cov = (1 / self.C) * self.Q * (1 / self.C)
else:
pred_x = (self.A * self.x) + (self.B * u)
pred_cov = ((self.A * self.cov) * self.A) + self.R

# Kalman Gain
K = pred_cov * self.C * (1 / ((self.C * pred_cov * self.C) + self.Q));

# Correction
self.x = pred_x + K * (measurement - (self.C * pred_x));
self.cov = pred_cov - (K * self.C * pred_cov);

return self.x

def last_measurement(self):
"""
Returns the last measurement fed into the filter
:return: The last measurement fed into the filter
"""
return self.x

def set_measurement_noise(self, noise):
"""
Sets measurement noise
:param noise: The new measurement noise
"""
self.Q = noise

def set_process_noise(self, noise):
"""
Sets process noise
:param noise: The new process noise
"""
self.R = noise
13 changes: 8 additions & 5 deletions custom_components/format_ble_tracker/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Common values"""
"""Common values."""
from homeassistant.helpers.device_registry import format_mac
from .const import DOMAIN
from .__init__ import BeaconCoordinator
from homeassistant.helpers.update_coordinator import CoordinatorEntity

from .__init__ import BeaconCoordinator
from .const import DOMAIN


class BeaconDeviceEntity(CoordinatorEntity[BeaconCoordinator]):
"""Base device class"""
"""Base device class."""

def __init__(self, coordinator: BeaconCoordinator) -> None:
"""Initialize."""
Expand All @@ -14,10 +16,11 @@ def __init__(self, coordinator: BeaconCoordinator) -> None:

@property
def device_info(self):
"""Device info creation."""
return {
"identifiers": {
# MAC addresses are unique identifiers within a specific domain
(DOMAIN, self.formatted_mac_address)
},
"name": self.coordinator.name,
}
}
Loading

0 comments on commit 61ace19

Please sign in to comment.