---

Created for [learn-investments.rice-business.org](https://learn-investments.rice-business.org)
    
By [Kerry Back](https://kerryback.com) and [Kevin Crotty](https://kevin-crotty.com)
    
Jones Graduate School of Business, Rice University

---


# EXAMPLE DATA

In [9]:
t_oi_0 = 0.35     # Initial tax rate on ordinary income
t_oi_T = 0.25     # Ending tax rate on ordinary income
t_div  = 0.15     # Tax rate on dividends
t_cg   = 0.15     # Tax rate on capital gains
dy     = 0.02     # Stock dividend yield
cg     = 0.04     # Stock capital gain
bndret = 0.03     # Bond rate of return
T      = 30       # Years saving

# Portfolio weights of stock and bond in each asset location
portfolio = {
        "Brokerage:Stock":  0.25,
        "Brokerage:Bond":   0.25,
        "401k:Stock":       0.25,
        "401k:Bond":        0.00,
        "Roth:Stock":       0.25,
        "Roth:Bond":        0.00
}

# wgt_roth_bond  = 1 - WGT_BROK_STOCK - WGT_BROK_BOND - WGT_401K_STOCK - WGT_401K_BOND - WGT_ROTH_STOCK
# wgts = [WGT_BROK_STOCK, WGT_BROK_BOND, WGT_401K_STOCK, WGT_401K_BOND, WGT_ROTH_STOCK, wgt_roth_bond]

# CALCULATIONS

In [10]:
import pandas as pd
import numpy as np


# Dividend-paying stock
def stockret(tax_treat, t_oi_0, t_oi_T, t_div, t_cg, dy, cg, T):
    # tax_treat: ['brokerage','roth','401k']
    # Assumes taxes are constant from t=0 to t=T-1 and then jump at t=T
    # Assumes constant dividend yield and capital gain per year
    if tax_treat == "Brokerage":
        r = 1 + dy * (1 - t_div) + cg
        ret = (r ** T) * (1 - t_cg) + t_cg * (
            1 + dy * (1 - t_div) * (1 - r ** T) / (1 - r)
        )
    elif tax_treat == "Roth":
        r = dy + cg
        ret = (1 + r) ** T
    elif tax_treat == "401k":
        r = dy + cg
        if T == 0:
            ret = ((1 - t_oi_0) * (1 + r) ** T) / (1 - t_oi_0)
        else:
            ret = ((1 - t_oi_T) * (1 + r) ** T) / (1 - t_oi_0)
    else:
        print("Tax treatment not defined")
    return ret


# Taxable coupon bond (with reinvestment)
def bondret(tax_treat, t_oi_0, t_oi_T, r, T):
    # tax_treat: ['brokerage','roth','401k']
    # Assumes taxes are constant from t=0 to t=T-1 and then jump at t=T
    # Assumes taxable coupon payment (reinvested at same rate); no capital gain/loss on bond
    if tax_treat == "Brokerage":
        if T == 0:
            ret = 1
        else:
            ret = (1 + r * (1 - t_oi_0)) ** (T - 1) * (1 + r * (1 - t_oi_T))
    elif tax_treat == "Roth":
        ret = (1 + r) ** T
    elif tax_treat == "401k":
        if T == 0:
            ret = ((1 - t_oi_0) * (1 + r) ** T) / (1 - t_oi_0)  # (ie. 1)
        else:
            ret = ((1 - t_oi_T) * (1 + r) ** T) / (1 - t_oi_0)
    else:
        print("Tax treatment not defined")
    return ret

In [11]:
accounts = ['Brokerage','401k', 'Roth']
subaccts = ['Stock','Bond','Total']
cols = pd.MultiIndex.from_product([accounts,subaccts])
df = pd.DataFrame(dtype=float, index=1+np.arange(T), columns=cols)
for t in df.index:
    for acct in accounts:
        wgt = portfolio[acct + ":Stock"]
        df.loc[t, (acct,'Stock')] = wgt * stockret(acct, t_oi_0, t_oi_T, t_div, t_cg, dy, cg, t)

        wgt = portfolio[acct + ":Bond"]
        df.loc[t, (acct,'Bond')]  = wgt * bondret(acct, t_oi_0, t_oi_T, bndret, t)

for acct in accounts:
    df[(acct,'Total')] = df[(acct,'Stock')] + df[(acct,'Bond')]
cols_to_sum = [(acct, 'Total') for acct in accounts] 
df[("Overall",'Total')] = df[cols_to_sum].sum(axis=1)        
df.head()

Unnamed: 0_level_0,Brokerage,Brokerage,Brokerage,401k,401k,401k,Roth,Roth,Roth,Overall
Unnamed: 0_level_1,Stock,Bond,Total,Stock,Bond,Total,Stock,Bond,Total,Total
1,0.26275,0.255625,0.518375,0.305769,0.0,0.305769,0.265,0.0,0.265,1.089144
2,0.276227,0.26061,0.536836,0.324115,0.0,0.324115,0.2809,0.0,0.2809,1.141852
3,0.290472,0.265692,0.556163,0.343562,0.0,0.343562,0.297754,0.0,0.297754,1.19748
4,0.305529,0.270873,0.576401,0.364176,0.0,0.364176,0.315619,0.0,0.315619,1.256196
5,0.321444,0.276155,0.597598,0.386027,0.0,0.386027,0.334556,0.0,0.334556,1.318181


# FIGURE

In [14]:
import plotly.graph_objects as go

fig = go.Figure()


fig.add_trace(go.Scatter(x=df.index, y=df[('Overall','Total')], mode="lines", name="Total"))
fig.add_trace(go.Scatter(x=df.index, y=df[('Brokerage','Total')], mode="lines", name="Brokerage"))
fig.add_trace(go.Scatter(x=df.index, y=df[('401k','Total')], mode="lines", name="401k"))
fig.add_trace(go.Scatter(x=df.index, y=df[('Roth','Total')], mode="lines", name="Roth"))
fig.update_layout(
    xaxis_title="Year of Withdrawal",
    xaxis_tickformat=",.0f",
    yaxis_title="After-Tax Future Value",
    yaxis_tickformat="$,.2f",
    hovermode="x unified",
    template="plotly_white",
    legend = dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01
    )
)
fig.show()