In [None]:
import numpy as np
import gtsam
from gtsam.symbol_shorthand import X, P
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import HTML
import tqdm
import chebyshev_fitter, spline_fitter, sln_letter_fit
from sln_letter_fit import FitParams, OptimizationLoggingParams
from sln_stroke_fit import SlnStrokeFit
from art_skills import SlnStrokeExpression
import loader
from fit_types import Solution

%load_ext autoreload
%autoreload 1
%aimport chebyshev_fitter, spline_fitter, sln_letter_fit, loader, sln_stroke_fit

# First Pretend we already have the stroke splits

In [None]:
letter = loader.load_segments('D')
dt = 1./120
fit_params = FitParams()
fitter = SlnStrokeFit(dt,
                      integration_noise_model=gtsam.noiseModel.Isotropic.Sigma(2, fit_params.noise_integration_std),
                      data_prior_noise_model=gtsam.noiseModel.Isotropic.Sigma(2, fit_params.noise_data_prior_std),
                      reparameterize=fit_params.reparameterize,
                      flip_parameters_at_end=fit_params.flip_parameters_at_end)
stroke_indices = fitter.stroke_indices(letter)
stroke_param_mean, stroke_param_prior = (np.array([
    -0.3824788860477358, 0.4333073800807866, 0.3254905871611055, 0.20628570778082297,
    0.23762343512462863, -1.104228059293241
]), gtsam.noiseModel.Isotropic.Sigma(6, 50))
isam2params = gtsam.ISAM2Params()
isam2params.evaluateNonlinearError = True


In [None]:
k = 0
isam = gtsam.ISAM2(isam2params)
prev_error = 0
estimates = []
with tqdm.tnrange(sum(stroke.shape[0] for stroke in letter)) as progress_bar:
    progress_bar.set_description('Optimizing stroke')
    for strokei, stroke in enumerate(letter):
        for datai, (t, x, y) in enumerate(stroke):
            # New factors
            graph = gtsam.NonlinearFactorGraph()
            graph.push_back(fitter.data_prior_factors(np.array([[t, x, y]])))
            if k > 0:
                graph.push_back(fitter.stroke_factors(strokei, k - 1, k))
            if datai == 0:
                graph.addPriorVector(P(strokei), stroke_param_mean, stroke_param_prior)
            # New variables
            init = gtsam.Values()
            if datai == 0:
                init.insert(P(strokei), stroke_param_mean)
            if k == 0:
                init.insert(X(k), np.array([x, y]))
            else:
                stroke_params = isam.calculateEstimateVector(P(strokei if datai > 0 else strokei - 1))
                stroke_eval = SlnStrokeExpression(stroke_params)
                xy_init = isam.calculateEstimatePoint2(X(k-1)) + stroke_eval.displacement(t, dt)
                init.insert(X(k), xy_init)
            # update
            result = isam.update(graph, init)
            # for _ in range(100):
            #     isam.update()
            def criteria(before, after):
                if before == 0:
                    return True
                return (after * (k-1) / k - before) / before < 1e-2
            did_update = 0
            for _ in range(20):
                if criteria(prev_error, result.getErrorAfter()):
                    break
                # print(f'error too large: {result.getErrorAfter()}')
                result = isam.update()

                b, a = prev_error, result.getErrorAfter()
                if a != did_update:
                    did_update = a
                    adjCost = (a * (k - 1) / k - b) / b if k > 0 else float('nan')
                    progress_bar.set_postfix(errorBefore=f'{result.getErrorBefore():.2e}',
                                            errorAfter=f'{result.getErrorAfter():.2e}',
                                            adjustedCostIncrease=f'{adjCost:.3f}')
            prev_error = result.getErrorAfter()
            estimates.append(isam.calculateEstimate())

            # result.print(f'{k = }, error = {result.getErrorBefore()} {result.getErrorAfter()}')
            k += 1
            progress_bar.update()
            # print(isam.calculateEstimate())
        break
        # graph = gtsam.NonlinearFactorGraph()
        # graph.addPriorVector(X(k - 1), np.array([x, y]), gtsam.noiseModel.Isotropic.Sigma(2, 0.05))
        # isam.update(graph, gtsam.Values())


In [None]:
def create_sol(values, num_strokes, k):
    stroke_indices_ = {0: (stroke_indices[0][0], k)}
    t = np.arange(0, (k + 1) * dt, dt).reshape(-1, 1)
    x0 = fitter.query_estimate_at(values, 0)
    def query(t):
        try:
            return fitter.query_estimate_at(values, t)
        except:
            return np.zeros((1, 2))#x0
    txy = np.hstack((t, [query(t_) for t_ in t]))
    num_strokes = sum(1 for start, _ in stroke_indices.values() if k > start)
    params = [fitter.query_parameters(values, strokei) for strokei in range(num_strokes)]
    txy_from_params = fitter.compute_trajectory_from_parameters(x0, params, stroke_indices)
    txy_from_params = np.hstack((np.arange(0, txy_from_params.shape[0] * dt,
                                           dt).reshape(-1, 1), txy_from_params))
    return Solution(params=params,
                    txy=txy,
                    txy_from_params=txy_from_params,
                    stroke_indices=stroke_indices_)
history = [create_sol(estimate, 1, kmax) for kmax, estimate in enumerate(estimates)]
sol_and_history = (history[-1], history)

In [None]:
fig, ax = plt.subplots(figsize=(8,8))

txy_gt = np.vstack(letter)
def update(i):
    ax.cla()
    ax.plot(txy_gt[:i + 1, 1], txy_gt[:i + 1, 2], 'k.')
    est = history[i]
    ax.plot(est['txy'][:i + 1, 1], est['txy'][:i + 1, 2], 'r.', markersize=1)
    ax.plot(est['txy_from_params'][:i + 1, 1],
            est['txy_from_params'][:i + 1, 2],
            'r-',
            markersize=1)
    ax.plot(est['txy_from_params'][i:i + 10, 1],
            est['txy_from_params'][i:i + 10, 2],
            'b-',
            markersize=1)

    text = '\n'.join(f'params: ' + ''.join(f'{param:6.2f}'.replace(' ', '~')
                                           for param in params)
                     for params in history[i]['params'])
    text_box = matplotlib.offsetbox.AnchoredText(text,
                                                 frameon=True,
                                                 loc='lower right',
                                                 pad=0.3,
                                                 bbox_to_anchor=(1, 0),
                                                 bbox_transform=ax.transAxes,
                                                 borderpad=0,
                                                 prop=dict(size=9))
    text_box.patch.set_alpha(0.4)
    ax.add_artist(text_box)

    ax.axis('equal')
    ax.set_ylim(-1.1, -0.1)

# with tqdm.notebook.trange(0, len(letter[0]), 1) as progress_bar:
#     progress_bar.set_description('Saving Animation')
#     anim = matplotlib.animation.FuncAnimation(ax.figure, update, frames=progress_bar)
#     anim.save('results/incremental_D_stroke0.mp4', writer=matplotlib.animation.FFMpegWriter(fps=15))
# anim = plotting.animate_trajectories(ax, [letter], [sol_and_history], is_notebook=True)
with tqdm.notebook.trange(0, len(history), 1) as progress_bar:
    progress_bar.set_description('Displaying Animation')
    anim = matplotlib.animation.FuncAnimation(ax.figure, update, frames=progress_bar)
    plt.close(fig)
    display(HTML(anim.to_jshtml()))
