-
Notifications
You must be signed in to change notification settings - Fork 2
/
arima_model.py
70 lines (56 loc) · 2.45 KB
/
arima_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pandas as pd
import numpy as np
import statsmodels.api as sm
from statsmodels.tsa.arima.model import ARIMA
from typing import List, Tuple
from datetime import datetime
from db_manager import DBManager
class ARIMAModel:
"""
ARIMA model class for predicting cryptocurrency prices
"""
def __init__(self, crypto_symbol: str, db_manager: DBManager):
"""
Initialize the ARIMA model instance
:param crypto_symbol: Symbol of the cryptocurrency
:type crypto_symbol: str
:param db_manager: Instance of the database manager class
:type db_manager: DBManager
"""
self.crypto_symbol = crypto_symbol
self.db_manager = db_manager
def train(self, end_date: datetime, start_date: datetime, p: int, d: int, q: int):
"""
Train the ARIMA model using historical data
:param end_date: End date for the historical data
:type end_date: datetime
:param start_date: Start date for the historical data
:type start_date: datetime
:param p: Order of the autoregressive part of the model
:type p: int
:param d: Degree of differencing
:type d: int
:param q: Order of the moving average part of the model
:type q: int
"""
df = self.db_manager.get_historical_prices(self.crypto_symbol, end_date, start_date)
df = df.reindex(index=df.index[::-1])
model = ARIMA(df, order=(p, d, q))
self.model_fit = model.fit()
def predict(self, end_date: datetime, num_periods: int) -> List[Tuple[datetime, float]]:
"""
Predict future cryptocurrency prices using the trained ARIMA model
:param end_date: End date for the predicted prices
:type end_date: datetime
:param num_periods: Number of periods to predict
:type num_periods: int
:return: List of tuples containing predicted dates and prices
:rtype: List[Tuple[datetime, float]]
"""
start_date = end_date - pd.DateOffset(days=num_periods)
df = self.db_manager.get_historical_prices(self.crypto_symbol, end_date, start_date)
df = df.reindex(index=df.index[::-1])
fcst, _, _ = self.model_fit.forecast(steps=num_periods)
fcst_dates = pd.date_range(end=end_date, periods=num_periods + 1, freq='D')[1:]
fcst_dates = [datetime.strptime(str(date), '%Y-%m-%d %H:%M:%S') for date in fcst_dates]
return list(zip(fcst_dates, fcst))