In [166]:
from collections import defaultdict
from dataclasses import dataclass, asdict
from datetime import timedelta
from typing import List, Union, Optional
from tempfile import NamedTemporaryFile
import arrow
from arrow import Arrow
import pandas as pd
import sqlite3
import logging 

@dataclass
class Asset:
    category: str
    ticker: str
    price: float
    units: int
    date: Arrow
    fmv: Optional[float] = None


@dataclass
class Event:
    action: str
    ticker: str
    price: float
    units: int
    date: Arrow


class Portfolio:
    def __init__(self):
        with NamedTemporaryFile('wb+', delete=False) as f:
            self.c = sqlite3.connect(f.name)
            
        self.c.execute("""
            CREATE TABLE ASSETS (
                _ID INTEGER PRIMARY KEY AUTOINCREMENT,
                CATEGORY TEXT NOT NULL,
                TICKER TEXT NOT NULL,
                PRICE FLOAT NOT NULL,
                UNITS INTEGER NOT NULL,
                DATE DATETIME NOT NULL,
                FMV FLOAT
            )
        """)
        self.c.execute("""
            CREATE TABLE EVENTS (
                _ID INTEGER PRIMARY KEY AUTOINCREMENT,
                ACTION TEXT NOT NULL,
                TICKER TEXT NOT NULL,
                PRICE FLOAT NOT NULL,
                UNITS INTEGER NOT NULL,
                DATE DATETIME NOT NULL
            )
        """)

    @property
    def assets(self):
        return self.query('assets')
                       
    @property
    def events(self):
        return self.query('events')
                       
    def query(self, table: str, order_by: Optional[str]="date", **filters):
        where_args = [f"{k.upper()} = '{v}'" for k, v in filters.items()]
        sep = "\n AND "
        filters_formatted = f"WHERE {sep.join(where_args)}"
        q = f"""
            SELECT * FROM {table.upper()}
            {filters_formatted if filters else ''}
            {f'ORDER BY {order_by.upper()}'}
        """
        df = pd.read_sql(q, self.c, index_col="_ID", parse_dates=["DATE"])
        df.columns = [c.lower() for c in df.columns]
        return df
               
    def _remove_assets(self, ids: list):
        q = f"""
        DELETE FROM ASSETS
        WHERE _ID IN ({', '.join([str(i) for i in ids])})
        """
        self.c.execute(q)
        return self.assets
                       
    def _upsert_asset(self, asset: Asset, _id: Optional[int] = None):
        """Inserts an asset into the database, optionally replacing the 
        one that exists at the index, then sorts by date.
        :param asset: The asset to insert
        :param index: The index of the asset to replace
        """
        if _id:
            q = f"""
            UPDATE ASSETS
                SET CATEGORY = '{asset.category}',
                TICKER = '{asset.ticker}',
                PRICE = {asset.price},
                UNITS = {asset.units},
                DATE = '{asset.date.isoformat()}',
                FMV = {asset.fmv or 'null'}
                WHERE _ID = {_id}
            """
        else:
            sep = ', '
            q = f"""
            INSERT INTO ASSETS ({', '.join(asdict(asset).keys())})
            VALUES (
                '{asset.category}',
                '{asset.ticker}',
                {asset.price},
                {asset.units},
                '{asset.date.isoformat()}',
                {asset.fmv or 'null'}
            )
            """
        self.c.execute(q)
        return self.assets
        
    def _upsert_event(self, event: Event, _id: Optional[int] = None):
        """Inserts an asset into the database, optionally replacing the 
        one that exists at the index, then sorts by date.
        :param asset: The asset to insert
        :param index: The index of the asset to replace
        """
        if _id:
            q = f"""
            UPDATE EVENTS
                SET ACTION = '{event.action}',
                TICKER = '{event.ticker}',
                PRICE = {event.price},
                UNITS = {event.units},
                DATE = '{event.date.isoformat()}',
                WHERE _ID = {_id}
            """
        else:
            q = f"""
            INSERT INTO EVENTS ({', '.join(asdict(event).keys())})
            VALUES (
                '{event.action}',
                '{event.ticker}',
                {event.price},
                {event.units},
                '{event.date.isoformat()}'
            )
            """
        self.c.execute(q)
        return self.events
        
    def cost_basis(self, category: str, ticker: str) -> float:
        assets = self.query(category=category, ticker=ticker)
        return (self.assets.price * self.assets.units).sum()
    
    def total_amount(self, category: str, ticker: str) -> int:
        assets = self.query(category=category, ticker=ticker)
        return self.assets.units.sum()
    
    def grant_option(self, ticker: str, price: float, units: int, date: Arrow) -> Event:
        event = Event('option grant', ticker, price, units, date)
        self._upsert_event(event)
        option = Asset('option', ticker, price, units, date)
        self._upsert_asset(option)
        return event
    
    def grant_options_from_schedule(self, 
                                    ticker: str, 
                                    price: float, 
                                    units: int, 
                                    begin_date: Arrow, 
                                    cliff_date: Arrow,
                                    cutoff_date=Arrow.utcnow(),
                                    num_months=48):
        raw_chunk_size, remainder = divmod(units, num_months)
        amount_to_grant = 0
        for i in range(num_months):
            # "consume" from the remainder while it exists
            chunk_size = raw_chunk_size + 1 if i < remainder else raw_chunk_size
            amount_to_grant += chunk_size
            
            # get date at this chunk of options
            year_delta, month_delta = divmod(i, 12)
            chunk_year = begin_date.year + year_delta
            chunk_month = begin_date.month + month_delta

            # roll chunk month into year if greater than 12
            year_roll, chunk_month = divmod(chunk_month, 12)
            chunk_year += year_roll
            
            chunk_date = Arrow(chunk_year, chunk_month + 1, begin_date.day)
            
            # we can't get grants until the cliff (if it exists) is over
            beyond_cliff = cliff_date is None or chunk_date >= cliff_date
            # we want to add chunks up until this point in time unless
            # the cutoff date is manually set into the future
            not_cut_off = cutoff_date is None or chunk_date <= cutoff_date
            if beyond_cliff and not_cut_off:
                self.grant_option(ticker, price, amount_to_grant, chunk_date)
                amount_to_grant = 0            
        
    def exercise_options(self, ticker: str, price: float, units: int, date: Arrow, fmv: float) -> Event:
        event = Event('option exercise', ticker, price, units, date)
        self._upsert_event(event)
        a = self.query('assets', category='option', ticker=ticker)

        o = a[a.price == price].units.cumsum().reset_index()
        # index of first unit that brings the cumsum over the amount we want to exercise
        over_target_ix = o[o.units >= units].iloc[0].name
        # get the original `p.assets` indeces of all the target units
        target_grants = o.iloc[0:over_target_ix + 1]
        # subtract the units to exercise from the cumulative amount to get the 
        # remainder we want to replace as the value of the last grant in our set
        remainder = target_grants.iloc[-1].units - units
        
        grant_to_update_ix = target_grants['_ID'].iloc[-1]
        grant_to_update = a.iloc[grant_to_update_ix]
        partial_grant = Asset(**grant_to_update.to_dict())
        partial_grant.units = remainder

        self._remove_assets(target_grants['_ID'].iloc[:-1].tolist())
        self._upsert_asset(partial_grant, _id=grant_to_update_ix)
        
        stock = Asset('stock', ticker, price, units, date, fmv)
        self._upsert_asset(stock)
        
        return event

In [167]:
p = Portfolio()
p.grant_options_from_schedule('JEFF', 2.18, 8000, Arrow(2019, 1, 7), Arrow(2020, 1, 7))
p.grant_options_from_schedule('JEFF', 3.08, 5000, Arrow(2020, 6, 1), None)
p.exercise_options('JEFF', 2.18, 2400, Arrow(2020, 12, 30), 15)
p.assets

Unnamed: 0_level_0,category,ticker,price,units,date,fmv
_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
4,option,JEFF,2.18,105,2020-05-07 00:00:00+00:00,
5,option,JEFF,2.18,167,2020-05-07 00:00:00+00:00,
6,option,JEFF,2.18,167,2020-06-07 00:00:00+00:00,
14,option,JEFF,3.08,105,2020-07-01 00:00:00+00:00,
7,option,JEFF,2.18,167,2020-07-07 00:00:00+00:00,
15,option,JEFF,3.08,105,2020-08-01 00:00:00+00:00,
8,option,JEFF,2.18,167,2020-08-07 00:00:00+00:00,
16,option,JEFF,3.08,105,2020-09-01 00:00:00+00:00,
9,option,JEFF,2.18,167,2020-09-07 00:00:00+00:00,
17,option,JEFF,3.08,105,2020-10-01 00:00:00+00:00,


In [175]:
import requests as rq
def get_tax_info(region: str, filing_status: str, capital_gains=False) -> (int, list):
    """Fetch tax brackets and deduction amount for a region of the US for FY2020 
    :param region: Can be any of the 50 states, `district of columbia`, or `federal`
    :param filing_status: Options are (single, married, married_separately, head_of_household)
    :param capital_gains: Whether you want to return capital gains rates instead of income
    """
    if capital_gains and region != 'federal':
        raise ValueError("Can only apply capital gains rate to `federal` region")
        
    normalized_region = region.lower().replace(' ', '_')
    url = (f"https://raw.githubusercontent.com/taxee/taxee-tax-statistics"
           f"/master/src/statistics/2020/{normalized_region}.json")
    res = rq.get(url)
    try:
        res.raise_for_status()
    except rq.RequestException as e:
        raise ValueError(f"'{region}' is not a valid region")
    if region == 'federal':
        data = res.json()['tax_withholding_percentage_method_tables']['annual'][filing_status]
    else:
        data = res.json()[filing_status]
    deduction = data['deductions'][0]['deduction_amount']
    rate_key = 'marginal_capital_gain_rate' if capital_gains else 'marginal_rate'
    brackets = data['income_tax_brackets']
    brackets = [
        {'income_level': d['bracket'], 'marginal_rate': d[rate_key] / 100.0} 
        for d in brackets
    ]
    return deduction, brackets

In [176]:
from typing import List
def taxes(income: float, brackets: List[dict]) -> float:
    """Calculate taxes owed based on progressive bracket"""
    taxes = 0
    for i, bracket in enumerate(brackets):
        if income == 0:
            break
        # if we have yet to hit the last bracket
        if i < len(brackets) - 1:
            next_bracket_income = brackets[i + 1]['income_level']
            income_level_band = next_bracket_income - bracket['income_level']
            portion = min(income, income_level_band)
        # otherwise the rest of the income will be taxed at this last bracket
        else:
            portion = income
        taxes += portion * bracket['marginal_rate']
        income -= portion

    # if there is income left over after exhausting all brackets
    # tax remainder at highest bracket
    if income:
        taxes += income * brackets[-1]['marginal_rate']
    return taxes


In [177]:
def calculate_payroll_tax(income: float) -> float:
    social_security_bracket = [{'income_level': 142_800, 'marginal_rate': 0.062}]
    medicare_rate = 0.0145
    social_security_tax = taxes(income, social_security_bracket)
    medicare_tax = gross_income * medicare_rate
    payroll_tax = social_security_tax + medicare_tax
    return payroll_tax

# payroll taxes are applied to unadjusted gross income
calculate_payroll_tax(gross_income)

NameError: name 'gross_income' is not defined

In [178]:
def calculate_income_taxes(
    income: float, 
    filing_state: str, 
    filing_status='single', 
    federal_deduction=None, 
    state_deduction=None,
) -> dict:
    standard_state_deduction, state_income_brackets = get_tax_info(filing_state, filing_status)
    standard_federal_deduction, federal_income_brackets = get_tax_info('federal', filing_status)
    
    # apply custom deductions if available
    federal_deduction = federal_deduction or standard_federal_deduction
    state_deduction = state_deduction or standard_federal_deduction
    
    federal_income_tax = taxes(income - federal_deduction, federal_income_brackets) 
    state_income_tax = taxes(income - state_deduction, state_income_brackets) 

    return {
        'federal': federal_income_tax, 
        'state': state_income_tax, 
    }
income_taxes = calculate_income_taxes(agi, filing_state)
income_taxes

NameError: name 'agi' is not defined

In [179]:
from typing import Tuple
def calculate_amt_taxes(income: float, options: Tuple[float, int]) -> float:
    amt_exemption = 72_900
    cost_basis = sum(o[0] * o[1] for o in options)
    number_of_options = sum(o[1] for o in options)
    avg_price_per_share = cost_basis / number_of_options
    valuation_at_exercise = number_of_options * fmv

    exercise_spread = valuation_at_exercise - cost_basis
    amt_base = agi + exercise_spread - amt_exemption
    amt_tax_rate = 0.26
    tmt = amt_base * amt_tax_rate
    amt_tax = max(0, tmt - income_taxes['federal'])
    
    return amt_tax

def calculate_capital_gains_taxes(
    income: float, 
    options: options: Tuple[float, int], 
    target_price: float, 
    long_term=False
    ) -> float:
    cost_basis = sum(o[0] * o[1] for o in options)
    number_of_options = sum(o[1] for o in options)
    
    valuation_at_sale = number_of_options * target_price
    gains_from_sale = valuation_at_sale - cost_basis
    _, cap_gains_brackets = get_tax_info('federal', 'single', capital_gains=True)
    cap_gains_tax = taxes(gains_from_sale, cap_gains_brackets)

SyntaxError: invalid syntax (<ipython-input-179-0bc7b09a348b>, line 19)

In [180]:
filing_state = 'georgia'
filing_status = 'single'
gross_income = 127_200
withholdings = 18_000 + 2_500 + 22 # 401k, HSA, Dental
agi = gross_income - withholdings

# Leave None if taking standard deduction,
# otherwise enter value in dollars
nonstandard_state_deduction = None
nonstandard_federal_deduction = None

fmv = 22
target_price = 120

p = Portfolio()
p.grant_options_from_schedule('JEFF', 2.18, 8000, Arrow(2019, 1, 7), Arrow(2020, 1, 7))
p.grant_options_from_schedule('JEFF', 3.08, 5000, Arrow(2020, 6, 1), None)
p.exercise_options('JEFF', 2.18, 2400, Arrow(2020, 12, 30), 15)
p.assets


long_strat_taxes = (amt_tax + cap_gains_tax)
short_cap_gains_taxes = calculate_income_taxes(gains_from_sale, filing_state)
short_strat_taxes = sum(short_cap_gains_taxes.values())
strat_diff = short_strat_taxes - long_strat_taxes

print(f"""
cost to exercise {number_of_options:,} shares at avg price of ${avg_price_per_share:.2f}: ${cost_basis:,.2f}
upfront amt owed if exercised at fmv of ${fmv}: ${amt_tax:,.2f}
total tax paid if held long term and sold at ${target_price}: {long_strat_taxes:,.2f}
total tax paid if sold short term: ${short_strat_taxes:,.2f}
potential savings by going long strat ${strat_diff:,.2f}
""")

NameError: name 'amt_tax' is not defined