In [1]:
import numpy as np
from matplotlib import pyplot as plt
import plotly.express as px
import plotly.graph_objs as go
import plotly.io as pio

from torchkf import *
pio.templates.default = "plotly_white"

In [19]:
prior = dotdict(mu=8, sigma=1)
t     = np.linspace(-10 + prior.mu, 10 + prior.mu, 100)
fe    = lambda x: (x - prior.mu)**2/(2 * prior.sigma**2) + np.log(np.sqrt(2 * np.pi * prior.sigma**2))
F     = fe(t)
dFdx  = lambda x: (x-prior.mu) / prior.sigma**2 

figw  = go.FigureWidget()
figw.add_scatter(x=t, y=fe(t), fill='tozeroy', fillcolor='black', line_color='white')

FigureWidget({
    'data': [{'fill': 'tozeroy',
              'fillcolor': 'black',
              'line': {'co…

In [20]:
x = np.random.randint(-5, 5)
if len(figw.data) < 2: 
    figw.add_scatter(x=[], y=[], marker_size=10, marker_color='red', mode='markers', opacity=0.7)
xs = []
ys = []

dt = 1.
for i in range(10000):    
    if np.isclose(x, prior.mu): 
        break
    
    a = -dFdx(x)
    x += dt * a + np.random.normal(0, 0.5)
    
    xs.append(x)
    ys.append(fe(x))
    if (i % 100) == 20: 
        figw.data[1]['x'] = xs[-20:]
        figw.data[1]['y'] = ys[-20:]

In [21]:
k1, k2, k3, k4 = 0.5, 0.5, 0.7, 6.
sm    = lambda a, u: k1 * u * a  + k2 * u
am    = lambda sm, u: k3 * sm + k4 * u
dsmda = lambda a, u: k1 * u
dsmdu = lambda a, u: k2 + 2* k1 * a
damdsm = lambda sm, u: k3 
damdu = lambda sm, u: k4 #+ k3 * sm
dFda  = lambda a, u: dFdx(am(sm(a, u) + np.random.normal(0,1.), u) + np.random.normal(0,1.)) * damdsm(sm(a, u) + np.random.normal(0,1.), u) * dsmda(a, u)

figw = go.FigureWidget()
figw.add_scatter(x=t, y=F, fill='tozeroy', fillcolor='black', line_color='white')

FigureWidget({
    'data': [{'fill': 'tozeroy',
              'fillcolor': 'black',
              'line': {'co…

In [25]:
x = np.random.randint(-10,10)
prop = dotdict(mu=0, sigma=1)
if len(figw.data) < 2: 
    figw.add_scatter(x=[], y=[], marker_size=10, marker_color='red', mode='markers', opacity=0.7)
xs = []
ys = []
ac = []
us = []
dt = 0.001
a, u, r = 0, 0, 0
for i in range(30000):    
    if np.isclose(prop.mu, prior.mu) and np.isclose(prop.sigma, prior.sigma): 
        break
    if np.isnan(prop.mu) or np.isnan(prop.sigma):
        break
    
    # Generate noise with drift
    r   += dt * np.random.normal(0.1, .1)
    uest = np.random.normal(0.0, 5.) + u + r
    da   = -dFda(a, uest) 
    
    dF_pred = (x - prop.mu) / prop.sigma **2 
    dF_true = dFdx(am(sm(a, u), u)) 
    
    prop.mu    -= dt * 2. * (dF_true - dF_pred) / prop.sigma**2
    
#     dF_pred = (x - prop.mu) / prop.sigma ** 
    prop.sigma  = prop.sigma * np.exp(-dt * 2 * (dF_true - dF_pred) * dF_pred) 
    
    du  = -(dF_pred) * ( damdsm(sm(a, u), u) * dsmdu(a, u) + damdu(sm(a, u), u)) - u/100.
    
    a   += da * dt
    u   += du * dt
    
    ac.append(a)
    us.append(u)
    
    x = am(sm(a, u), u) + np.random.normal(0, 1)
    
    xs.append(x)
    ys.append(fe(x))
    if (i % 100) == 20: 
        figw.data[1]['x'] = xs[-100::2]
        figw.data[1]['y'] = ys[-100::2]
print(prop)
px.line({'Fe': [fe(x) for x in xs], 'a': ac, 'u': us, 'e': [x - prior.mu for x in xs]})

{'mu': 7.469620269089788, 'sigma': 2.4527524985755287}
