In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.special as sp
import numpy as onp

import scipy
from functools import partial
import flax.linen as nn
import optax
import diffrax
import matplotlib.pyplot as plt
from tqdm import tqdm
import sde.markov_approximation as ma
from sde.models import FractionalSDE, StaticFunction, Function
import distrax
from ipywidgets import interact
import pandas as pd
import seaborn as sns
sns.set(font_scale=3., rc={'text.usetex' : True}, style='whitegrid')
sns.set_context('paper')

In [None]:
def solve_diffrax(params, model, x0, y0, ts, dt, key, solver=diffrax.Euler(), args=None):
    def drift(t, state, args):
        x, y, _ = state
        hurst = model.hurst(params, t)
        omega = model.omega(hurst)
        u = model.u(params, t, x, y, args)
        dy = - model.gamma[:, None] * y + u[None, :]
        dx = model.b(params, t, x, args) + model.s(params, t, x, args) * (omega[:, None] * dy).sum(axis=0)
        return (dx, dy, .5 * u ** 2)

    def diffusion(t, state, args):
        x, _, _ = state
        hurst = model.hurst(params, t)
        omega = model.omega(hurst)
        return (model.s(params, t, x, args) * omega.sum(), jnp.ones((model.num_k, model.num_latents)), jnp.zeros(model.num_latents))

    state_init = (x0, y0.T, jnp.zeros(model.num_latents))
    brownian_motion = ma.CustomPath(ts[0], ts[-1], dt / 10, model.num_latents, model.num_k, key)
    terms = diffrax.MultiTerm(diffrax.ODETerm(drift), diffrax.WeaklyDiagonalControlTerm(diffusion, brownian_motion))
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        ts[0],
        ts[-1],
        dt0=dt,
        y0=state_init,
        saveat=saveat,
        # stepsize_controller=diffrax.PIDController(rtol=1e-2, atol=1e-2),
        max_steps=4096,
        args=args,
    )
    xs, ys, log_path_int = sol.ys
    return xs, log_path_int[-1]


class HurstNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        x = nn.tanh(x)
        x = nn.Dense(1)(x)
        x = nn.sigmoid(x)
        return x


class FractionalSDEMBM(FractionalSDE):
    def __init__(
            self,
            b: Function,
            u: Function,
            s: Function,
            gamma: jnp.ndarray,
            type: int = 1,
            time_horizon: float = 1.,
            num_latents: int = 1,
        ):
        self.gamma = gamma
        self.type = type
        self.num_latents = num_latents
        self._b = b
        self._u = u
        self._s = s
        self._hurst = None
        self._hurst_net = HurstNet()

        if type == 1:
            self.omega_fn = jax.jit(lambda hurst: ma.omega_optimized_1(self.gamma, hurst, time_horizon))
        elif type == 2:
            self.omega_fn = jax.jit(lambda hurst: ma.omega_optimized_2(self.gamma, hurst, time_horizon))
        else:
            raise ValueError('type must be either 1 or 2')

    def init(self, key):
        keys = jax.random.split(key, 4)
        params = {}

        params['b'] = self._b.init(keys[0])
        params['u'] = self._u.init(keys[1])
        params['s'] = self._s.init(keys[2])

        params['hurst'] = self._hurst_net.init(keys[4], jnp.zeros(3))
        params['hurst']['params']['Dense_1']['kernel'] *= 0.
        return params

    def hurst(self, params, t):
        return self._hurst_net.apply(params['hurst'], jnp.array([jnp.sin(t), jnp.cos(t), t])).squeeze()

    def __call__(self, params, key, x0, ts, dt, solver='euler', args=None):
        keys = jax.random.split(key, 4)

        if self.type == 1:
            cov = 1 / (self.gamma[None, :] + self.gamma[:, None])
            y0 = jax.random.multivariate_normal(keys[2], jnp.zeros((self.num_latents, self.num_k)), cov)
        elif self.type == 2:
            y0 = jnp.zeros((self.num_latents, self.num_k))

        if solver == 'euler':
            raise NotImplemented()

        xs, log_path = solve_diffrax(params, self, x0, y0, ts, dt, keys[3], solver, args)
        return xs, log_path

In [None]:
# data from Learning Fractional White Noise ..

ts = jnp.array([0.1000, 0.1101, 0.1201, 0.1302, 0.1402, 0.1503, 0.1603, 0.1704, 0.1804,
         0.1905, 0.2005, 0.2106, 0.2206, 0.2307, 0.2407, 0.2508, 0.2608, 0.2709,
         0.2809, 0.2910, 0.3010, 0.3111, 0.3211, 0.3312, 0.3412, 0.3513, 0.3613,
         0.3714, 0.3814, 0.3915, 0.4015, 0.4116, 0.4216, 0.4317, 0.4417, 0.4518,
         0.4618, 0.4719, 0.4819, 0.4920, 0.5020, 0.5121, 0.5221, 0.5322, 0.5422,
         0.5523, 0.5623, 0.5724, 0.5824, 0.5925, 0.6025, 0.6126, 0.6226, 0.6327,
         0.6427, 0.6528, 0.6628, 0.6729, 0.6829, 0.6930, 0.7030, 0.7131, 0.7231,
         0.7332, 0.7432, 0.7533, 0.7633, 0.7734, 0.7834, 0.7935, 0.8035, 0.8136,
         0.8236, 0.8337, 0.8437, 0.8538, 0.8638, 0.8739, 0.8839, 0.8940, 0.9040,
         0.9141, 0.9241, 0.9342, 0.9442, 0.9543, 0.9643, 0.9744, 0.9844, 0.9945,
         1.0045, 1.0146, 1.0246, 1.0347, 1.0447, 1.0548, 1.0648, 1.0749, 1.0849,
         1.0950, 1.1050, 1.1151, 1.1251, 1.1352, 1.1452, 1.1553, 1.1653, 1.1754,
         1.1854, 1.1955, 1.2055, 1.2156, 1.2256, 1.2357, 1.2457, 1.2558, 1.2658,
         1.2759, 1.2859, 1.2960, 1.3060, 1.3161, 1.3261, 1.3362, 1.3462, 1.3563,
         1.3663, 1.3764, 1.3864, 1.3965, 1.4065, 1.4166, 1.4266, 1.4367, 1.4467,
         1.4568, 1.4668, 1.4769, 1.4869, 1.4970, 1.5070, 1.5171, 1.5271, 1.5372,
         1.5472, 1.5573, 1.5673, 1.5774, 1.5874, 1.5975, 1.6075, 1.6176, 1.6276,
         1.6377, 1.6477, 1.6578, 1.6678, 1.6779, 1.6879, 1.6980, 1.7080, 1.7181,
         1.7281, 1.7382, 1.7482, 1.7583, 1.7683, 1.7784, 1.7884, 1.7985, 1.8085,
         1.8186, 1.8286, 1.8387, 1.8487, 1.8588, 1.8688, 1.8789, 1.8889, 1.8990,
         1.9090, 1.9191, 1.9291, 1.9392, 1.9492, 1.9593, 1.9693, 1.9794, 1.9894,
         1.9995, 2.0095, 2.0196, 2.0296, 2.0397, 2.0497, 2.0598, 2.0698, 2.0799,
         2.0899, 2.1000])
xs_true = jnp.array([1.0000, 0.9937, 1.0005, 0.9876, 0.9796, 0.9791, 0.9867, 0.9856, 0.9892,
         0.9907, 0.9864, 0.9746, 0.9900, 0.9949, 1.0365, 1.0618, 1.0819, 1.1109,
         1.1092, 1.1118, 1.1220, 1.1211, 1.1225, 1.1280, 1.1270, 1.1320, 1.1024,
         1.1070, 1.1285, 1.1464, 1.1806, 1.1934, 1.2075, 1.2029, 1.2125, 1.1870,
         1.1946, 1.2084, 1.2360, 1.2497, 1.2472, 1.2509, 1.2662, 1.2734, 1.2621,
         1.2656, 1.2658, 1.2818, 1.2558, 1.2291, 1.2402, 1.2648, 1.2644, 1.2727,
         1.2878, 1.2908, 1.2989, 1.2770, 1.2863, 1.2539, 1.2127, 1.2244, 1.2001,
         1.2045, 1.2041, 1.2243, 1.2070, 1.2022, 1.1823, 1.1668, 1.1811, 1.1936,
         1.1953, 1.2049, 1.2100, 1.1610, 1.1435, 1.1079, 1.0878, 1.1071, 1.0700,
         1.0551, 1.0690, 1.0428, 1.0379, 1.0231, 0.9947, 1.0249, 1.0539, 1.0343,
         1.0240, 0.9941, 1.0338, 1.0587, 1.0798, 1.1159, 1.1448, 1.1298, 1.1363,
         1.1637, 1.1520, 1.0876, 1.2128, 1.2526, 1.1630, 1.1583, 1.2025, 1.1835,
         1.1492, 1.1744, 1.1761, 1.1400, 1.2283, 1.2251, 1.2691, 1.2795, 1.1789,
         1.1070, 1.2180, 1.1753, 1.2512, 1.2314, 1.0933, 1.1711, 1.2679, 1.0846,
         1.1814, 1.2447, 1.2575, 1.2691, 1.0366, 1.0732, 0.9331, 0.9764, 0.9705,
         0.9592, 0.9539, 1.0171, 1.0782, 0.9953, 0.9362, 0.8997, 0.8441, 0.7643,
         0.7771, 0.7892, 0.7191, 0.6609, 0.7832, 0.9225, 1.1819, 1.2010, 1.2220,
         1.2501, 1.1765, 1.4034, 1.5753, 1.5900, 1.4096, 1.5693, 1.3191, 1.4366,
         1.5673, 1.7884, 2.3375, 2.1191, 2.0380, 2.0810, 2.4327, 2.5219, 2.8853,
         3.0536, 2.5550, 2.7076, 2.5918, 2.4321, 2.5801, 2.5948, 2.1993, 1.8813,
         2.3859, 2.6075, 2.8679, 2.4236, 2.0053, 2.0748, 1.7241, 1.4495, 1.4841,
         1.4260, 1.7823, 2.3027, 2.0535, 2.3248, 2.0475, 1.5701, 1.4721, 1.1525,
         0.9354, 0.9513])
hurst_true = jnp.array([0.7995, 0.7995, 0.7995, 0.7994, 0.7994, 0.7994, 0.7993, 0.7993, 0.7992,
         0.7991, 0.7991, 0.7990, 0.7989, 0.7989, 0.7988, 0.7987, 0.7986, 0.7985,
         0.7984, 0.7983, 0.7981, 0.7980, 0.7979, 0.7977, 0.7975, 0.7974, 0.7972,
         0.7970, 0.7968, 0.7965, 0.7963, 0.7960, 0.7957, 0.7954, 0.7951, 0.7947,
         0.7943, 0.7939, 0.7935, 0.7930, 0.7925, 0.7920, 0.7914, 0.7908, 0.7901,
         0.7894, 0.7887, 0.7879, 0.7870, 0.7861, 0.7851, 0.7840, 0.7829, 0.7817,
         0.7804, 0.7791, 0.7776, 0.7761, 0.7744, 0.7726, 0.7708, 0.7688, 0.7666,
         0.7644, 0.7620, 0.7594, 0.7567, 0.7539, 0.7508, 0.7476, 0.7442, 0.7407,
         0.7369, 0.7329, 0.7287, 0.7243, 0.7197, 0.7148, 0.7097, 0.7044, 0.6988,
         0.6930, 0.6870, 0.6807, 0.6742, 0.6675, 0.6605, 0.6533, 0.6460, 0.6384,
         0.6306, 0.6226, 0.6145, 0.6062, 0.5978, 0.5892, 0.5806, 0.5719, 0.5632,
         0.5544, 0.5456, 0.5368, 0.5281, 0.5194, 0.5108, 0.5022, 0.4938, 0.4855,
         0.4774, 0.4694, 0.4616, 0.4540, 0.4467, 0.4395, 0.4325, 0.4258, 0.4193,
         0.4130, 0.4070, 0.4012, 0.3956, 0.3903, 0.3852, 0.3803, 0.3757, 0.3713,
         0.3671, 0.3631, 0.3593, 0.3558, 0.3524, 0.3492, 0.3461, 0.3433, 0.3406,
         0.3380, 0.3356, 0.3334, 0.3312, 0.3292, 0.3274, 0.3256, 0.3239, 0.3224,
         0.3209, 0.3196, 0.3183, 0.3171, 0.3160, 0.3149, 0.3139, 0.3130, 0.3121,
         0.3113, 0.3106, 0.3099, 0.3092, 0.3086, 0.3080, 0.3075, 0.3070, 0.3065,
         0.3061, 0.3057, 0.3053, 0.3049, 0.3046, 0.3043, 0.3040, 0.3037, 0.3035,
         0.3032, 0.3030, 0.3028, 0.3026, 0.3025, 0.3023, 0.3021, 0.3020, 0.3019,
         0.3017, 0.3016, 0.3015, 0.3014, 0.3013, 0.3012, 0.3011, 0.3011, 0.3010,
         0.3009, 0.3009, 0.3008, 0.3007, 0.3007, 0.3006, 0.3006, 0.3006, 0.3005,
         0.3005, 0.3005])

# extracted from https://github.com/anh-tong/fractional_neural_sde/blob/main/fractional_neural_sde/example.ipynb
extracted_other = onp.array([
    [0.04205607476635531, 0.6878504672897197],
    [0.05368280052540575, 0.6874468774188118],
    [0.07808439564177938, 0.6860271482484046],
    [0.1024859907581529, 0.6844138196456692],
    [0.12688758587452653, 0.6826714247547149],
    [0.15128918099090005, 0.6807354304314324],
    [0.17569077610727357, 0.6786381032478763],
    [0.2000923712236472, 0.6764762429202108],
    [0.22449396634002072, 0.6741853163043264],
    [0.24889556145639424, 0.6719589228325515],
    [0.2732971565727679, 0.6699583953651596],
    [0.2976987516891414, 0.6684095999065336],
    [0.322100346805515, 0.6678610681816035],
    [0.34650194192188855, 0.6687645321991353],
    [0.37090353703826207, 0.6719266562604969],
    [0.3953051321546357, 0.6781541046670556],
    [0.4163792370278674, 0.6864466136851158],
    [0.43523501507233786, 0.6947391227031761],
    [0.4477383117435375, 0.7026766994286344],
    [0.46044327449007927, 0.7112596075951869],
    [0.4718374077468981, 0.7201296882530264],
    [0.4820663592350749, 0.7285258088191289],
    [0.49371817692149755, 0.7373632643874464],
    [0.5051123101783166, 0.7465366508226001],
    [0.5162039443221229, 0.7549485461572627],
    [0.5284047418803097, 0.7634077657976055],
    [0.5394963760241158, 0.7712754316169453],
    [0.5528063369966831, 0.7793374651203291],
    [0.5716621150411536, 0.7888445801007343],
    [0.5949545467431465, 0.7950236286492111],
    [0.6193561418595201, 0.7972822886930407],
    [0.6437577369758938, 0.7953785609418129],
    [0.6681593320922674, 0.79015137626895],
    [0.6914517637942603, 0.7826332649802029],
    [0.7114167052531113, 0.7748247545429634],
    [0.7291633198832013, 0.7671493437154495],
    [0.746909934513291, 0.7588528013758825],
    [0.7635473857290003, 0.7508414724857276],
    [0.7801848369447095, 0.7425259159161998],
    [0.7968222881604188, 0.7343117685731296],
    [0.8123505759617475, 0.7266553719755766],
    [0.8278788637630761, 0.7190496799912524],
    [0.8434071515644046, 0.7115961018466146],
    [0.8600446027801139, 0.7037939294860287],
    [0.8777912174102038, 0.6958523194390634],
    [0.8955378320402936, 0.6882212751481248],
    [0.914393610084764, 0.6804127647108853],
    [0.9343585515436152, 0.6724859435094451],
    [0.9543234930024664, 0.664953491522007],
    [0.9753975978756981, 0.6575196318380694],
    [0.998690029577691, 0.6495336552545289],
    [1.0230916246940647, 0.6417574113893441],
    [1.047493219810438, 0.6343683663888158],
    [1.0718948149268117, 0.6272051873926704],
    [1.0962964100431853, 0.6203646741170721],
    [1.1206980051595588, 0.6135886939855832],
    [1.1450996002759324, 0.6069417801423131],
    [1.169501195392306, 0.600423932587262],
    [1.1939027905086794, 0.5938092853160467],
    [1.218304385625053, 0.5871946380448314],
    [1.2427059807414267, 0.5805154576295066],
    [1.2671075758578, 0.573610411209799],
    [1.2915091709741737, 0.5664794987857082],
    [1.3159107660905474, 0.5591872535013441],
    [1.3403123612069208, 0.5515723424964328],
    [1.363604792908914, 0.5437638320591933],
    [1.385788061196526, 0.5360618013097342],
    [1.4068621660697578, 0.5284228696345157],
    [1.426827107528609, 0.5207721068828768],
    [1.44679204898746, 0.5128847226028369],
    [1.4656478270319304, 0.5050318456290221],
    [1.4833944416620202, 0.4974451678746587],
    [1.5011410562921104, 0.48959229090084394],
    [1.5188876709222001, 0.48165068085387874],
    [1.5366342855522899, 0.47362033773376305],
    [1.55438090018238, 0.4656343611502226],
    [1.5721275148124698, 0.4578258507129831],
    [1.5909832928569403, 0.44986452220531775],
    [1.6120573977301718, 0.44189826408247745],
    [1.6353498294321647, 0.4344446859378397],
    [1.6597514245485383, 0.4290561684047033],
    [1.684153019664912, 0.42686204150498314],
    [1.7085546147812856, 0.42844310353566384],
    [1.7318470464832785, 0.433912287498937],
    [1.7518119879421297, 0.44154333178987565],
    [1.7701131842794096, 0.4501504398854692],
    [1.7861960537879293, 0.4582251495421601],
    [1.799727847443372, 0.4668145110231236],
    [1.812815975733064, 0.4745520350018428],
    [1.8227984464624891, 0.4827154777316841],
    [1.8372175708494372, 0.49047328355569486],
    [1.854964185479527, 0.498293061684763],
    [1.877147453767139, 0.5036242164231355],
    [1.9015490488835127, 0.5015914223836888],
    [1.9204048269279832, 0.49401415571284135],
    [1.929278134243028, 0.4863712803454826],
    [1.9392606049724537, 0.47845629022046254],
    [1.9481339122874985, 0.4699379151980194],
    [1.9558980561881631, 0.4614195401755763],
    [1.9625530366744466, 0.4533744082099355],
    [1.9676551883805973, 0.44402785783808824],
    [1.974476543379038, 0.4345777855475654],
    [1.9797450695973462, 0.42574884476909564],
    [1.9865849106526932, 0.415751585194145],
    [1.9933323214235086, 0.4056064371639437],
    [1.9980462659346265, 0.3967035521578486],
    [2.0024829195921487, 0.38889504172060907],
    [2.0069195732496716, 0.38096822051916884],
    [2.0113562269071936, 0.3728047777893275],
    [2.0157928805647165, 0.3644638689131853],
    [2.0202295342222385, 0.3560046492728425],
    [2.0246661878797614, 0.34748627425039935],
    [2.0291028415372834, 0.33926367613845765],
    [2.0335394951948063, 0.33104107802651606],
    [2.0379761488523283, 0.3226410137682735],
    [2.042412802509851, 0.31412263874583035],
    [2.0479586195817543, 0.30509158374518464],
    [2.0490654205607477, 0.30130841121495333],
])
extracted_other[:, 0] -= extracted_other[0, 0] - .1
hurst_other = jnp.interp(ts, extracted_other[:, 0], extracted_other[:, 1])

alpha = .5
beta = .5

ts_ext = jnp.concatenate([jnp.array([0.]), ts])

fig, axes = plt.subplots(nrows=2, figsize=(12, 7), sharex=True)
axes[0].plot(ts, xs_true, 'x')
axes[1].plot(ts, hurst_true)
axes[1].plot(ts, hurst_other)
plt.show()

In [None]:
fbm_type = 2
num_k = 5
gamma_max = 20.
sigma = 0.025

num_training_steps = 1000

time = ts_ext[-1]
x0 = jnp.array([1.])
num_samples = len(xs_true)
dt = 5e-3
solver = diffrax.StratonovichMilstein()

assert gamma_max * dt < .5

gamma = ma.gamma_by_gamma_max(num_k, gamma_max)

In [None]:
class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(1000)(x)
        x = nn.tanh(x)
        x = nn.Dense(1000)(x)
        x = nn.tanh(x)
        x = nn.Dense(1)(x)
        return x

class Control:
    def __init__(self, num_coefficients=8):
        self.mlp = MLP()
        self.num_coefficients = num_coefficients

    def init(self, key):
        params = self.mlp.init(key, jnp.zeros(num_k + 1 + 2 * self.num_coefficients))
        # Initialization trick from Glow.
        params['params']['Dense_2']['kernel'] *= 0.0
        return params

    def __call__(self, params, t, x, y, args):
        sines = [jnp.sin(t / time * jnp.pi * 2 ** f) for f in range(self.num_coefficients)]
        cosines = [jnp.cos(t / time * jnp.pi * 2 ** f) for f in range(self.num_coefficients)]
        inputs = jnp.concatenate((jnp.array(sines), jnp.array(cosines), x, y.flatten()), axis=-1)
        return self.mlp.apply(params, inputs)

In [None]:
random_key = jax.random.PRNGKey(7)
random_key, key = jax.random.split(random_key, 2)

b = StaticFunction(lambda t, x, args: alpha * x)
u = Control()
s = StaticFunction(lambda *args: jnp.array([beta]))
model = FractionalSDEMBM(b, u, s, gamma, type=fbm_type, time_horizon=ts[-1])
params = model.init(key)

lr_schedule = optax.cosine_decay_schedule(3e-3, num_training_steps, alpha=.1)
optimizer = optax.chain(
    optax.clip_by_global_norm(10.),
    optax.adam(learning_rate=lr_schedule),
)
opt_state = optimizer.init(params)

def loss_fn(params, key, step):
    xs, kl = model(params, key, x0, ts_ext, dt, solver)
    neg_log_likelihood = .5 * ((xs[1:, 0] - xs_true) ** 2 / sigma ** 2).sum()
    loss = neg_log_likelihood + kl.sum()
    return loss, (xs,)

def batched_loss_fn(params, key, kl_weight, batch_size=4):
    keys = jax.random.split(key, batch_size)
    loss, _ = jax.vmap(loss_fn, in_axes=(None, 0, None))(params, keys, kl_weight)
    return loss.mean()

loss_grad = jax.jit(jax.value_and_grad(batched_loss_fn))
loss_values = []
pbar = tqdm(range(num_training_steps))
for step in pbar:
    random_key, key = jax.random.split(random_key)
    loss, grads = loss_grad(params, key, step)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    loss_values.append(loss)
    pbar.set_description(f'Loss: {float(loss):.2f}, Hurst: {model.hurst(params, 1.):.2f}')

In [None]:
ts_vis = jnp.linspace(ts_ext[0], ts_ext[-1], 1024)
fig, ax = plt.subplots(figsize=(12, 4))
num_samples = 64
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, num_samples)
x0 = jnp.array([1.])
xs, kl = jax.vmap(model, in_axes=(None, 0, None, None, None, None))(params, keys, x0, ts_vis, dt, solver)
xs = xs.squeeze()

std_empirical = jnp.std(xs, axis=0)
mean_empirical = jnp.mean(xs, axis=0)

for i in range(4):
    ax.plot(ts_vis, xs[i], color='black', alpha=.5)
ax.scatter(ts, xs_true, marker='x')

ax.plot(ts_vis, mean_empirical - std_empirical, 'g--', label='Empirical Variance')
ax.plot(ts_vis, mean_empirical + std_empirical, 'g--')

ax.set_xlim(0, time)
# ax.set_ylim(data_ylim)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(2.6, 2.6))
ax.plot(ts-.1, hurst_true, label='True')
ax.plot(ts-.1, jax.vmap(model.hurst, (None, 0))(params, ts), label='Ours')
ax.plot(ts-.1, hurst_other, label='Tong et al.')

# ax.set_ylim(.3, .8)
ax.legend()
ax.set_xlabel('$t$')
ax.set_ylabel('$H(t)$')
ax.set_xlim(0, 2)
# ax.set_ylim(0, 1)
fig.savefig('mbm.pdf', dpi=300, bbox_inches='tight')
plt.show()