## 1. Packages & Initializations

In [20]:
from flask import Flask, request, jsonify, render_template, redirect, url_for, session
from flask_sqlalchemy import SQLAlchemy
from werkzeug.security import check_password_hash, generate_password_hash
from abc import ABC, abstractmethod
from statsmodels.tsa.ar_model import AutoReg
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import r2_score
from itertools import product
import pickle
import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
from datetime import datetime, timedelta
import string
import random
import re

In [21]:
app = Flask(__name__)
# Configure the SQLAlchemy part of the app instance
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///aidea.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.secret_key = 'Doggan98-ddd'
# Create the SQLAlchemy db instance
db = SQLAlchemy(app)

## 2. Models

### 2.0 Model Class Abstraction

In [22]:
class Model(ABC):
    def __init__(self, data):
        self.data = data

    @abstractmethod
    def train(self):
        """
        Abstract method to train the model.
        """
        pass
    
    @abstractmethod
    def forecast(self, forecast_days):
        """
        Abstract method to make predictions using the trained model.
        """
        pass

### 2.1 AR Model

In [23]:
class AR_model(Model):
    def __init__(self, data):
        super().__init__(data)
        self.trained_model = None
        self.model_type = 'AR'

    def train(self):
        # Handle NaNs
        data = self.data.dropna(subset=['Close'])

        # Define parameter grid for tuning
        trends = ['n', 'c', 't', 'ct']
        min_lag = 1
        max_lag = len(data) 
        lags_range = range(min_lag, max_lag) 

        best_r2 = -float('inf') 
        best_params = None

        # Perform grid search with cross-validation on the training set
        # Choose the best params based on R2 score
        n_splits = 3
        tscv = TimeSeriesSplit(n_splits=n_splits)  # Time series cross-validation
        warnings.filterwarnings("ignore")
        for trend, lags in product(trends, lags_range):
            r2_sum = 0
            for train_index, val_index in tscv.split(data):
                train_split, val_split = data.iloc[train_index], data.iloc[val_index]
                try:
                    model = AutoReg(train_split['Close'].values, lags=lags, trend=trend).fit()
                    predictions = model.predict(start=len(train_split), end=len(train_split) + len(val_split) - 1)
                    r2 = r2_score(val_split['Close'], predictions)
                    r2_sum += r2
                except Exception as e:
                    print(f"Error for trend={trend}, lags={lags}: {e}")
                    continue
            
            # Average R2 score across folds
            avg_r2 = r2_sum / n_splits
            
            # Update best parameters if better R2 found
            if avg_r2 > best_r2:
                best_r2 = avg_r2
                best_params = (trend, lags)

        best_trend, best_lags = best_params
        
        print(f"Best R2 score: {best_r2:.4f}")
        print(f"Best parameters: trend={best_trend}, lags={best_lags}")
        
        # Fit the best model on the entire dataset 
        try:
            self.trained_model = AutoReg(data['Close'].values, lags=best_lags, trend=best_trend).fit()
        except Exception as e:
            print(f'Model training failed with the error message: {e}')
        
        return self.trained_model
    
    def forecast(self, forecast_days):
        #Forecast next forecast_period days
        start = len(self.data)
        end = start + forecast_days - 1
        forecast_prices = self.trained_model.predict(start=start, end=end)
        
        return forecast_prices


### Model Creation Method

In [24]:
def create_model(model_type, data):
    if model_type == 'AR':
        return AR_model(data)

## 3. Database for User and Model Tables

### 3.1 Users Table

In [25]:
class User(db.Model):
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    email = db.Column(db.String(120), unique=True, nullable=False)
    password = db.Column(db.String(255), nullable=False)
    account_type = db.Column(db.String(20), nullable=False)  # 'basic' or 'premium'

    def __repr__(self):
        return f'<User {self.email}>'

### 3.2 Trained Models Table

In [26]:
class TrainedModels(db.Model):
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    symbol = db.Column(db.String(20), nullable=False)
    model_type = db.Column(db.String(50), nullable=False)
    start_date = db.Column(db.String(50), nullable=False)
    end_date = db.Column(db.String(50), nullable=False)  # Nullable if model is ongoing
    trained_model = db.Column(db.Text)  # Serialized model data or file path

    def save_trained_model(self, trained_model):
        """
        Save the trained model to the database.
        """
        self.trained_model = pickle.dumps(trained_model)
        db.session.commit()
    
    def load_trained_model(self, model_type, start_date, end_date, symbol):
        model = TrainedModels.query.filter_by(
                    model_type=model_type,
                    start_date=start_date,
                    end_date=end_date,
                    symbol=symbol
                ).first()
        if model:
            return pickle.loads(model.trained_model)
        else:
            return None

### 3.3 Temporary Password

In [27]:
class TemporaryPassword(db.Model):
    email = db.Column(db.String(120), primary_key=True, nullable=False)
    temp_password = db.Column(db.String(255), nullable=False)

## 4. Fetch Historical Data

In [28]:
def fetch_data(symbol, start_date, end_date):
    data = None
    try:
        # Fetch historical price data
        df = yf.download(symbol, start=start_date, end=end_date)

        # Drop 'Adj Close' column if present
        if 'Adj Close' in df.columns:
            df.drop(columns=['Adj Close'], inplace=True)

        # Store the dataframe with technical indicators
        data = df

    except Exception as e:
        print(f"Error fetching data for {symbol}: {e}")

    return data

## 5. PAGES

In [29]:
@app.route('/')
def index():
    return render_template('login.html')

### 5.1 Login 

In [30]:
@app.route('/login', methods=['GET', 'POST'])
def login():
    if request.method == 'GET':
        # Serve the signup page
        return render_template('login.html')

    if request.method == 'POST':
        data = request.json
        email = data.get('email')
        password = data.get('password')
        
        # Query the user from the database
        user = User.query.filter_by(email=email).first()

        if user and check_password_hash(user.password, password): # MAIN PAGE
            # If user exists and password is correct, redirect to the main page
            return jsonify({'message': 'Login successful!', 'redirect': url_for('main')}), 200
        else:
            # If credentials do not match, return an error
            return jsonify({'message': 'Invalid credentials'}), 401


### 5.2 Signup

In [31]:
def is_valid_password(password):
    """Check if the password is valid based on the criteria."""
    if len(password) < 8: # Check password length
        return False
    if not re.search(r'[A-Z]', password):  # Check for uppercase letter
        return False
    if not re.search(r'[a-z]', password):  # Check for lowercase letter
        return False
    if not re.search(r'[0-9]', password):  # Check for number
        return False
    return True

In [32]:
@app.route('/signup', methods=['GET', 'POST'])
def signup():
    if request.method == 'GET':
        # Serve the signup page
        return render_template('signup.html')
    
    if request.method == 'POST':
        data = request.json
        if not data:
            return jsonify({'message': 'No data provided'}), 400
        
        email = data.get('email')
        password = data.get('password')
        password_re = data.get('password_re')
        
        # Check if all fields are present
        if not email or not password or not password_re:
            return jsonify({'message': 'All fields are required.'}), 400

        # Check if email already exists
        existing_user = User.query.filter_by(email=email).first()
        if existing_user:
            return jsonify({'message': 'User already exists.'}), 400

        # Check if passwords match
        if password != password_re:
            return jsonify({'message': 'Passwords do not match.'}), 400
        
        # Validate password strength
        if not is_valid_password(password):
            return jsonify({'message': 'Password must be at least 8 characters long, include uppercase letters, lowercase letters, and numbers.'}), 400
        
        # Create new user
        hashed_password = generate_password_hash(password, method='pbkdf2:sha256')
        new_user = User(email=email, password=hashed_password, account_type='basic')

        try:
            db.session.add(new_user)
            db.session.commit()
            return jsonify({'redirect': url_for('login')}), 200
        except Exception as e:
            return jsonify({'message': 'Error creating user.'}), 500


### 5.3 Request E-mail

In [33]:
def generate_temporary_password(length=8):
    letters_and_digits = string.ascii_letters + string.digits
    return ''.join(random.choice(letters_and_digits) for i in range(length))

def send_email(email, temporary_password):
    # Placeholder function to simulate sending an email
    # Implement actual email sending logic here
    print(f"Sending temporary password to {email}: {temporary_password}")

In [34]:
@app.route('/request_email', methods=['GET', 'POST'])
def request_email():
    if request.method == 'POST':
        data = request.get_json()
        email = data.get('email')
        
        # Check if the email exists in the User database
        user = User.query.filter_by(email=email).first()
        if user:  
            # Generate a temporary password
            temp_password = generate_temporary_password()
            # Check if a temporary password entry already exists for the email
            temp_password_entry = TemporaryPassword.query.filter_by(email=email).first()
            if temp_password_entry:
                db.session.delete(temp_password_entry)
                db.session.commit()

            temp_password_entry = TemporaryPassword(email=email, temp_password=temp_password)
            db.session.add(temp_password_entry)
            db.session.commit()
            
            # Send a temporary password
            send_email(email, temp_password)

            # Store the email in a session to use in forgot password section
            session['resetEmail'] = email
            
            return jsonify({'success': True, 'message': 'Temporary password sent to your email'})
        else:
            return jsonify({'success': False, 'message': 'Email not found'})

    return render_template('request_email.html')


### 5.4 Forgot Passward

In [35]:
@app.route('/forgot_password', methods=['GET', 'POST'])
def forgot_password():
    if request.method == 'POST':
        data = request.get_json()
        temp_password = data.get('tempPassword')
        
        #Retrieve email from session
        email = session.get('resetEmail')

        if not email:
            return jsonify({'success': False, 'message': 'No email found in session. Please request an email first.'})

        # Check the temporary password in the database
        temp_password_entry = TemporaryPassword.query.filter_by(email=email, temp_password=temp_password).first()
        if temp_password_entry:
            # Temporary password matches
            return jsonify({'success': True, 'message': 'Temporary password verified. Proceed to reset password.'})
        else:
            return jsonify({'success': False, 'message': 'Invalid temporary password'})

    return render_template('forgot_password.html')

### 5.5 Reset Password

In [36]:
@app.route('/reset_password', methods=['GET', 'POST'])
def reset_password():
    if request.method == 'POST':
        data = request.get_json()
        new_password = data.get('newPassword')
        new_password_re = data.get('newPassword_re')

        # Retrieve the email from session
        email = session.get('resetEmail')

        if not email:
            return jsonify({'success': False, 'message': 'No email found in session. Please request an email first.'})

        if new_password != new_password_re:
            return jsonify({'success': False, 'message': 'Passwords do not match'})
        
        # Validate password strength
        if not is_valid_password(new_password):
            return jsonify({'message': 'Password must be at least 8 characters long, include uppercase letters, lowercase letters, and numbers.'}), 400

        # Fetch the user from the database
        user = User.query.filter_by(email=email).first()
        if user:
            # Check if the new password is different from the current password
            if check_password_hash(user.password, new_password):
                return jsonify({'success': False, 'message': 'New password must be different from the old password'})
            
            # Update the password
            hashed_new_password = generate_password_hash(new_password, method='pbkdf2:sha256')
            user.password = hashed_new_password
            db.session.commit()
            return jsonify({'success': True, 'message': 'Password successfully updated. You can now login with your new password.'})
        else:
            return jsonify({'success': False, 'message': 'User not found'})

    # Ensure session email is set, otherwise redirect to request_email
    if not session.get('resetEmail'):
        return jsonify({'redirect': '/request_email', 'message': 'No email found in session. Please request an email first.'})

    return render_template('reset_password.html')

### 5.6 Main 

In [37]:
@app.route('/main', methods=['GET', 'POST'])
def main():
    # Data fetching
    return render_template('main.html')
    try:
        data = request.json
        symbol = data['symbol']
        data_length = data['data_length']
        forecast_days = data['forecast_days']
        model_type = data['model_type']
    
        #Set start and end date
        now = datetime.now()
        start_date =  (now - timedelta(days = data_length)).strftime("%Y-%m-%d")
        end_date = now.strftime("%Y-%m-%d")

        #Check if model already exists
        model = TrainedModels.load_trained_model(model_type=model_type, start_date=start_date, end_date=end_date, symbol=symbol)
        forecast_prices = None
        if model:  # if the model is already trained avoid re-training it.
            trained_model = model
            # Forecast
            forecast_prices = trained_model.forecast(forecast_days = forecast_days)
        else:
            data = fetch_data(symbol, start_date, end_date)
            model = create_model(model_type, data)
            trained_model = model.train()
            # Create a new instance of Model
            new_model = TrainedModels(symbol= symbol, model_type= model_type, start_date= start_date, end_date= end_date)
            # Save the trained model to the database
            new_model.save_trained_model(trained_model)
            # Commit the changes
            db.session.add(new_model)
            db.session.commit()
            # Forecast
            forecast_prices = trained_model.forecast(forecast_days = forecast_days)

        # Example: Data to send to Firebase
        processed_data = {
            'forecast_prices': forecast_prices,
            'processed': True
        }

        return jsonify(processed_data), 200 

    except KeyError as e:
        return jsonify({'error': f'Missing key: {e.args[0]}'}, 400)

    except Exception as e:
        return jsonify({'error': str(e)}, 500)

In [None]:
if __name__ == '__main__':
    with app.app_context():
        db.create_all()
    app.run(debug=True, use_reloader=False)