In [1]:
import yfinance as yf
import pandas as pd
import numpy as np
import math
from datetime import datetime, timedelta, date
import time
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import matplotlib.patches as patches
from matplotlib.colors import TwoSlopeNorm
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from mapping_tickers import *
from utils import *
from mapping_plot_attributes import *
from mapping_portfolio_downloads import *
from analyze_prices import AnalyzePrices
from download_data import DownloadData

In [2]:
tk = 'AAPL'
tk_market = '^GSPC'
end_date = datetime.today()
start_date = datetime(end_date.year - 1, end_date.month, end_date.day)

hist_data = DownloadData(end_date, start_date, [tk], tk_market)
analyze_prices = AnalyzePrices(end_date, start_date, [tk])

ticker_data = hist_data.download_yh_data(start_date, end_date, [tk], tk_market)

df_adj_close = pd.DataFrame()
df_close = pd.DataFrame()
df_ohlc = pd.DataFrame()

df_adj_close = ticker_data['Adj Close']
df_close = ticker_data['Close']
df_ohlc = ticker_data['OHLC'][tk] 

ohlc_tk = df_ohlc.copy()
adj_close_tk = df_adj_close[tk]
close_tk = df_close[tk]
open_tk = ohlc_tk['Open']
high_tk = ohlc_tk['High']
low_tk = ohlc_tk['Low']

price_type_map = {
    'Adj Close': adj_close_tk,
    'Adjusted Close': adj_close_tk,
    'Close': close_tk,
    'Open': open_tk,
    'High': high_tk,
    'Low': low_tk
}

# theme = 'light'
theme = 'dark'
style= theme_style[theme]

[*********************100%%**********************]  1 of 1 completed
[*********************100%%**********************]  1 of 1 completed

The portfolio data will be truncated to end at the latest available date of 2024-10-04.





In [3]:
def macd(
    df_tk,
    signal_window = 9      
):
    """
    df_tk: a series of price values, taken as a column of df_close or df_adj_close for ticker tk
    
    """ 

    if not isinstance(df_tk, pd.Series):
        print('Incorrect format of input data')
        exit
    
    ema_26 = df_tk.ewm(span = 26).mean()
    ema_12 = df_tk.ewm(span = 12).mean()
    macd_line = ema_12 - ema_26
        
    macd_signal = macd_line.ewm(span = signal_window).mean()

    macd_data = {
        'MACD': macd_line,
        'MACD Signal': macd_signal,
        'MACD Signal Window': signal_window
    }

    return macd_data

In [4]:
def plot_macd_plotly(
    tk_macd,
    macd_data,
    df_tk,
    n_ticks_max = 48,
    n_yticks_max = 16,
    plot_width = 1450,
    plot_height = 750,
    title_font_size = 32,
    legendgroup = 'upper',
    theme = 'dark',
    overlay_price = False,
    price_type = 'close'
):
    """
    MACD plot with a signal line and the original price line overlayed, if desired
    price_type: Normally 'adjusted close' or 'close' or whatever MACD is based on
    """
         
    macd = macd_data['MACD']
    macd_signal = macd_data['MACD Signal']
    macd_signal_window = macd_data['MACD Signal Window']
 
    x_min = str(macd.index.min().date())
    x_max = str(macd.index.max().date())
 
    min_macd = min(min(macd), min(macd_signal))
    max_macd = max(max(macd), max(macd_signal))
    y_macd_min, y_macd_max = set_axis_limits(min_macd, max_macd)
    
    macd_positive = macd.copy()
    macd_positive.iloc[np.where(macd_positive < 0)] = np.nan
    macd_negative = macd.copy()
    macd_negative.iloc[np.where(macd_negative >= 0)] = np.nan
    
    title_macd = f'{tk_macd} Moving Average Convergence Divergence (EMA 12-26)'

    price_types = ['adjusted close', 'adj close', 'close', 'open', 'high', 'low']
    if price_type in price_types:
        price_name = 'Adjusted Close' if price_type == 'adj close' else price_type.title()
    else:
        price_name = 'Adjusted Close'
    
    style = theme_style[theme]
    
    if overlay_price:
        fig_macd = make_subplots(specs=[[{'secondary_y': True}]])
    else:
        fig_macd = make_subplots(rows = 1, cols = 1)
    
    fig_macd.add_trace(
        go.Bar(
            x = macd_positive.index.astype(str),
            y = macd_positive,
            marker_color = style['green_color'],
            width = 1,
            name = 'MACD > 0',
            legendgroup = legendgroup
        ),
        secondary_y = False
    )
    fig_macd.add_trace(
        go.Bar(
            x = macd_negative.index.astype(str),
            y = macd_negative,
            marker_color = style['red_color'],
            width = 1,
            name = 'MACD < 0',
            legendgroup = legendgroup
        ),
        secondary_y = False
    )
    fig_macd.add_trace(
        go.Scatter(
            x = macd_signal.index.astype(str),
            y = macd_signal,
            line = dict(color = style['signal_color']),
            # name = 'Signal Line'  # 9-day span is a standard, no need to customise it
            name = f'MACD EMA {macd_signal_window} Signal',
            legendgroup = legendgroup
        ),
        secondary_y = False
    )
    if overlay_price:
        fig_macd.add_trace(
            go.Scatter(
                x = macd.index.astype(str),
                y = df_tk,
                line = dict(color = style['basecolor']),
                name = price_name,
                legendgroup = legendgroup
            ),
            secondary_y = True
        )
    
    # Add plot border
    fig_macd.add_shape(
        type = 'rect',
        xref = 'x',  # use 'x' because of seconday axis - 'paper' does not work correctly
        yref = 'paper',
        x0 = x_min,
        x1 = x_max,
        y0 = 0,
        y1 = 1,
        line_color = style['x_linecolor'],
        line_width = 0.3
    )
    # Update layout and axes
    fig_macd.update_layout(
        width = plot_width,
        height = plot_height,
        xaxis_rangeslider_visible = False,
        template = style['template'],
        legend_groupclick = 'toggleitem',
        title = dict(
            text = title_macd,
            font_size = title_font_size,
            y = 0.95,
            x = 0.45,
            xanchor = 'center',
            yanchor = 'top'
        )
    )
    fig_macd.update_yaxes(
        title_text = f'MACD',
        range = (y_macd_min, y_macd_max),
        secondary_y = False,
        gridcolor = style['y_gridcolor'],
        nticks = n_yticks_max,
        ticks = 'outside',
        ticklen = 8,
        ticklabelshift = 5,  # not working
        ticklabelstandoff = 10  # not working
    )
    if overlay_price:
        fig_macd.update_yaxes(
            title_text = price_name,
            secondary_y = True,
            ticks = 'outside',
            ticklen = 8,
            ticklabelshift = 5,  # not working
            ticklabelstandoff = 10,  # not working
            showgrid = False
        )
    fig_macd.update_xaxes(
        type = 'category',
        nticks = n_ticks_max,
        tickangle = -90,
        gridcolor = style['x_gridcolor'],
        ticks = 'outside',
        ticklen = 8,
        ticklabelshift = 5,  # not working
        ticklabelstandoff = 10,  # not working
        showgrid = True  # Thanks to Copilot for showing me this option! Googling didn't get me anywhere...
    )

    return fig_macd

In [5]:
macd_data = macd(adj_close_tk)
theme = 'light'

fig_macd = plot_macd_plotly(tk, macd_data, adj_close_tk, overlay_price = True, theme = theme)
print(fig_macd['data'])
fig_macd.show()

(Bar({
    'legendgroup': 'upper',
    'marker': {'color': 'rgba(0, 128, 0, 1)'},
    'name': 'MACD > 0',
    'width': 1,
    'x': array(['2023-10-06', '2023-10-09', '2023-10-10', ..., '2024-10-02',
                '2024-10-03', '2024-10-04'], dtype=object),
    'xaxis': 'x',
    'y': array([0.        , 0.03348233, 0.02415209, ..., 1.6034257 , 1.4124872 ,
                1.33693764]),
    'yaxis': 'y'
}), Bar({
    'legendgroup': 'upper',
    'marker': {'color': 'rgba(178, 34, 34, 1)'},
    'name': 'MACD < 0',
    'width': 1,
    'x': array(['2023-10-06', '2023-10-09', '2023-10-10', ..., '2024-10-02',
                '2024-10-03', '2024-10-04'], dtype=object),
    'xaxis': 'x',
    'y': array([nan, nan, nan, ..., nan, nan, nan]),
    'yaxis': 'y'
}), Scatter({
    'legendgroup': 'upper',
    'line': {'color': 'orange'},
    'name': 'MACD EMA 9 Signal',
    'x': array(['2023-10-06', '2023-10-09', '2023-10-10', ..., '2024-10-02',
                '2024-10-03', '2024-10-04'], dtype=object)

In [6]:
def add_macd_stacked(
    fig_data,
    tk_macd,
    macd_data,
    df_tk,
    # x_min = None,
    # x_max = None,
    legendgroup = 'lower',
    theme = 'dark'
):
    """
    MACD with a signal line added to the lower stacked plot
    """
    
    # x_min = start_date if x_min is None else x_min
    # x_max = end_date if x_max is None else x_max

    fig_macd = fig_data['fig']

    style = theme_style[theme]
    
    subplot_row = legendgroup_map[legendgroup]
        
    macd = macd_data['MACD']
    macd_signal = macd_data['MACD Signal']
    macd_signal_window = macd_data['MACD Signal Window']
 
    # x_min = str(macd.index.min().date())
    # x_max = str(macd.index.max().date())
 
    min_macd = min(min(macd), min(macd_signal))
    max_macd = max(max(macd), max(macd_signal))
    y_macd_min, y_macd_max = set_axis_limits(min_macd, max_macd)
    
    macd_positive = macd.copy()
    macd_positive.iloc[np.where(macd_positive < 0)] = np.nan
    macd_negative = macd.copy()
    macd_negative.iloc[np.where(macd_negative >= 0)] = np.nan
    
    title_macd = f'{tk_macd} Moving Average Convergence Divergence (EMA 12-26)'
     
    fig_macd.add_trace(
        go.Bar(
            x = macd_positive.index.astype(str),
            y = macd_positive,
            marker_color = style['green_color'],
            width = 1,
            name = 'MACD > 0',
            legendgroup = legendgroup
        ),
        row = subplot_row, col = 1
    )
    fig_macd.add_trace(
        go.Bar(
            x = macd_negative.index.astype(str),
            y = macd_negative,
            marker_color = style['red_color'],
            width = 1,
            name = 'MACD < 0',
            legendgroup = legendgroup
        ),
        row = subplot_row, col = 1
    )
    fig_macd.add_trace(
        go.Scatter(
            x = macd_signal.index.astype(str),
            y = macd_signal,
            line = dict(color = style['signal_color']),
            # name = 'Signal Line'  # 9-day span is a standard, no need to customise it
            name = f'MACD EMA {macd_signal_window} Signal',
            legendgroup = legendgroup
        ),
        row = subplot_row, col = 1
    )
    
    # Update layout and axes
    fig_macd.update_yaxes(
        title_text = f'MACD',
        range = (y_macd_min, y_macd_max),
        gridcolor = style['y_gridcolor'],
        ticks = 'outside',
        ticklen = 8,
        row = subplot_row, col = 1
    )

    fig_macd_data = {
        'fig': fig_macd,
        'y_min': y_macd_min,
        'y_max': y_macd_max
    }

    return fig_macd_data 

In [7]:
def weighted_mean(values):
    """
    values: a list, tuple or series of numerical values
    """
    if isinstance(values, (list, tuple)):
        values = pd.Series(values)
    
    n = len(values)
    weight_sum = n * (n + 1) / 2
    weights = range(n + 1)[1:]
    wm = values @ weights / weight_sum
    return wm

In [8]:
def moving_average(
    df_tk,
    ma_type,
    ma_window,
    min_periods = 1
):
    """
    df_tk:      
        a series of price values, taken as a column of df_close or df_adj_close for ticker tk
    ma_type:    
        simple ('sma'),
        exponential ('ema'),
        double exponential ('dema'),
        triple exponential ('tema'),
        weighted ('wma')
    window:
        length in days
    Returns ma
    """

    if not isinstance(df_tk, pd.Series):
        print('Incorrect format of input data')
        exit
    
    if ma_type in ['ema', 'dema', 'tema']:
        ma = df_tk.ewm(span = ma_window).mean()
        if ma_type in ['dema', 'tema']:
            ma = ma.ewm(span = ma_window).mean()
            if ma_type == 'tema':
                ma = ma.ewm(span = ma_window).mean()
    
    elif ma_type == 'wma':
        ma = df_tk.rolling(window = ma_window, min_periods = min_periods).apply(lambda x: weighted_mean(x))

    else:  # 'sma' or anything else
        ma = df_tk.rolling(window = ma_window, min_periods = min_periods).mean()
    
    return ma

In [9]:
def stochastic_oscillator(
    close_tk,
    high_tk,
    low_tk,
    fast_k_period = 14,
    smoothing_period = 3,
    sma_d_period = 3,
    stochastic_type = 'Slow'
):
    """
    stochastic_type: 'Fast', 'Slow', 'Full'
    NOTES:
    1) fast_k_period is also know as the look--back period
    2) smoothing_period is the period used in slow %K and full %K
    3) sma_d_period is the %D averaging period used in fast, slow and full stochastics
    4) if sma_d_period == smoothing_period, then the slow and full stochastics become equivalent
    
    """
    fast_low = low_tk.rolling(window = fast_k_period, min_periods = 1).min()
    fast_high = high_tk.rolling(window = fast_k_period, min_periods = 1).max()
    fast_k_line = 100 * (close_tk - fast_low) / (fast_high - fast_low)

    if stochastic_type.lower() == 'fast':
        
        k_line = fast_k_line.copy()    
        d_line = k_line.rolling(window = sma_d_period, min_periods = 1).mean()
        stochastic_label = f'({fast_k_period}, {sma_d_period})'
        stochastic_type = 'Fast'

    elif (stochastic_type.lower() == 'full') | (sma_d_period != smoothing_period):
        
        k_line = fast_k_line.rolling(window = smoothing_period, min_periods = 1).mean()
        d_line = k_line.rolling(window = sma_d_period, min_periods = 1).mean()
        stochastic_label = f'({fast_k_period}, {smoothing_period}, {sma_d_period})'
        stochastic_type = 'Full'

    else:
        # This includes the case of 
        # (stochastic_type == 'slow') | (sma_d_period == smoothing_period)
        # and any other stochastic_type specified.
        
        k_line = fast_k_line.rolling(window = smoothing_period, min_periods = 1).mean()
        d_line = k_line.rolling(window = sma_d_period, min_periods = 1).mean()
        stochastic_label = f'({fast_k_period}, {sma_d_period})'
        stochastic_type = 'Slow'

    k_line.index = k_line.index.astype(str)
    d_line.index = d_line.index.astype(str)

    stochastic_data = {
        'k_line': k_line,
        'd_line': d_line,
        'label': stochastic_label,
        'type': stochastic_type
    }

    return stochastic_data

In [10]:
# These must be specified by the user, except signal_type and signal_window if the user chooses not to add signal
# The price types should be capitalized as they appear in the menu, must also be consistent with price_list
# The MA acronyms can stay lower case, easy to convert 

diff_data = {
    'p_base': 'close',  # 'adjusted close', 'adj close', 'close', 'open', 'high', 'low'
    'p1_type': 'ema',  # 'adjusted close', 'adj close', 'close', 'open', 'high', 'low', 'sma', 'ema', 'dema', 'tema', 'wma'
    'p2_type': 'wma',  # 'adjusted close', 'adj close', 'close', 'open', 'high', 'low', 'sma', 'ema', 'dema', 'tema', 'wma'
    'p1_window': 10,
    'p2_window': 10,
    'signal_type': 'ema',  # 'sma', 'ema', 'dema', 'tema', 'wma'
    'signal_window': 5
}
diff_data_stochastic = {
    'p_base': 'close',  # 'adjusted close', 'adj close', 'close', 'open', 'high', 'low'
    'p1_type': 'k-line',  # 'adjusted close', 'adj close', 'close', 'open', 'high', 'low', 'sma', 'ema', 'dema', 'tema', 'wma', 'wwma', 'k-line', 'd-line'
    'p2_type': 'd-line',  # 'adjusted close', 'adj close', 'close', 'open', 'high', 'low', 'sma', 'ema', 'dema', 'tema', 'wma', 'wwma', 'k-line', 'd-line'
    'p1_window': 13,
    'p2_window': 3,
    'signal_type': 'ema',  # 'sma', 'ema', 'dema', 'tema', 'wma', 'wwma'
    'signal_window': 10
}

In [11]:
stochastic_data = stochastic_oscillator(close_tk, high_tk, low_tk)

In [24]:
def plot_diff_plotly(
    tk,
    diff_data,
    price_type_map,
    reverse_diff = False,
    plot_type = 'filled_line',
    add_signal = True,
    n_ticks_max = 48,
    n_yticks_max = 16,
    plot_width = 1450,
    plot_height = 750,
    title_font_size = 32,
    theme = 'dark'
):
    """
    price_type_map = {
        'Adj Close': adj_close_tk,
        'Adjusted Close': adj_close_tk,
        'Close': close_tk,
        'Open': open_tk,
        'High': high_tk,
        'Low': low_tk
    }
    reverse_diff:
        if True, the (p2 - p1) difference will be used instead of (p1 - p2)
    plot_type:
        'flooded_line' (or 'scatter'), 'bar'
    add_signal:
        if True, a signal will be added that is a moving average of the calculated difference
    """
    
    base = diff_data['p_base']
    p_base_name = base.title()
    p_base = price_type_map[p_base_name]

    p1_type = diff_data['p1_type']
    p2_type = diff_data['p2_type']
    p1_window = diff_data['p1_window']
    p2_window = diff_data['p2_window']
    signal_type = diff_data['signal_type']
    signal_window = diff_data['signal_window']
    
    price_types = ['adjusted close', 'adj close', 'close', 'open', 'high', 'low']
    ma_types = ['sma', 'ema', 'dema', 'tema', 'wma']
    stochastic_types = ['k-line', 'k_line', 'kline', 'd-line', 'd_line', 'dline']

    diff_title_base = f'{tk} {p_base_name} '

    if p1_type in price_types:
        p1_name = 'Adjusted Close' if p1_name == 'adj close' else p1_type.title()
        try:
            p1 = price_type_map[p1_name]
        except:
            p1 = price_type_map['Adj Close']

    elif p1_type in ma_types:
        p1 = moving_average(p_base, p1_type, p1_window)
        p1_name = f'{p1_type.upper()} {p1_window}'

    elif p1_type in stochastic_types:
        stochastic_type = stochastic_data['type']
        p1 = stochastic_data['k_line']
        p2 = stochastic_data['d_line']
        p1_name = '%K Line'
        p2_name = '%D Line'
        diff_title_base += f'Stochastic {stochastic_type} '

    if p2_type in price_types:
        p2_name = 'Adjusted Close' if p2_name == 'adj close' else p2_type.title()
        try:
            p2 = price_type_map[p2_name]
        except:
            p2 = price_type_map['Adj Close']

    elif p2_type in ma_types:
        p2 = moving_average(p_base, p2_type, p2_window)
        p2_name = f'{p2_type.upper()} {p2_window}'

    if not reverse_diff:
        diff = p1 - p2
        diff_title = diff_title_base + f'{p1_name} - {p2_name} Oscillator'
        diff_positive_name = f'{p1_name} > {p2_name}'
        diff_negative_name = f'{p1_name} < {p2_name}'
    else:
        diff = p2 - p1
        diff_title = diff_title_base + f'{p2_name} - {p1_name} Oscillator'
        diff_positive_name = f'{p2_name} > {p1_name}'
        diff_negative_name = f'{p2_name} < {p1_name}'

    diff_signal = moving_average(diff, signal_type, signal_window)
    signal_name = f'Diff {signal_type.upper()} {signal_window} Signal'

    # x_min = str(diff.index.min().date())
    # x_max = str(diff.index.max().date())

    x_min = diff.index.min()
    x_max = diff.index.max()
 
    min_diff = min(diff)
    max_diff = max(diff)
    
    print(min_diff)
    print(max_diff)

    y_diff_min, y_diff_max = set_axis_limits(min_diff, max_diff)
    
    diff_positive = diff.copy()
    diff_negative = diff.copy()

    style = theme_style[theme]
    
    fig_diff = make_subplots(rows = 1, cols = 1)

    if plot_type == 'bar':

        diff_positive.iloc[np.where(diff_positive < 0)] = np.nan
        diff_negative.iloc[np.where(diff_negative >= 0)] = np.nan

        fig_diff.add_trace(
            go.Bar(
                x = diff_positive.index.astype(str),
                y = diff_positive,
                marker_color = style['green_color'],
                width = 1,
                name = diff_positive_name,
                # legendgroup = f'{target_deck}',
                # legendgrouptitle = legendgrouptitle,
                showlegend = True
            )
            # , row = target_deck, col = 1
        )
        fig_diff.add_trace(
            go.Bar(
                x = diff_negative.index.astype(str),
                y = diff_negative,
                marker_color = style['red_color'],
                width = 1,
                name = diff_negative_name,
                # legendgroup = f'{target_deck}',
                # legendgrouptitle = legendgrouptitle,
                showlegend = True
            )
            # , row = target_deck, col = 1
        )

    else:
        # 'filled_line' or 'scatter'

        prev_v = diff.iloc[0]
        diff_positive.iloc[0] = prev_v if prev_v >= 0 else np.nan
        diff_negative.iloc[0] = prev_v if prev_v < 0 else np.nan
    
        for idx in diff.index[1:]:
            
            curr_v = diff.loc[idx]
    
            if np.sign(curr_v) != np.sign(prev_v):
                # Set both diff copies to 0 if the value is changing sign
                diff_positive[idx] = 0
                diff_negative[idx] = 0
            else:
                # Set both diff copies to current value or NaN
                diff_positive[idx] = curr_v if curr_v >= 0 else np.nan
                diff_negative[idx] = curr_v if curr_v < 0 else np.nan
        
            prev_v = curr_v

        fig_diff.add_trace(
            go.Scatter(
                x = diff_positive.index.astype(str),
                y = diff_positive,
                line_color = style['diff_green_linecolor'],
                line_width = 2,
                fill = 'tozeroy',
                fillcolor = style['diff_green_fillcolor'],
                name = diff_positive_name
            )
        )
        fig_diff.add_trace(
            go.Scatter(
                x = diff_negative.index.astype(str),
                y = diff_negative,
                line_color = 'darkred',
                line_width = 2,
                fill = 'tozeroy',
                fillcolor = style['diff_red_fillcolor'],
                name = diff_negative_name
            )
        )

    if add_signal:
        fig_diff.add_trace(
            go.Scatter(
                x = diff_signal.index.astype(str),
                y = diff_signal,
                line_color = style['signal_color'],
                line_width = 2,
                name = signal_name
            )
        )
    
    # Add plot border
    fig_diff.add_shape(
        type = 'rect',
        xref = 'x',  # use 'x' because of seconday axis - 'paper' does not work correctly
        yref = 'paper',
        x0 = x_min,
        x1 = x_max,
        y0 = 0,
        y1 = 1,
        line_color = style['x_linecolor'],
        line_width = 0.3
    )
    # Update layout and axes
    fig_diff.update_layout(
        width = plot_width,
        height = plot_height,
        xaxis_rangeslider_visible = False,
        template = style['template'],
        title = dict(
            text = diff_title,
            font_size = title_font_size,
            y = 0.95,
            x = 0.45,
            xanchor = 'center',
            yanchor = 'top'
        )
    )
    fig_diff.update_yaxes(
        title_text = f'Oscillator',
        range = (y_diff_min, y_diff_max),
        secondary_y = False,
        nticks = n_yticks_max,
        gridcolor = style['y_gridcolor'],
        zerolinecolor = style['x_gridcolor'],
        zerolinewidth = 1,        
        ticks = 'outside',
        ticklen = 8,
        ticklabelshift = 5,  # not working
        ticklabelstandoff = 10,  # not working
    )
    fig_diff.update_xaxes(
        type = 'category',
        nticks = n_ticks_max,
        tickangle = -90,
        gridcolor = style['x_gridcolor'],
        ticks = 'outside',
        ticklen = 8,
        ticklabelshift = 5,  # not working
        ticklabelstandoff = 10,  # not working
        showgrid = True
    )

    return fig_diff

In [25]:
theme = 'light'
theme = 'light'
# stochastic_data = stochastic_oscillator(close_tk, high_tk, low_tk)
fig_diff = plot_diff_plotly(tk, diff_data_stochastic, price_type_map, plot_type = 'scatter', add_signal = True, theme = theme)
fig_diff.show()

-25.199126804901443
25.899847806876238


In [14]:
import numpy as np
import math


def set_axis_limits_(
    x_min,
    x_max,
    max_n_intervals = 15
):
    """
    Returns the lower and upper limits for an axis where x_min and x_max are the min/max values.
    max_n_intervals: maximum number of intervals between y-ticks
    units: increments of values at axis ticks, will be scaled to correspond with the
        order of magntitude of x_max - x_min
    """

    if x_min == x_max:
        return x_min, x_max
    
    else:
        units = np.array([0.05, 0.1, 0.2, 0.25, 0.5])
        # intervals = np.array(range(4, max_n_intervals + 1))

        x_maxmax = max(abs(x_max), abs(x_min))
        diff = 2 * x_maxmax
        order = 10 ** round(math.log10(x_maxmax))
        print(f'order = {order}')
        eps = order * 1e-10

        for unit in units:
            unit_scaled = order * unit
            print(f'unit scaled = {unit_scaled}')

            lower_anchor = 0
            increment = unit_scaled
            while lower_anchor - abs(x_min) < eps:
                lower_anchor += increment
            lower_anchor *= np.sign(x_min)
            if x_min > eps:
                lower_anchor -= increment

            diff_lower = abs(lower_anchor - x_min)
            if diff_lower < eps:
                diff_lower = 0

            print(f'\tlower anchor = {lower_anchor}')
            print(f'\tdiff lower = {diff_lower}')

            upper_anchor = lower_anchor
            while (upper_anchor < x_max) & (abs(upper_anchor - x_max) > eps) & ((upper_anchor - lower_anchor) / unit_scaled <= max_n_intervals):
                upper_anchor += unit_scaled
                print(f'\t\tupper anchor = {upper_anchor} (while loop)')
            diff_upper = abs(upper_anchor - x_max)
            if diff_upper < eps:
                diff_upper = 0
            
            print(f'\tupper anchor = {upper_anchor}')
            print(f'\tdiff upper = {diff_upper}')

            if (upper_anchor - x_max > -eps) & (diff_lower + diff_upper < diff):
                diff = diff_lower + diff_upper
                lower_limit = lower_anchor
                upper_limit = upper_anchor
        
        return lower_limit, upper_limit

In [15]:
# print(set_axis_limits(-25.199126804901443, 25.899847806876238))
print(set_axis_limits_(-29.1, 29.1))

order = 10
unit scaled = 0.5
	lower anchor = -29.5
	diff lower = 0.3999999999999986
		upper anchor = -29.0 (while loop)
		upper anchor = -28.5 (while loop)
		upper anchor = -28.0 (while loop)
		upper anchor = -27.5 (while loop)
		upper anchor = -27.0 (while loop)
		upper anchor = -26.5 (while loop)
		upper anchor = -26.0 (while loop)
		upper anchor = -25.5 (while loop)
		upper anchor = -25.0 (while loop)
		upper anchor = -24.5 (while loop)
		upper anchor = -24.0 (while loop)
		upper anchor = -23.5 (while loop)
		upper anchor = -23.0 (while loop)
		upper anchor = -22.5 (while loop)
		upper anchor = -22.0 (while loop)
		upper anchor = -21.5 (while loop)
	upper anchor = -21.5
	diff upper = 50.6
unit scaled = 1.0
	lower anchor = -30.0
	diff lower = 0.8999999999999986
		upper anchor = -29.0 (while loop)
		upper anchor = -28.0 (while loop)
		upper anchor = -27.0 (while loop)
		upper anchor = -26.0 (while loop)
		upper anchor = -25.0 (while loop)
		upper anchor = -24.0 (while loop)
		upper an