In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import quad

import math

In [None]:
from scipy.stats import norm

norm_err = norm(0, 1).rvs(100)

class TimeSeries():
    def __init__(self, mu=0, rho=0, trend=lambda t: 0, season=[], sigma_eps=lambda x_t, t: 1, error_dist=norm_err):
        self.mu = mu
        self.rho = rho
        self.trend = trend
        self.season = season
        self.sigma_eps = sigma_eps
        self.error_dist = error_dist / error_dist.std()
        
    def get_trial(self, T, t=1, num=1000):
        series, error = self.simulate(T)
        forecast = self.forecast(t, start=len(series), x0=series[-1], e0=error[-1], num=num)
        return series, forecast, np.random.choice(forecast)
        
    def simulate(self, T, start=0, x0=None, e0=0):
        def step(t, x0, e0):
            x_s = series[-1] if series else x0
            e_s = error[-1] if error else e0
            error.append(self.rho * e_s + self.error(x_s, t))
            series.append(
                self.mu # mean
                + self.trend(t) # trend
                + (self.season[t%len(self.season)-1] if self.season else 0) # seasonality
                + error[-1] # error
            )
        
        x0 = x0 if x0 is not None else self.mu
        series = []
        error = []
        [step(start+t, x0, e0) for t in range(1, T+1)]
        return series, error
        
    def error(self, x_s, t):
        sigma_eps = self.sigma_eps(x_s, t)
        return np.random.choice(self.error_dist * sigma_eps)
    
    def forecast(self, t=1, start=0, x0=None, e0=0, num=1000):
        return [self.simulate(t, start, x0=x0, e0=e0)[0][-1] for _ in range(num)]
    
def sin(wavelength, amp, start=0):
    series = amp * np.sin(np.linspace(0, 2*math.pi, num=wavelength))
    return list(np.append(series[start:], series[:start]))
    
timeseries = TimeSeries(rho=1, sigma_eps=lambda x_s, t: 4)
series, forecast, realization = timeseries.get_trial(20)
plt.plot(series)
plt.show()
plt.hist(forecast, density=True)
plt.show()
print('Realization', realization)