In [11]:
import math

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pyproj
import tensorflow as tf
from keras.layers import Dense, Normalization
from keras.models import Sequential
from plotly.subplots import make_subplots
from pykalman import KalmanFilter

from utils import ORBSLAMResults, umeyama_alignment


def fit_trajectory(source_points, target_points, epochs=200, batch_size=32):
    assert source_points.shape == target_points.shape

    callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=10)

    source_normalizer = Normalization(input_shape=(3,)) 
    source_normalizer.adapt(source_points)

    target_normalizer = Normalization(input_shape=(3,))
    target_normalizer.adapt(target_points)

    X = source_normalizer(source_points)
    y = target_normalizer(target_points)

    # Define the neural network architecture
    model = Sequential()
    model.add(Dense(256, activation='relu', input_shape=(3,)))
    model.add(Dense(256, activation='relu'))
    model.add(Dense(3))

    # Compile the model
    model.compile(optimizer='adam', loss='mean_squared_error')

    # Train the model with your data
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[callback], verbose=False)

    return model, source_normalizer, target_normalizer


def denormalize(normalizer, data):
    mean = normalizer.mean.numpy()
    variance = normalizer.variance.numpy()
    std = np.sqrt(variance)
    return data * std + mean


def predict_trajectory(model, source_points, source_normalizer, target_normalizer):
    normalized_source_points = source_normalizer(source_points)

    # Flatten the data for training
    X = normalized_source_points

    # Predict the target trajectory
    normalized_predicted_target_points = model.predict(X)


    return denormalize(target_normalizer, normalized_predicted_target_points)


In [12]:

results = ORBSLAMResults("~/msc/shared_data/orbslam-out-utm-nn-vabadusepst")

gps_trajectory_wgs = pd.DataFrame([(kf.gps.lat, kf.gps.lon, kf.gps.alt)
                                   for kf in results.keyframes[1:]], columns=['lat', 'lon', 'alt'])
slam_trajectory = np.array([(kf.x, kf.y, kf.z) for kf in results.keyframes[1:]])

# Create transformers for WGS84 <-> UTM35N
wgs2utm = pyproj.Transformer.from_crs(4326, 32635)
utm2wgs = pyproj.Transformer.from_crs(32635, 4326)

# Convert GPS trajectory (WGS84) to UTM35N
gps_trajectory_utm = np.array([wgs2utm.transform(kf.gps.lat, kf.gps.lon, kf.gps.alt)
                               for kf in results.keyframes[1:]])

# Align SLAM trajectory to GPS trajectory
R, t, c = umeyama_alignment(slam_trajectory.T, gps_trajectory_utm.T, True)
aligned_slam_trajectory_utm = np.array([t + c * R @ p for p in slam_trajectory])

# Convert SLAM trajectory (UTM35N) to WGS84
aligned_slam_trajectory_wgs = pd.DataFrame([utm2wgs.transform(p[0], p[1], p[2])
                                           for p in aligned_slam_trajectory_utm], columns=['lat', 'lon', 'alt'])

model, source_normalizer, target_normalizer = fit_trajectory(aligned_slam_trajectory_utm, gps_trajectory_utm, epochs=400)

fitted_slam_trajectory_utm = predict_trajectory(model, aligned_slam_trajectory_utm, source_normalizer, target_normalizer)
fitted_slam_trajectory_wgs = pd.DataFrame([utm2wgs.transform(p[0], p[1], p[2])
                                          for p in fitted_slam_trajectory_utm], columns=['lat', 'lon', 'alt'])

slam_estimates = np.array([(e.lat, e.lon, e.alt) for e in results.slam_estimates])
aligned_slam_estimate_utm = np.array([t + c * R @ p for p in slam_estimates])
fitted_slam_estimate_utm = predict_trajectory(model, aligned_slam_estimate_utm, source_normalizer, target_normalizer)
fitted_slam_estimate_wgs = pd.DataFrame([utm2wgs.transform(p[0], p[1], p[2])
                                        for p in fitted_slam_estimate_utm], columns=['lat', 'lon', 'alt'])

fig = go.Figure()
fig.add_trace(
    go.Scattermapbox(lat=gps_trajectory_wgs['lat'],
                     lon=gps_trajectory_wgs['lon'],
                     mode='markers+lines',
                     marker=dict(color='blue'),
                     name='GPS'))
fig.add_trace(
    go.Scattermapbox(lat=aligned_slam_trajectory_wgs['lat'],
                     lon=aligned_slam_trajectory_wgs['lon'],
                     mode='markers+lines',
                     marker=dict(color='red'),
                     name='SLAM'))
fig.add_trace(
    go.Scattermapbox(lat=fitted_slam_trajectory_wgs['lat'],
                     lon=fitted_slam_trajectory_wgs['lon'],
                     mode='markers+lines',
                     marker=dict(color='forestgreen'),
                     name='fitted SLAM'))
fig.add_trace(
    go.Scattermapbox(lat=fitted_slam_estimate_wgs['lat'],
                     lon=fitted_slam_estimate_wgs['lon'],
                     mode='markers+lines',
                     name='fitted SLAM estimate'))
fig.update_geos(projection_type="transverse mercator")
fig.update_layout(mapbox_style="open-street-map",
                  mapbox=dict(center=dict(lat=np.mean(gps_trajectory_wgs['lat']), lon=np.mean(
                      gps_trajectory_wgs['lon'])), zoom=15),
                  margin={"t": 0, "b": 0, "l": 0, "r": 0},
                  height=800)
fig.show()


