In [46]:
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
from ipywidgets import interact, FloatSlider, IntSlider, Layout
from plotly.subplots import make_subplots

def calculate_monthly_payment(principal, annual_rate, term_months):
    monthly_rate = annual_rate / 12
    numerator = principal * monthly_rate * (1 + monthly_rate) ** term_months
    denominator = (1 + monthly_rate) ** term_months - 1
    monthly_payment = numerator / denominator
    return monthly_payment

def calculate_payments(principal, annual_rate, term_months):
    monthly_payment = calculate_monthly_payment(principal, annual_rate, term_months)
    total_payment = monthly_payment * term_months
    total_interest = total_payment - principal
    return monthly_payment, total_payment, total_interest

def plot_loan_payments(car_price, sales_tax_rate, down_payment, annual_rate, loan_term_years):
    total_cost = car_price * (1 + sales_tax_rate)
    loan_principal = total_cost - down_payment
    term_months = loan_term_years * 12
    term_months_short = 12  # 1-year term

    if loan_principal <= 0:
        print("Down payment exceeds or matches total cost. No loan needed.")
        return
    
    # Calculate payments for the selected loan term
    monthly_payment, total_payment, total_interest = calculate_payments(loan_principal, annual_rate, term_months)

    # Calculate payments for a 1-year loan term for comparison
    monthly_payment_short, total_payment_short, total_interest_short = calculate_payments(loan_principal, annual_rate, term_months_short)

    payments = np.arange(1, term_months + 1)
    interest_payments = [loan_principal * (annual_rate / 12) * ((1 + annual_rate / 12) ** i) / ((1 + annual_rate / 12) ** term_months - 1) for i in range(1, term_months + 1)]
    principal_payments = [monthly_payment - ip for ip in interest_payments]
    remaining_balance = [loan_principal - sum(principal_payments[:i]) for i in range(1, term_months + 1)]

    payments_short = np.arange(1, term_months_short + 1)
    interest_payments_short = [loan_principal * (annual_rate / 12) * ((1 + annual_rate / 12) ** i) / ((1 + annual_rate / 12) ** term_months_short - 1) for i in range(1, term_months_short + 1)]
    principal_payments_short = [monthly_payment_short - ip for ip in interest_payments_short]
    remaining_balance_short = [loan_principal - sum(principal_payments_short[:i]) for i in range(1, term_months_short + 1)]

    # Create subplots
    fig = make_subplots(rows=3, cols=1, shared_xaxes=True,
                        subplot_titles=("Monthly Loan Payments", "Remaining Loan Balance Over Time", "Interest Paid Comparison"),
                        vertical_spacing=0.1)

    # Add principal and interest payment traces
    fig.add_trace(go.Scatter(x=payments, y=principal_payments, mode='lines', name=f'Principal Payment ({loan_term_years}y)'), row=1, col=1)
    fig.add_trace(go.Scatter(x=payments, y=interest_payments, mode='lines', name=f'Interest Payment ({loan_term_years}y)'), row=1, col=1)

    # Add remaining balance trace
    fig.add_trace(go.Scatter(x=payments, y=remaining_balance, mode='lines', name=f'Remaining Balance ({loan_term_years}y)', line=dict(color='firebrick')), row=2, col=1)

    # Add short term (1 year) comparison traces
    fig.add_trace(go.Scatter(x=payments_short, y=principal_payments_short, mode='lines', name='Principal Payment (1y)', line=dict(dash='dot')), row=1, col=1)
    fig.add_trace(go.Scatter(x=payments_short, y=interest_payments_short, mode='lines', name='Interest Payment (1y)', line=dict(dash='dot')), row=1, col=1)
    fig.add_trace(go.Scatter(x=payments_short, y=remaining_balance_short, mode='lines', name='Remaining Balance (1y)', line=dict(color='firebrick', dash='dot')), row=2, col=1)

    selected_loan_term_years = f"{loan_term_years} year"
    if loan_term_years > 1:
        selected_loan_term_years += 's'
    
    # Add total interest paid comparison
    fig.add_trace(go.Bar(x=['1 year', selected_loan_term_years], y=[total_interest_short, total_interest], name='Total Interest Paid'), row=3, col=1)

    # Update layout
    fig.update_layout(height=900, width=1200, title_text="Loan Payment Schedule and Interest Comparison")
    fig.update_xaxes(title_text="Month")
    fig.update_yaxes(title_text="Payment ($)", row=1, col=1)
    fig.update_yaxes(title_text="Remaining Balance ($)", row=2, col=1)
    fig.update_yaxes(title_text="Total Interest Paid ($)", row=3, col=1)
    fig.update_layout(legend=dict(x=0.75, y=0.99, bordercolor="Black", borderwidth=1))
    fig.update_layout(margin=dict(l=20, r=20, t=50, b=20))

    fig.show()

    print(f"Monthly Payment (1 year): ${monthly_payment_short:.2f}")
    print(f"Total Payment (1 year): ${total_payment_short:.2f}")
    print(f"Total Interest Paid (1 year): ${total_interest_short:.2f}")
    print(f"Monthly Payment ({selected_loan_term_years}): ${monthly_payment:.2f}")
    print(f"Total Payment ({selected_loan_term_years}): ${total_payment:.2f}")
    print(f"Total Interest Paid ({selected_loan_term_years}): ${total_interest:.2f}")

# Initial values
car_price = 230000
sales_tax_rate = 0.0913
down_payment = 200000
apr = 0.045
term = 5

style = {'description_width': 'initial'}
slider_layout = Layout(width='50%', height='auto')

interact(
    plot_loan_payments,
    car_price=IntSlider(value=car_price, min=10000, max=500000, step=5000, description='Car Price ($)', layout=slider_layout, style=style),
    sales_tax_rate=FloatSlider(value=sales_tax_rate, min=0, max=0.1, step=0.0025, description='Sales Tax Rate', layout=slider_layout, style=style),
    down_payment=IntSlider(value=down_payment, min=0, max=500000, step=5000, description='Down Payment ($)', layout=slider_layout, style=style),
    annual_rate=FloatSlider(value=apr, min=0.01, max=0.1, step=0.0025, description='Annual Interest Rate', layout=slider_layout, style=style),
    loan_term_years=IntSlider(value=term, min=2, max=7, step=1, description='Loan Term (years)', layout=slider_layout, style=style)
)

interactive(children=(IntSlider(value=230000, description='Car Price ($)', layout=Layout(height='auto', width=…

<function __main__.plot_loan_payments(car_price, sales_tax_rate, down_payment, annual_rate, loan_term_years)>