In [18]:
import gradio as gr
import pyotp
import qrcode
import io
import bcrypt
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from datetime import datetime
import tempfile  

In [None]:
# In-memory storage for registered users
users_db = {}

In [None]:
# Function to generate a QR code image for Google Authenticator
def generate_qr_code(username, secret):
    totp = pyotp.TOTP(secret)
    uri = totp.provisioning_uri(name=username, issuer_name="RSI searcher")
    qr_img = qrcode.make(uri)
    buffer = io.BytesIO()
    qr_img.save(buffer, format="PNG")
    buffer.seek(0)
    return Image.open(buffer)

# Function to register a new user
def register_user(username, password):
    if username in users_db:
        return "Username already exists!"
    
    # Hash the password and store the user's data
    hashed_password = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
    secret = pyotp.random_base32()  # Generate a secret for Google Authenticator
    users_db[username] = {"password": hashed_password, "secret": secret}
    
    return "Registration successful! You can now log in."

# Function to authenticate an existing user
def authenticate_user(username, password):
    if username not in users_db:
        return "User not found!", None
    
    # Verify the password
    user_data = users_db[username]
    if not bcrypt.checkpw(password.encode(), user_data['password']):
        return "Incorrect password!", None
    
    # Generate and return the QR code for Google Authenticator
    qr_code_img = generate_qr_code(username, user_data['secret'])
    return "Please scan the QR code with Google Authenticator and enter the OTP.", qr_code_img

# Function to verify the OTP
def verify_otp(username, otp):
    user_data = users_db.get(username)
    if not user_data:
        return "User not found!", False
    
    totp = pyotp.TOTP(user_data['secret'])
    if not totp.verify(otp):
        return "Invalid OTP!", False
    
    return "Login successful!", True

def switch_page(current_page):
        if current_page == "login":
            return "register", gr.update(visible=True), gr.update(visible=False), "Switch to Login"
        else:
            return "login", gr.update(visible=False), gr.update(visible=True), "Switch to Register"

# Actions
def handle_registration(username, password):
    msg = register_user(username, password)
    return msg

def handle_login(username, password):
    msg, qr_code_img = authenticate_user(username, password)
    if qr_code_img:
        return msg, gr.update(visible=True), qr_code_img
    else:
        return msg, gr.update(visible=False), None

# Function to verify the OTP and transition to the search UI
def handle_otp_verification(username, otp):
    msg, success = verify_otp(username, otp)
    if success:
        return msg, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), "rsi_search"
    else:
        return msg, gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), "login"

def set_func_name(selected_method):
        return 'ema' if selected_method == "Exponential Weighted Moving Average" else 'sma'

In [19]:
# Function to fetch stock data from Yahoo Finance
def fetch_stock_data(stock_symbol, start_date, end_date):
    # Convert dates to the required format
    start_timestamp = int(datetime.strptime(start_date, "%Y-%m-%d").timestamp())
    end_timestamp = int(datetime.strptime(end_date, "%Y-%m-%d").timestamp())
    
    # Yahoo Finance URL
    yahoo_url = f"https://query1.finance.yahoo.com/v7/finance/download/{stock_symbol}?period1={start_timestamp}&period2={end_timestamp}&interval=1d&events=history&includeAdjustedClose=true"
    
    # Fetch the data
    data = pd.read_csv(yahoo_url)
    
    # Convert Date to datetime
    data['Date'] = pd.to_datetime(data['Date'])
    
    return data

# Function to calculate RSI
def calculate_rsi(data, period=14, func_name='sma'):
    delta = data['Close'].diff()
    if func_name == 'ema':
        gain = (delta.where(delta > 0, 0)).ewm(span=period, adjust=False).mean()
        loss = (-delta.where(delta < 0, 0)).ewm(span=period, adjust=False).mean()
    else:
        gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()

    rs = gain / loss
    rsi = 100 - (100 / (1 + rs))
    
    plt.figure(figsize=(10, 6))
    plt.plot(rsi, label='RSI')
    plt.title('Relative Strength Index (RSI)')
    plt.legend()
    
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    plt.close()
    return buffer

# Function to fetch stock data and plot RSI
def get_rsi_chart(start_date, end_date, period, stock_symbol, func_name):
    # Fetch stock data from Yahoo Finance
    data = fetch_stock_data(stock_symbol, start_date, end_date)
    
    # Calculate RSI
    rsi_buffer = calculate_rsi(data, period=period, func_name=func_name)
    
    # Save the RSI plot to a temporary file
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
        temp_file.write(rsi_buffer.read())
        temp_file_path = temp_file.name
    
    return temp_file_path

In [20]:
# Gradio UI setup
with gr.Blocks(css=".gradio-container {background-color: green;}") as demo:
    current_page = gr.State("login")
    # Initialize func_name state
    func_name = gr.State("sma")

    # Top left button to switch between login and registration
    with gr.Row():
        switch_button = gr.Button("Switch to Register")

    # Registration components
    with gr.Column(visible=False) as register_column:
        reg_username_input = gr.Textbox(label="Register Username")
        reg_password_input = gr.Textbox(label="Register Password", type="password")
        register_button = gr.Button("Register")
        registration_message = gr.Markdown(value="", visible=True)

    # Login components
    with gr.Column(visible=True) as login_column:
        login_username_input = gr.Textbox(label="Login Username")
        login_password_input = gr.Textbox(label="Login Password", type="password")
        login_button = gr.Button("Login")
        login_message = gr.Markdown(value="", visible=True)
        qr_code_display = gr.Image(visible=False, label="Scan this QR code with Google Authenticator")
        otp_input = gr.Textbox(label="Enter OTP from Google Authenticator", type="password")
        otp_button = gr.Button("Verify OTP")
        otp_message = gr.Markdown(value="", visible=True)

    # RSI search components (initially hidden)
    with gr.Column(visible=False) as rsi_column:
        start_date_input = gr.Textbox(label="Start Date (YYYY-MM-DD)")
        end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)")
        period_input = gr.Slider(minimum=1, maximum=100, value=14, label="RSI Period", interactive=True)
        stock_symbol_input = gr.Textbox(label="Stock Symbol")
        math_input = gr.Radio(["Exponential Weighted Moving Average", "Smooth Moving Average"], label="Moving Average Type") #sma/ewma
        calculate_button = gr.Button("Calculate RSI")
        rsi_chart_output = gr.Image()

    register_button.click(
        handle_registration,
        [reg_username_input, reg_password_input],
        registration_message
    )

    login_button.click(
        handle_login,
        [login_username_input, login_password_input],
        [login_message, qr_code_display, qr_code_display]
    )
    
    otp_button.click(
        handle_otp_verification,
        [login_username_input, otp_input],
        [otp_message, login_column, qr_code_display, rsi_column, current_page]
    )

    
    math_input.change(
        set_func_name,
        [math_input],
        [func_name]
    )
    
    # Switch page action
    switch_button.click(
        lambda current_page: switch_page(current_page),
        current_page,
        [current_page, register_column, login_column, switch_button]
    )

    #get rsi & chart
    calculate_button.click(
        lambda start_date, end_date, period, stock_symbol, func_name: get_rsi_chart(start_date, end_date, period, stock_symbol, func_name),
        [start_date_input, end_date_input, period_input, stock_symbol_input, func_name],
        rsi_chart_output
    )
    
# Launch the app
demo.launch()

Running on local URL:  http://127.0.0.1:7867

To create a public link, set `share=True` in `launch()`.




In [21]:
users_db

{'': {'password': b'$2b$12$oEbZMakjPQJ582Y5wxx5zeLKMt/bUumQ/aytSuA5YYzL793efOeE6',
  'secret': 'XD7Q7SFMO3OXWR4BSLJGMRWJSJCT2IQS'}}