In [None]:
"""Tax Calculator notebook"""

#import datetime
import ipywidgets as widgets
import IPython
from IPython.display import display

import pandas as pd

# ############## Input Widgets ##############

# inpt: net pay (excluding 401k deductions)
# inpt: IRA dist
# inpt: interest
# inpt: dividends
# inpt: SSA (calc SSA * 85%)
# inpt: st cap gains
# inpt: lt cap gains carryover
# inpt: lt capital gains

# calc: AGI

year_box = widgets.BoundedIntText(
    value=2024,
    min=2000,
    max=9999,
    step=1,
    description='Tax Year:',
    disabled=False,
    layout=widgets.Layout(width='200px')
)
npay_box = widgets.BoundedFloatText(
    value=0.00,
    min=0.0,
    max=250000.00,
    step=1.00,
    description='Net Earnings',
    disabled=False,
    layout=widgets.Layout(width='200px')
)
irad_box = widgets.BoundedFloatText(
    value=0.00,
    min=0.0,
    max=250000.00,
    step=1.00,
    description='IRA Distrib.',
    disabled=False,
    layout=widgets.Layout(width='200px')
)
intr_box = widgets.BoundedFloatText(
    value=0.00,
    min=0.0,
    max=250000.00,
    step=1.00,
    description='Interest',
    disabled=False,
    layout=widgets.Layout(width='200px')
)
dvdd_box = widgets.BoundedFloatText(
    value=0.00,
    min=0.0,
    max=250000.00,
    step=1.00,
    description='Dividends',
    disabled=False,
    layout=widgets.Layout(width='200px')
)
ssai_box = widgets.BoundedFloatText(
    value=0.00,
    min=0.0,
    max=250000.00,
    step=1.00,
    description='SSA income',
    disabled=False,
    layout=widgets.Layout(width='200px')
)
stgn_box = widgets.BoundedFloatText(
    value=0.00,
    min=0.0,
    max=250000.00,
    step=1.00,
    description='ST Gains',
    disabled=False,
    layout=widgets.Layout(width='200px')
)
ltgn_box = widgets.BoundedFloatText(
    value=0.00,
    min=0.0,
    max=250000.00,
    step=1.00,
    description='LT Gains',
    disabled=False,
    layout=widgets.Layout(width='200px')
)
gcry_box = widgets.BoundedFloatText(
    value=0.00,
    min=-250000.0,
    max=250000.00,
    step=1.00,
    description='Gains Carry',
    disabled=False,
    layout=widgets.Layout(width='200px')
)


## Unit Test Code for e-ORP

# input widget for file name for saving dataframe
efname = widgets.Text(
    value='../data/_explore.csv',
    placeholder='../data/_explore.csv',
    description='Test output at:',
    style={'description_width': 'initial'},
    disabled=False)



# ##############      Tax Data       ##############

# these tax brackets are for Married Filing Jointly only
# each entry in the ordered dict is top of income bracket in 000s, and tax rate for that bracket
tax_brackets_rates = {
    2022: [(20.55, 0.100), (83.55, 0.120), (178.15, 0.220), (340.10, 0.240), (431.90, 0.320)],
    2023: [(22.00, 0.100), (89.45, 0.120), (190.75, 0.220), (364.20, 0.240), (462.50, 0.320)],
    2024: [(23.20, 0.100), (94.30, 0.120), (201.05, 0.220), (383.90, 0.240), (487.45, 0.320)],
    2025: [(23.85, 0.100), (96.95, 0.120), (206.70, 0.220), (394.60, 0.240), (501.05, 0.320)],
    # brackets must be adjusted for inflation, somehow...
    2026: [(23.85, 0.100), (96.95, 0.120), (206.70, 0.220), (394.60, 0.240), (501.05, 0.320)]
}
std_deductions = {
    2022: 25.9 + (2 * 1.4),  # with 2 over 65 2022
    2023: 27.7 + (2 * 1.5),  # with 2 over 65 2023
    2024: 29.2 + (2 * 1.55), # with 2 over 65 2024
    2025: 31.5 + (2 * 1.60), # with 2 over 65 assuming SSA income 2025
    # brackets must be adjusted for inflation, somehow...
    2026: 31.5 + (2 * 1.60)  # with 2 over 65 assuming SSA income 2026
}
cap_brackets_rates = {
    2022: [(83.35, 0.000), (517.20, 0.150), (999.99, 0.200)],
    2023: [(89.25, 0.000), (553.85, 0.150), (999.99, 0.200)],
    2024: [(94.05, 0.000), (583.75, 0.150), (999.99, 0.200)],
    2025: [(96.70, 0.000), (600.05, 0.150), (999.99, 0.200)],
    # brackets must be adjusted for inflation, somehow...
    2026: [(96.70, 0.000), (600.05, 0.150), (999.99, 0.200)]
}
first_tax_year_with_irs_brackets = 2022
last_tax_year_with_irs_adjusted_brackets = 2025

err_out = widgets.Output(layout={'border': '1px solid black'})

    #
    # with err_out:
    #     print(f'Calc brackets for {year}\n')
    #

# ##############    Tax Calculator     ##############

def tax_brackets_for_year(year, MAGI, rate_infla):
    """Given a year, create interable lists of income and gains brackets with marginal rates"""
    oyear = year
    infla = 1.0
    last_ceil = 0.0 # ceiling of previous bracket
    cummu_tax = 0.0 # cummulative tax for preceeding brackets

    if year < first_tax_year_with_irs_brackets:
        with err_out:
            print('FAIL: no tax info for years before 2022')
        year = first_tax_year_with_irs_brackets
    elif year <= last_tax_year_with_irs_adjusted_brackets:
        infla = 1.0
    else:
        infla = (1.0 + rate_infla) ** (year - last_tax_year_with_irs_adjusted_brackets)
        year = last_tax_year_with_irs_adjusted_brackets

    # OBBBA: Seniors with income under $75,000 (single) / $150,000 (joint) 
    # can deduct up to $6,000 ($12,000 joint) of Social Security income until 2028.
    # ASSUME MFJ with two over 65
    #    The deduction will phase out at a 6% rate for single taxpayers earning more than $75,000 
    #    and married taxpayers earning more than $150,000. A single taxpayer (age 65+) earning $85,000, 
    #    for example, will be eligible for an additional deduction of $5,400. The deduction will phase 
    #    out entirely for single taxpayers with income above $175,000 and married taxpayers with income 
    #    above $250,000.
    #    6% * (175,000 - 75000) = 6000
    #    6% * (250,000 - 150,000) - 6000
    #    MAGI used, so need to add the non-taxable portion of SSA
    #    OBBBA-pax: 0, 1, 2 # individuals over 65
    #    OBBBA_exc: == (MAGI - (OBBBA-pax * 75000)
    #    OBBBA-ded: <= OBBBA-pax * 6000 - (OBBBA-pax * 0.06 * OBBBA_exc))
    # TODO OBBBA-pax
    OBBBA_pax = 2
    OBBBA_exc = max(0, MAGI - (OBBBA_pax * 75.0))
    OBBBA_ded = max(0, OBBBA_pax * 6.0 - (OBBBA_pax * 0.06 * OBBBA_exc))
    # with err_out:
    #     print(f"{oyear} b {OBBBA_ded}")
    std_deduction = std_deductions[year] * infla + (OBBBA_ded if oyear >= 2025 and oyear <= 2028 else 0)
    
    # each bracket has: (low, high, cummtax, rate)
    brackets_rates = [(0.0, std_deduction, 0.0, 0.0)] # standard deduction zero tax
    last_ceil = std_deduction
    for (b,r) in tax_brackets_rates[year]:
        next_ceil = std_deduction + (b * infla)
        brackets_rates.append((last_ceil, next_ceil, cummu_tax, r))
        cummu_tax += (next_ceil - last_ceil) * r
        last_ceil = next_ceil

    # capital gains tax ### wrong!? cumm depends on income before cap gains
    capgains_rates = [(0.0, std_deduction, 0.0, 0.0)] # standard deduction zero tax
    last_ceil = std_deduction
    for (b,r) in cap_brackets_rates[year]:
        next_ceil = std_deduction + (b * infla)
        capgains_rates.append((last_ceil, next_ceil, cummu_tax, r))
        cummu_tax += (next_ceil - last_ceil) * r
        last_ceil = next_ceil

    return (std_deduction, brackets_rates, capgains_rates)

def calc_tax(year, rate_infla, income, capgains, MAGI, tax_bs_for_year=None):
    """Calculate income tax; `tax_bs_for_year` is optional to avoid running tax_brackets_for_year multiple times per year"""
    # pylint: disable=too-many-arguments
    # pylint: disable=too-many-locals
    if tax_bs_for_year is None:
        (sd, bs, cs) = tax_brackets_for_year(year, MAGI, rate_infla)
    else:
        (sd, bs, cs) = tax_bs_for_year
    ibtax = 0.0 # tax on income before capital gains
    mrate = 0.0 # marginal tax rate
    brend = 0.0 # high end of marginal tax bracket
    cgtax = 0.0 # capital gains tax
    crate = 0.0 # capgains tax rate
    for b in bs:
        (low, high, cummtax, rate) = b
        if low <= income <= high: # use inclusive range to get first applicable bracket even if full
            ibtax = cummtax + (income - low) * rate
            mrate = rate
            brend = high
            break
    for b in cs:
        (low, high, cummtax, rate) = b
        if low <= income < high: # use exclusive range to get first applicable bracket not full
            avail = high - income
            take = min(avail, capgains)
            cgtax += take * rate
            crate = rate
            capgains -= take
            income += take
            if capgains <= 0:
                break
    return (ibtax + cgtax, mrate, crate, brend, sd, ibtax, cgtax)

# ##############       UI       ##############

def tax_calc(year, income, capgains, MAGI):
    """run calc_tax with values adjusted for OORPy compatibility """
    # The *1000 and /1000 are because OORPy uses $000s and this UI uses dollars
    # Use 2% inflation assumption... only used for years beyond known tax brackets
    r = calc_tax(year, 0.02, income / 1000, capgains / 1000, MAGI / 1000)
    return r

# The big general purpose output for tables, graphs, etc.
out_box = widgets.Output(layout={'border': '1px solid black'})

def calc(_):
    """run calc_tax and display results"""
    err_out.clear_output() # at start of each run
    # TODO more finesse with ssa, and also dividends, perhaps
    year = year_box.value
    income = (npay_box.value
              + irad_box.value
              + intr_box.value
              + dvdd_box.value
              + ssai_box.value * 0.85
              + stgn_box.value)
    capgains = ltgn_box.value + gcry_box.value
    MAGI = income + capgains + ssai_box.value * 0.15 # TODO: QCD
    (tax, mrate, crate, brend, _, ibtax, cgtax) = tax_calc(year, income, capgains, MAGI)
    with out_box:
        print(f'\n{year: 11d} Tax Year'
              f'\n{income: 11,.2f} Income'
              f'\n{capgains: 11,.2f} LTCG'
              f'\n{income + capgains: 11,.2f} AGI'
              f'\n{1000 * tax: 11,.2f} Total Tax'
              f'\n{1000 * ibtax: 11,.2f} Tax on Income'
              f'\n{1000 * cgtax: 11,.2f} Tax on Capital Gains'
              f'\n{mrate: 11.0%} Marginal rate on Income'
              f'\n{crate: 11.0%} Marginal rate on Capital Gains'
              f'\n{1000 * brend: 11,.2f} End of income bracket')

def test(_):
    """test _explore file and display results"""
    err_out.clear_output() # at start of each run
    out_box.clear_output() # at start of each run
    #
    dd = pd.read_csv(efname.value) # , index_col='year'
    YRS = len(dd['e']) - 1 # number of years of projection from base year 0
    infl = dd['IRMAA-buk0'][0] # squirreled value
    for y in range(1,YRS+1):
        year = dd['year'][y]
        income = dd['taxable_income'][y]
        capgains = dd['capgains'][y]
        ssa_income = dd['SSA_income'][y]
        MAGI = income + capgains + 0.15 * ssa_income + dd['QCD'][y]

        tb = tax_brackets_for_year(year, MAGI, infl)
        (std_deduction, brackets_rates, capgains_rates) = tb
        # with err_out:
        #     print(std_deduction, brackets_rates, capgains_rates)
        #     print('\n')
        with out_box:
            print(f" ")
        
        (tax, mrate, crate, brend, _, ibtax, cgtax) = calc_tax(year, infl, income, capgains, MAGI, tb)
        tax = max(0,round(tax, 3))
        with out_box:
            print(f"{dd['year'][y]} ti {income} cg {capgains} it {ibtax} ct {cgtax}")
            print(f"{dd['year'][y]} tc {tax} eo {dd['income_tax'][y]} {'ok' if tax == dd['income_tax'][y] else 'ng'}")
  
run_button = widgets.Button(description='Run Tax Calc', disabled=False,)
run_button.on_click(calc)

tst_button = widgets.Button(description='Test e-ORP Tax Calc', disabled=False,)
tst_button.on_click(test)

winputs = widgets.VBox([
    widgets.GridBox([widgets.Label('e Tax Calc', style={
                        'font_weight':'bold',
                        'font_size':'large',
                        'text_color':'forestgreen',
                        }), widgets.Label(''),
                    widgets.Label('Inputs', style={
                        'font_weight':'bold',
                        'font_size':'large',
                        }),
                 widgets.Label(''),
                 year_box, widgets.Label('The tax year for the calculation'),
                 npay_box, widgets.Label('Net pay exclusive on pre-tax deductions (e.g., 401k)'),
                 irad_box, widgets.Label('Taxable IRA distributions'),
                 intr_box, widgets.Label('Taxable Interest'),
                 dvdd_box, widgets.Label('Dividends'),
                 ssai_box, widgets.Label('SSA income'),
                 stgn_box, widgets.Label('Short Term Capital Gains'),
                 ltgn_box, widgets.Label('Lond Term Capital Gains'),
                 gcry_box, widgets.Label('Lond Term Capital Gains carry over from previous year'),
                ],
                layout=widgets.Layout(grid_template_columns='35% 65%')),
    run_button,
    efname, tst_button,
    err_out,
    out_box
])

display(winputs)
