# EXAMPLE DATA

In [1]:
t_oi_0 = 0.35     # Initial tax rate on ordinary income
t_oi_T = 0.25     # Ending tax rate on ordinary income
t_div  = 0.15     # Tax rate on dividends
t_cg   = 0.15     # Tax rate on capital gains
dy     = 0.02     # Stock dividend yield
cg     = 0.04     # Stock capital gain
bndret = 0.03     # Bond rate of return
T      = 30       # Years saving

# Portfolio #1 weights of stock and bond in each asset location
portfolio1 = {
        "Brokerage:Stock":  0.50,
        "Brokerage:Bond":   0.50,
        "401k:Stock":       0.00,
        "401k:Bond":        0.00,
        "Roth:Stock":       0.00,
        "Roth:Bond":        0.00
}

# Portfolio #2 weights of stock and bond in each asset location
portfolio2 = {
        "Brokerage:Stock":  0.50,
        "Brokerage:Bond":   0.00,
        "401k:Stock":       0.00,
        "401k:Bond":        0.50,
        "Roth:Stock":       0.00,
        "Roth:Bond":        0.00
}

# Portfolio #3 weights of stock and bond in each asset location
portfolio3 = {
        "Brokerage:Stock":  0.00,
        "Brokerage:Bond":   0.00,
        "401k:Stock":       0.50,
        "401k:Bond":        0.50,
        "Roth:Stock":       0.00,
        "Roth:Bond":        0.00
}


# CALCULATIONS

In [2]:
import pandas as pd
import numpy as np


# Dividend-paying stock
def stockret(tax_treat, t_oi_0, t_oi_T, t_div, t_cg, dy, cg, T):
    # tax_treat: ['brokerage','roth','401k']
    # Assumes taxes are constant from t=0 to t=T-1 and then jump at t=T
    # Assumes constant dividend yield and capital gain per year
    if tax_treat == "Brokerage":
        r = 1 + dy * (1 - t_div) + cg
        ret = (r ** T) * (1 - t_cg) + t_cg * (
            1 + dy * (1 - t_div) * (1 - r ** T) / (1 - r)
        )
    elif tax_treat == "Roth":
        r = dy + cg
        ret = (1 + r) ** T
    elif tax_treat == "401k":
        r = dy + cg
        if T == 0:
            ret = ((1 - t_oi_0) * (1 + r) ** T) / (1 - t_oi_0)
        else:
            ret = ((1 - t_oi_T) * (1 + r) ** T) / (1 - t_oi_0)
    else:
        print("Tax treatment not defined")
    return ret


# Taxable coupon bond (with reinvestment)
def bondret(tax_treat, t_oi_0, t_oi_T, r, T):
    # tax_treat: ['brokerage','roth','401k']
    # Assumes taxes are constant from t=0 to t=T-1 and then jump at t=T
    # Assumes taxable coupon payment (reinvested at same rate); no capital gain/loss on bond
    if tax_treat == "Brokerage":
        if T == 0:
            ret = 1
        else:
            ret = (1 + r * (1 - t_oi_0)) ** (T - 1) * (1 + r * (1 - t_oi_T))
    elif tax_treat == "Roth":
        ret = (1 + r) ** T
    elif tax_treat == "401k":
        if T == 0:
            ret = ((1 - t_oi_0) * (1 + r) ** T) / (1 - t_oi_0)  # (ie. 1)
        else:
            ret = ((1 - t_oi_T) * (1 + r) ** T) / (1 - t_oi_0)
    else:
        print("Tax treatment not defined")
    return ret

In [4]:
portfolios = {'portfolio1': portfolio1, 'portfolio2': portfolio2, 'portfolio3': portfolio3}
totals = pd.DataFrame(dtype=float, index=1+np.arange(T), columns=list(portfolios.keys()))

for p in totals.columns:
    accounts = ['Brokerage','401k', 'Roth']
    subaccts = ['Stock','Bond','Total']
    cols = pd.MultiIndex.from_product([accounts,subaccts])
    df = pd.DataFrame(dtype=float, index=1+np.arange(T), columns=cols)
    for t in df.index:
        for acct in accounts:
            portfolio = portfolios[p]

            wgt = portfolio[acct + ":Stock"]
            df.loc[t, (acct,'Stock')] = wgt * stockret(acct, t_oi_0, t_oi_T, t_div, t_cg, dy, cg, t)

            wgt = portfolio[acct + ":Bond"]
            df.loc[t, (acct,'Bond')]  = wgt * bondret(acct, t_oi_0, t_oi_T, bndret, t)

    for acct in accounts:
        df[(acct,'Total')] = df[(acct,'Stock')] + df[(acct,'Bond')]
    cols_to_sum = [(acct, 'Total') for acct in accounts] 
    df[("Overall",'Total')] = df[cols_to_sum].sum(axis=1)   
    totals[p] = df[("Overall",'Total')]

# FIGURE

In [5]:
import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Scatter(x=totals.index, y=totals.portfolio1, mode="lines", name="Portfolio 1"))
fig.add_trace(go.Scatter(x=totals.index, y=totals.portfolio2, mode="lines", name="Portfolio 2"))
fig.add_trace(go.Scatter(x=totals.index, y=totals.portfolio3, mode="lines", name="Portfolio 3"))
fig.update_xaxes(title="Year of Withdrawal", tickformat=",.0f")
fig.update_yaxes(title="After-Tax FV", tickformat="$,.2f")
fig.update_layout(hovermode="x unified")
fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01))
fig.show()