# Interactive Brokers Reporting in Local Currency
- show balance per start of year vs balance per end of year
- show taxable profit and loss for stocks, options, forex and CFDs
- show taxable dividends

## TODO
- add forex trades
- add tracking of balances of cash vs local from deposit/ withdrawal/ trade
- next year: Transfers (stock), internal cash movements between accounts

In [None]:
from datetime import datetime

In [None]:
import csv
from pathlib import Path

In [None]:
####### USER SETTINGS #######
DATA_DIR = Path("data/")  # subdir where your .csv file is stored
DATA_FILE = DATA_DIR / Path("2020.csv")   # name of your .csv file with activity statement from ib
DATABASE_NAME = "ib.db"  # name of the sqlite database for storing and analysing activity statement data
# datafile2 = DATA_DIR / Path("MULTI_20210101_20210423.csv")

In [None]:
from csv_parser import process_csv

In [None]:
from forex import forex_rate

In [None]:
from settings import LOCAL_CURRENCY

In [None]:
from enums import NameValueType, TradeType, TradePositionStatus
from decimal import Decimal

In [None]:
from queries import calc_balance, calc_pnl_simplified, calc_net_dividend

## Parse csv

In [None]:
# "Net Asset Value"
# "Open Positions"
# "Trades"
# "Deposits & Withdrawals"
# "Dividends"
# "Withholding Tax"
rows = process_csv(DATA_FILE)

In [None]:
rows.keys()

## Test adding trades to db

In [None]:
from sqlalchemy.sql.functions import coalesce
from sqlalchemy import case, literal, func

In [None]:
from sqlalchemy import null, alias
from sqlalchemy.orm import aliased

In [None]:
from db import SQLASyncDB

In [None]:
#db2 = SQLASyncDB("sqlite:///ib_dev.db", drop=False)

In [None]:
drop=False
db = SQLASyncDB("sqlite:///" + DATABASE_NAME, drop=drop)

In [None]:
type(db.session)

In [None]:
from models.sqla import Trade, DepositsWithdrawals, Dividends, ForexBalance, WitholdingTax
from models.sqla import OpenPositions, NetAssetValue, NameValue, ChangeInDividendAccruals
from models.sqla import tradePosition

In [None]:
db.session.add_all([ChangeInDividendAccruals(**val) for val in rows['Change in Dividend Accruals']])
db.session.commit()

In [None]:
#tr = Trade(**rows['Trades'][0])
db.session.add_all([Trade(**val) for val in rows['Trades']])
db.session.commit()

In [None]:
# dw = DepositsWithdrawals(**rows['Deposits & Withdrawals'][0])
db.session.add_all([DepositsWithdrawals(**val) for val in rows['Deposits & Withdrawals']])
db.session.commit()

In [None]:
#model = Dividends(**rows['Dividends'][0])
db.session.add_all([Dividends(**val) for val in rows['Dividends']])
db.session.commit()

In [None]:
#model = ForexBalance(**rows['Forex Balances'][0])
db.session.add_all([ForexBalance(**val) for val in rows['Forex Balances']])
db.session.commit()

In [None]:
#model = WitholdingTax(**rows['Withholding Tax'][0])
db.session.add_all([WitholdingTax(**val) for val in rows['Withholding Tax']])
db.session.commit()

In [None]:
#model = OpenPositions(**rows['Open Positions'][0])
db.session.add_all([OpenPositions(**val) for val in rows['Open Positions']])
db.session.commit()

In [None]:
#model = NetAssetValue(**rows['Net Asset Value'][0])
db.session.add_all([NetAssetValue(**val) for val in rows['Net Asset Value']])
db.session.commit()

In [None]:
db.session.add_all([NameValue(**val) for val in rows["Statement"]])
db.session.add_all([NameValue(**val) for val in rows["Account Information"]])
db.session.commit()

## Queries

In [None]:
# Infer Base Currency
BASE_CURRENCY = db.session.query(NameValue.Value).filter(NameValue.type==NameValueType.ACCOUNT_INFORMATION, NameValue.Name=="Base Currency").scalar()
BASE_CURRENCY

In [None]:
# Infer balance date from statement
STATEMENT_END_DATE_STR = db.session.query(NameValue.Value).filter(NameValue.type==NameValueType.STATEMENT, NameValue.Name=="Period").scalar()
STATEMENT_END_DATE = datetime.strptime(STATEMENT_END_DATE_STR.split(" - ")[-1], "%B %d, %Y")
STATEMENT_END_DATE

In [None]:
# infer FIAT trades
ForexPairs = set()
FiatTrades = set()
symbols = db.session.query(Trade.Symbol).filter(Trade.Asset_Category.like("Forex%")).group_by(Trade.Symbol).all()
for symbol in symbols:
    ForexPairs.add(symbol[0])
    base, quote = symbol[0].split(".")
    FiatTrades.add(base)
    FiatTrades.add(quote)
print(f"Forex pairs: {ForexPairs}")
print(f"Fiat traded: {FiatTrades}")
EOY_FOREX_RATES = {}
for currency in FiatTrades:
    EOY_FOREX_RATES[currency] = forex_rate(currency, LOCAL_CURRENCY, STATEMENT_END_DATE)
print(f"Forex rates vs {LOCAL_CURRENCY} @ {STATEMENT_END_DATE}: {EOY_FOREX_RATES}")

In [None]:
# Base in Local (Quote) currency @ statement end date
EOY_BASE_LOCAL = forex_rate(BASE_CURRENCY, LOCAL_CURRENCY, STATEMENT_END_DATE)
print(f"Base currency quote in local currency {BASE_CURRENCY}.{LOCAL_CURRENCY} = {EOY_BASE_LOCAL}")

In [None]:
def calc_balance():
    """
    Calculate the balance of equities and cash at statement end date
    """
    QUANTIZE_FIAT = Decimal('1.00')
    sum_equity_base = sum_cash_base = sum_change_in_dividends_base = Decimal(0)
    #sum_balance_base = sum_balance_local_currency = Decimal(0)
    print(f"BALANCE AT {STATEMENT_END_DATE_STR}\n")
    print("*** EQUITY BALANCE ***")
    q = (
        db.session
        .query(OpenPositions.Symbol, OpenPositions.Quantity, OpenPositions.Mult, OpenPositions.Value)
        .order_by(OpenPositions.Symbol)
    )
    for symbol, quantity, multiplier, value in q:
        sum_equity_base += value
        val_local_currency = value * EOY_BASE_LOCAL
    
        print(
            f"{quantity} {symbol} @ {value} {BASE_CURRENCY}/ "
            f"{val_local_currency.quantize(Decimal('1.00'))} {LOCAL_CURRENCY}"
        )
    print("-------------------------------")
    print(
        f"SUB TOTAL EQUITY {sum_equity_base} {BASE_CURRENCY}/ "
        f"{(sum_equity_base * EOY_BASE_LOCAL).quantize(QUANTIZE_FIAT)} {LOCAL_CURRENCY}")
    print()
    print("*** CASH BALANCE ***")
    for currency, qty, close_value_in_base_at_statement_end in (
        db.session.query(ForexBalance.Description, ForexBalance.Quantity, ForexBalance.Close_Price)):
        val_base = qty * close_value_in_base_at_statement_end
        val_local_currency = val_base * EOY_BASE_LOCAL
        sum_cash_base += val_base
        print(
            f"{qty} {currency}: {val_base.quantize(Decimal('1.00'))} {BASE_CURRENCY}/ "
            f"{val_local_currency.quantize(Decimal('1.00'))} {LOCAL_CURRENCY}")
    print("-------------------------------")
    print(
        f"SUB TOTAL CASH {sum_cash_base} {BASE_CURRENCY}/ "
        f"{(sum_cash_base * EOY_BASE_LOCAL).quantize(QUANTIZE_FIAT)} {LOCAL_CURRENCY}")
    print()
    print("*** CHANGE IN DIVIDEND ACCRUALS ***")
    sum_change_in_dividends_base = db.session.query(func.sum(ChangeInDividendAccruals.Net_Amount)).scalar()
    print(sum_change_in_dividends_base)
    print()
    totals = sum_equity_base + sum_cash_base + sum_change_in_dividends_base
    print(
        f"TOTAL: {totals.quantize(QUANTIZE_FIAT)} {BASE_CURRENCY}/ "
        f"{(totals * EOY_BASE_LOCAL).quantize(QUANTIZE_FIAT)} {LOCAL_CURRENCY}")

In [None]:
calc_balance()

In [None]:
def calc_net_dividend(session):
    """
    Calc sum of dividend - witholding tax 
    """
    dividends = session.query(
        Dividends.Symbol.label("asset"), Dividends.Currency.label("currency"),
        Dividends.Amount.label("amount"), 
        (Dividends.Amount * Dividends.QuoteInLocalCurrency).label("amount_local")
    )
    witholdings = session.query(
        WitholdingTax.Symbol.label("asset"), WitholdingTax.Currency.label("currency"),
        WitholdingTax.Amount.label("amount"), 
        (WitholdingTax.Amount * WitholdingTax.QuoteInLocalCurrency).label("amount_local")
    )
    net_dividend = dividends.union_all(witholdings).subquery()
    q = (
        session.query(
            net_dividend.c.asset.label("asset"), 
            func.sum(net_dividend.c.amount).label("net_dividend"),
            func.sum(net_dividend.c.amount_local).label("net_dividend_local")
        )
        .group_by(net_dividend.c.asset)
        .order_by(net_dividend.c.asset)
    )
    
    #for _id, _account_id, _currency, _account, _description, _dt, _q, _forex in q:
    print(f"NET DIVIDENDS (DIVIDEND - WITHOLDING TAX) AT {STATEMENT_END_DATE_STR}\n")
    for row in q:
        #print(f"{_id} {_currency} {_description} {_dt} {_q} {_forex}")
        print(f"{row.asset}: {row.net_dividend} {BASE_CURRENCY}, {row.net_dividend_local} {LOCAL_CURRENCY}")

In [None]:
calc_net_dividend(db.session)

## Calculate PNL

In [None]:
from queries import process_base_query

In [None]:
tradedb = SQLASyncDB("sqlite:///tradedb.db", drop=False)

In [None]:
tradedb.drop_all()
tradedb.create_all(models=[tradePosition])

In [None]:
process_base_query(db.session, tradedb.session, forex_rate=forex_rate)

In [None]:
(
    tradedb.session.query(tradePosition.asset, func.sum(tradePosition.qty))
    .filter(tradePosition.status==TradePositionStatus.OPEN)
    .group_by(tradePosition.asset)
    .all()
)

In [None]:
# PNL per asset - NOTE USD VALUE LOSS!
(
    tradedb.session.query(tradePosition.asset, func.sum(tradePosition.pnl_base), func.sum(tradePosition.pnl_local))
    .filter(tradePosition.status==TradePositionStatus.CLOSED)
    .group_by(tradePosition.asset)
    .all()
)

In [None]:
# PNL total - compare pnl_base to realized 
(
    tradedb.session.query(func.sum(tradePosition.pnl_base), func.sum(tradePosition.pnl_local))
    .filter(tradePosition.status==TradePositionStatus.CLOSED)
    .all()
)

In [None]:
# PNL total - compare pnl_base to realized 
(
    tradedb.session.query(func.sum(tradePosition.pnl_base), func.sum(tradePosition.pnl_local))
    .filter(tradePosition.status==TradePositionStatus.CLOSED, tradePosition.asset != "USD")
    .all()
)

## Deltas
- deposits + withdrawals
- trades
- dividend
- witholding tax

In [None]:
# Statement and Account Information
db.session.query(NameValue.type, NameValue.Name).all()

In [None]:
# trade symbol holdings with qty <> 0 ("short or long position open at balance date")
q = (db.session
     .query(Trade.Symbol, func.sum(Trade.Quantity).label("sum"))
     .group_by(Trade.Symbol)
     .order_by(Trade.Symbol).subquery()
    )
q2 = db.session.query(q.c.Symbol, q.c.sum).filter(or_(q.c.sum > Decimal("0"), q.c.sum < Decimal("0")))
for r in q2:
    print(r)

In [None]:
# alt: trade symbol holdings with qty <> 0 ("short or long position open at balance date")
q = (db.session
     .query(Trade.Symbol, func.sum(Trade.Quantity).label("sum"))
     .group_by(Trade.Symbol)
     .order_by(Trade.Symbol).subquery()
    )
q2 = db.session.query(q.c.Symbol, q.c.sum).filter(not_(q.c.sum == Decimal(0)))
for r in q2:
    print(r)

In [None]:
# NOTE: EUR.USD included in the above, need to factor in FIAT DepositWithdrawals

In [None]:
db.session.query(DepositsWithdrawals).all()

In [None]:
db.session.query(DepositsWithdrawals.Currency, func.sum(DepositsWithdrawals.Amount)).group_by(DepositsWithdrawals.Currency).all()

In [None]:
# Open positions
db.session.query(OpenPositions.Symbol, OpenPositions.Quantity).order_by(OpenPositions.Symbol).all()

In [None]:
# trade symbols totals
q = (db.session
     .query(
         Trade.Symbol,
         func.sum(Trade.Quantity).label("quantity"),
         func.sum(Trade.Proceeds).label("proceeds"),
         func.sum(Trade.CommOrFee).label("comm_or_fee"),
         func.sum(Trade.Basis).label("basis"),
         func.sum(Trade.Realized_PnL).label("realized"),
         func.sum(Trade.Realized_PnL_pct).label("realized_pct"),
         func.sum(Trade.MTM_PnL).label("mtm_pnl"),
         func.sum(Trade.Comm_in_USD).label("comm_USD"),
         func.sum(Trade.MTM_in_USD).label("MTM_USD"),    
     )
     .group_by(Trade.Symbol)
     .order_by(Trade.Symbol)
    )
for r in q:
    print(r)

In [None]:
# Equity/ Open positions balance, value and value in local currency @ statement period end date
q = (
    db.session
    .query(OpenPositions.Symbol, OpenPositions.Quantity, OpenPositions.Mult, OpenPositions.Value)
    .order_by(OpenPositions.Symbol)
)
for symbol, quantity, multiplier, value in q:
    print(
        f"{quantity} {symbol} @ {value} {BASE_CURRENCY}/ "
        f"{(value * EOY_BASE_LOCAL).quantize(Decimal('1.00'))} {LOCAL_CURRENCY}"
    )
    

In [None]:
# Cash balance, note that description is for the FIAT wheres Currency == base currecny
for currency, qty, close_value_in_base_at_statement_end in db.session.query(ForexBalance.Description, ForexBalance.Quantity, ForexBalance.Close_Price):
    print(
        f"{qty} {currency}: {(qty * close_value_in_base_at_statement_end).quantize(Decimal('1.00'))} {BASE_CURRENCY}/ "
        f"{(qty * close_value_in_base_at_statement_end * EOY_BASE_LOCAL).quantize(Decimal('1.00'))} {LOCAL_CURRENCY}")

In [None]:
def print_trade(trade):
    print("TRADE ----")
    print(
        f"ID {trade.id} Symbol {trade.Symbol}, Type {TradeType(trade.type).value}, proceeds {trade.Proceeds}"
        f"real pnl {trade.Realized_PnL}, basis {trade.Basis}, C_price {trade.C_Price}, "
        f"T price {trade.T_Price}, quantity {trade.Quantity}, CommorFee {trade.CommOrFee}"
        f"Comm usd {trade.Comm_in_USD}"
        f"Proceeds + Comm {trade.Proceeds + trade.CommOrFee} vs {-trade.Basis + trade.Realized_PnL} basis + pnl"
    )
    '''
    Quantity = Column(Numeric(6, 2))
    Proceeds = Column(Numeric(6, 2))
    MTM_PnL = Column(Numeric(6, 2))
    MTM_in_USD = Column(Numeric(6, 2))
    DateTime = Column(DateTime, nullable=False)'''

In [None]:
print_trade(db.session.query(Trade).get(95))

In [None]:
atrade = db.session.query(Trade).first()
atrade.id

In [None]:
atrade.Quantity

In [None]:
atrade.Quantity = Decimal(100)

In [None]:
db.session.commit()

In [None]:
def show_trade_deltas():
    """
    Show the per symbol trade quantities sorted by date
    while tracking running total of balance
    """
    q = (
        db.session.query(
            Trade.Symbol, 
            func.extract("month", Trade.DateTime),
            func.extract("day", Trade.DateTime),
            #Trade.DateTime, 
            Trade.Quantity,
            Trade.Basis,
            func.sum(Trade.Quantity).over(
                partition_by=Trade.Symbol, order_by=(Trade.DateTime)).label("Balance"),
            func.sum(Trade.Basis / Trade.Quantity).over(
                partition_by=Trade.Symbol, order_by=(Trade.DateTime)).label("AvgPrice"),
            #Trade.QuoteInLocalCurrency #, Trade.Proceeds, Trade.CommOrFee
        )
    )
    for row in q:
        print(row)

In [None]:
show_trade_deltas()

## Table tries
- buys: avg entry: Quantity, Basis, balance, lag balance, 

In [None]:
# DCA price - wrong, do not take the sells in the dca price
q = (
    db.session.query(
        Trade.Symbol,
        (func.sum(Trade.Basis).over(
            partition_by=Trade.Symbol, order_by=(Trade.DateTime)) /
        func.sum(Trade.Quantity).over(
            partition_by=Trade.Symbol, order_by=(Trade.DateTime))).label('dca_price')
    )
)

In [None]:
for row in q:
    print(row)

In [None]:
# avg entry - on buys only
# challenge, need balance taking into account both buys and sells whereas calc only to apply on buys
# new avg = (basis + old_avg * old_bal) / balance
buys = (
    db.session.query(
        Trade.Basis.label('basis'),
        Trade.id.label('id'),
        (Trade.Basis / Trade.Quantity).label('buy_avg')
    )
    .filter(Trade.Quantity > Decimal(0))
).subquery()

balq = (
    db.session.query(
        Trade.Symbol.label('symbol'),
        Trade.DateTime.label('dt'),
        Trade.Quantity.label('delta'),
        func.sum(Trade.Quantity).over(
            partition_by=Trade.Symbol, order_by=(Trade.DateTime)).label('balance'),
        buys.c.basis.label('buybasis'),
        buys.c.buy_avg.label('buy_avg')
    )
    .outerjoin(buys, buys.c.id==Trade.id)
).subquery()
balWlag = (
    db.session.query(
        balq.c.symbol, balq.c.delta, balq.c.balance,
        coalesce(balq.c.buybasis, Decimal(0)),
        balq.c.balance - balq.c.delta,
        #func.lag(balq.c.balance, 1, 0).over(balq.c.symbol),
        #balq.c.buy_avg,
        coalesce(balq.c.buy_avg, Decimal(0)),
        func.first_value(balq.c.buy_avg).over(partition_by=balq.c.symbol)
    )
    .order_by(balq.c.dt)
)

print('symbol, qty, bal, buybasis, prev bal, buy_avg')
for row in balWlag:
    print(row)
    
#### CHALLENGE IS TO PULL THE LAST COLUMN DATA DOWN WHEN THERE IS DECIMAL(0) characterized with id, dt << trade.id, trade.dt

In [None]:
# avg entry case based
from sqlalchemy import case, literal

q = (
    db.session.query(
        Trade.DateTime.label('dt'),
        Trade.Symbol.label('asset'),
        Trade.Quantity.label('delta'),
        func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime).label('balance'),
        case([(Trade.Quantity > Decimal(0), Trade.Basis),],
        else_ = func.lag(Trade.Basis, 1, 0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime) #literal(0)
        ).label('basis'),  # basis for buys or prev buy basis
        case([(Trade.Quantity > Decimal(0), Trade.Basis/ Trade.Quantity),],
        else_ = literal(0)
        ).label('avg_entry') # avg entry
    )
).subquery()

q2 = (
    db.session.query(
        q.c.asset.label('asset'),
        q.c.delta.label('delta'),
        q.c.balance.label('balance'),
        (q.c.balance - q.c.delta).label('old_balance'),
        q.c.basis.label('basis'),
        func.lag(q.c.basis,1,0).over(partition_by=q.c.asset).label('old_basis'),
        case([(q.c.avg_entry != literal(0), q.c.avg_entry),],
        else_ = func.lag(q.c.avg_entry, 1, 0).over(partition_by=q.c.asset, order_by=q.c.dt) #literal(0)
        ).label('avg_entry')
        #coalesce(q.c.avg_entry.label('avg_entry'), func.lag(q.c.avg_entry,1,0).over(partition_by=q.c.asset, order_by=q.c.dt))
    ).order_by(q.c.dt)
).subquery()

q3 = (
    db.session.query(
        q2.c.asset.label('asset'),
        q2.c.delta.label('delta'),
        q2.c.balance.label('balance'),
        case([(q2.c.delta > Decimal(0), (q2.c.basis + q2.c.old_balance * q2.c.avg_entry)/q2.c.balance),],
        else_ = literal(0)
        ).label('avg_entry')
              
    )
)
for row in q3:
    print(row)


In [None]:
q1 = (
    db.session
    .query(
        Trade.id.label('id'),
        func.sum(Trade.Basis/ Trade.Quantity).over(
            partition_by=Trade.Symbol, order_by=(Trade.DateTime)).label("avg_entry")
        
    )
    .filter(Trade.type==TradeType.STOCKS, Trade.Quantity > Decimal("0"))
).subquery()

q2 = (
    db.session.query(
        Trade.Symbol, 
        func.extract("month", Trade.DateTime),
        func.extract("day", Trade.DateTime),
        #Trade.DateTime, 
        Trade.Quantity,
        func.sum(Trade.Quantity).over(
            partition_by=Trade.Symbol, order_by=(Trade.DateTime)).label("balance"),
        func.lag(q1)
    .outerjoin()
)
for row in q1:
    print(row)

In [None]:
q1 = (
    db.session
    .query(
        Trade.id.label('id'),
        Trade.Symbol.label('symbol'),
        func.extract('day', Trade.DateTime).label('day'), 
        Trade.Quantity.label('q'),
        Trade.Basis.label('basis'),
        func.sum(Trade.Quantity).over(
            partition_by=Trade.Symbol, order_by=(Trade.DateTime)).label("balance"),
        func.sum(Trade.Basis/ Trade.Quantity).over(
            partition_by=Trade.Symbol, order_by=(Trade.DateTime)).label("avg_entry")
        
    )
    .filter(Trade.Quantity > Decimal("0"))
    #.order_by(Trade.Symbol, Trade.DateTime)
).subquery()

q2 = (
    db.session.query(
        Trade.Symbol, 
        func.extract("month", Trade.DateTime),
        func.extract("day", Trade.DateTime),
        #Trade.DateTime, 
        Trade.Quantity,
        Trade.Basis,
        func.sum(Trade.Quantity).over(
            partition_by=Trade.Symbol, order_by=(Trade.DateTime)).label("balance"),
        q1.c.avg_entry,
        q1.c.avg_entry + func.lag(q1.c.avg_entry, 1, 0).over(partition_by=Trade.Symbol)
    )
    .outerjoin(q1, q1.c.id==Trade.id)
    .filter(Trade.type==TradeType.STOCKS)
    .order_by(Trade.Symbol, Trade.DateTime)
)
for row in q2:
    print(row)

In [None]:
(
    db.session.query(
        Trade.id, Trade.Symbol,
        func.lag(Trade.id, 1, 1).over(partition_by=Trade.Symbol))
).all()

In [None]:
(
    db.session.query(
        Trade.id,
        case([(Trade.id % 2 == 0, Trade.id),],
             else_=func.lag(Trade.id,1,0).over(order_by=Trade.id)
            ),
        func.lag(Trade.id,1,0).over(order_by=Trade.id)
        
        
    )
).all()

In [None]:
# Trailing basis of last bought
(
    db.session.query(
        Trade.Symbol,
        Trade.Quantity,
        case([(Trade.Quantity > Decimal(0), Trade.Basis),],
             else_=func.lag(Trade.Basis,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
            ),
    )
    .order_by(Trade.Symbol)
).all()

In [None]:
# Trailing avg_price of last bought
(
    db.session.query(
        Trade.Symbol,
        Trade.Quantity,
        case([(Trade.Quantity > Decimal(0), Trade.Basis),],
             else_=func.lag(Trade.Basis,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
            ),
        case([(Trade.Quantity > Decimal(0), Trade.Basis / Trade.Quantity),],
             else_=func.lag(Trade.Basis / Trade.Quantity,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
            ),
    )
    .order_by(Trade.Symbol)
).all()

In [None]:
# delta, balance, old_balance, most recent buy basis, most recent avg price
(
    db.session.query(
        Trade.Symbol.label('asset'),
        Trade.Quantity.label('delta'),
        func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime).label('balance'),
        (func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime) - Trade.Quantity).label('prev_balance'),
        case([(Trade.Quantity > Decimal(0), Trade.Basis),],
             else_=func.lag(Trade.Basis,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
            ).label('last_buy_basis'),
        case([(Trade.Quantity > Decimal(0), Trade.Basis / Trade.Quantity),],
             else_=func.lag(Trade.Basis / Trade.Quantity,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
            ).label('last_avg_buy_price'),
    )
    .order_by(Trade.Symbol)

).all()

In [None]:
# delta, balance, old_balance, most recent buy basis, most recent avg price subquery 
# with main query tracking change in avg price
q = (
    db.session.query(
        Trade.DateTime.label('dt'),
        Trade.Symbol.label('asset'),
        Trade.Quantity.label('delta'),
        func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime).label('balance'),
        (func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime) - Trade.Quantity).label('prev_balance'),
        case([(Trade.Quantity > Decimal(0), Trade.Basis),],
             else_=func.lag(Trade.Basis,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
            ).label('last_buy_basis'),
        case([(Trade.Quantity > Decimal(0), Trade.Basis / Trade.Quantity),],
             else_=func.lag(Trade.Basis / Trade.Quantity,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
            ).label('last_avg_buy_price'),
    )
    .order_by(Trade.Symbol)

).subquery()
q2 = (
    db.session.query(
        q.c.asset.label('asset'), q.c.delta.label('delta'), q.c.balance.label('balance'), 
        q.c.last_buy_basis.label('last_buy_basis'), q.c.prev_balance.label('prev_balance'),
        func.lag(q.c.last_avg_buy_price,1,0).over(partition_by=q.c.asset, order_by=q.c.dt).label('prev_buy_price'),
        case([(
            q.c.delta > Decimal(0),
            (q.c.last_buy_basis + q.c.prev_balance * func.lag(q.c.last_avg_buy_price,1,0).over(partition_by=q.c.asset, order_by=q.c.dt))/ q.c.balance),],
            else_=literal(0)).label('last_avg')
    )
    .order_by(q.c.asset)
)
for row in q2:
    print(row)
    
# NOW MAKE THE LAST COL STICKY

In [None]:
def cost_basis_query(session):
    """
    Return query that provides id, asset, delta, balance, average cost price in base
    currency and average cost price in local currency
    Provides a means to derive pnl_base and pnl_local from
    """
    q = (
            session.query(
                Trade.id.label('id'),
                Trade.DateTime.label('dt'),
                Trade.Symbol.label('asset'),
                Trade.Quantity.label('delta'),
                #Trade.QuoteInLocalCurrency.label('forex_rate'),
                func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime).label('balance'),
                (
                    (func.sum(Trade.Quantity)
                     .over(partition_by=Trade.Symbol, order_by=Trade.DateTime) - Trade.Quantity).label('prev_balance')
                ),
                case([(Trade.Quantity > Decimal(0), Trade.Basis),],
                     else_=func.lag(Trade.Basis,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
                    ).label('last_buy_basis'),
                case([(Trade.Quantity > Decimal(0), Trade.Basis * Trade.QuoteInLocalCurrency),],
                     else_=func.lag(Trade.Basis * Trade.QuoteInLocalCurrency,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
                    ).label('last_buy_basis_local'),
                case([(Trade.Quantity > Decimal(0), Trade.Basis / Trade.Quantity),],
                     else_=(
                         func.lag(Trade.Basis / Trade.Quantity,1,0)
                         .over(
                             partition_by=Trade.Symbol,
                             order_by=Trade.DateTime))
                    ).label('last_avg_buy_price'),
                case([(Trade.Quantity > Decimal(0), Trade.Basis * Trade.QuoteInLocalCurrency / Trade.Quantity),],
                     else_=(
                         func.lag(Trade.Basis * Trade.QuoteInLocalCurrency / Trade.Quantity,1,0)
                         .over(
                             partition_by=Trade.Symbol,
                             order_by=Trade.DateTime))
                    ).label('last_avg_buy_price_local'),
            )
            .order_by(Trade.Symbol, Trade.DateTime)

        ).subquery()
    q2 = (
        session.query(
            q.c.id.label('id'),
            #q.c.forex_rate.label('forex_rate'),
            q.c.dt.label('dt'), q.c.asset.label('asset'), q.c.delta.label('delta'), q.c.balance.label('balance'), 
            q.c.last_buy_basis.label('last_buy_basis'), q.c.prev_balance.label('prev_balance'),
            q.c.last_buy_basis_local.label('last_buy_basis_local'),
            q.c.last_avg_buy_price_local.label('last_avg_buy_price_local'),
            func.lag(q.c.last_avg_buy_price,1,0).over(partition_by=q.c.asset, order_by=q.c.dt).label('prev_buy_price'),
            func.lag(q.c.last_avg_buy_price_local,1,0).over(partition_by=q.c.asset, order_by=q.c.dt).label('prev_buy_price_local'),
            case([(
                q.c.delta > Decimal(0),
                (q.c.last_buy_basis + q.c.prev_balance * func.lag(q.c.last_avg_buy_price,1,0).over(partition_by=q.c.asset, order_by=q.c.dt))/ q.c.balance),],
                else_=literal(0)).label('last_avg'),
            case([(
                q.c.delta > Decimal(0),
                (q.c.last_buy_basis_local + q.c.prev_balance * func.lag(q.c.last_avg_buy_price_local,1,0).over(partition_by=q.c.asset, order_by=q.c.dt))/ q.c.balance),],
                else_=literal(0)).label('last_avg_local')
        )
        .order_by(q.c.asset, q.c.dt)
    ).subquery()
    q3 = (
        session.query(
            q2.c.id.label('id'),
            #q2.c.forex_rate.label('forex_rate')
            q2.c.asset.label('asset'), 
            q2.c.delta.label('delta'), 
            q2.c.balance.label('balance'),
            case([(
                q2.c.last_avg != literal(0), q2.c.last_avg),],
                else_=func.lag(q2.c.last_avg,1,0).over(partition_by=q2.c.asset, order_by=q2.c.dt)
            ).label('cost_price_base'),
            case([(
                q2.c.last_avg_local != literal(0), q2.c.last_avg_local),],
                else_=func.lag(q2.c.last_avg_local,1,0).over(partition_by=q2.c.asset, order_by=q2.c.dt)
            ).label('cost_price_local')
        )
        .order_by(q2.c.asset, q2.c.dt)
    )
    return q3


In [None]:
for row in cost_basis_query(db.session):
    print(row)

In [None]:
def calc_pnl(session):
    """
    calculate the pnl for trades on the basis of the average cost basis (buy) price
    """
    cb = cost_basis_query(session).subquery()
    sells = (
        session.query(
            Trade.id, Trade.DateTime,
            Trade.T_Price.label("sell_price"),
            Trade.QuoteInLocalCurrency.label("forex"),
            cb.c.asset, cb.c.delta, cb.c.balance, cb.c.cost_price_base, cb.c.cost_price_local
        )
        .filter(Trade.Quantity < Decimal(0))
        .outerjoin(cb, cb.c.id==Trade.id)
    )
    for _id, _dt, _sell_price, _forex, _asset, _delta, _balance, _cost_price_base, _cost_price_local in sells:
        print(
            f"TRADE {_id} {_delta} {_asset} PNL {_delta * (_sell_price - _cost_price_base)}"
            f" PNL LOCAL {_delta * (_sell_price * _forex - _cost_price_local)}"
        )

In [None]:
calc_pnl(db.session)

In [None]:
# delta, balance, old_balance, most recent buy basis, most recent avg price subquery 
# with main query tracking change in avg price
q = (
    db.session.query(
        Trade.DateTime.label('dt'),
        Trade.Symbol.label('asset'),
        Trade.Quantity.label('delta'),
        func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime).label('balance'),
        (
            (func.sum(Trade.Quantity)
             .over(partition_by=Trade.Symbol, order_by=Trade.DateTime) - Trade.Quantity).label('prev_balance')
        ),
        case([(Trade.Quantity > Decimal(0), Trade.Basis),],
             else_=func.lag(Trade.Basis,1,0).over(partition_by=Trade.Symbol, order_by=Trade.DateTime)
            ).label('last_buy_basis'),
        case([(Trade.Quantity > Decimal(0), Trade.Basis / Trade.Quantity),],
             else_=(
                 func.lag(Trade.Basis / Trade.Quantity,1,0)
                 .over(
                     partition_by=Trade.Symbol,
                     order_by=Trade.DateTime))
            ).label('last_avg_buy_price'),
    )
    .order_by(Trade.Symbol, Trade.DateTime)

).subquery()
q2 = (
    db.session.query(
        q.c.dt.label('dt'), q.c.asset.label('asset'), q.c.delta.label('delta'), q.c.balance.label('balance'), 
        q.c.last_buy_basis.label('last_buy_basis'), q.c.prev_balance.label('prev_balance'),
        func.lag(q.c.last_avg_buy_price,1,0).over(partition_by=q.c.asset, order_by=q.c.dt).label('prev_buy_price'),
        case([(
            q.c.delta > Decimal(0),
            (q.c.last_buy_basis + q.c.prev_balance * func.lag(q.c.last_avg_buy_price,1,0).over(partition_by=q.c.asset, order_by=q.c.dt))/ q.c.balance),],
            else_=literal(0)).label('last_avg')
    )
    .order_by(q.c.asset)
).subquery()
q3 = (
    db.session.query(
        q2.c.asset, q2.c.delta, q2.c.balance,
        case([(
            q2.c.last_avg != literal(0), q2.c.last_avg),],
            else_=func.lag(q2.c.last_avg,1,0).over(partition_by=q2.c.asset, order_by=q2.c.dt)
        )
    )
    .order_by(q2.c.asset)
)
for row in q3:
    print(row)

# New idea on PNL comparison basis
## Challenge
- deltas, balance and prev balance can be tracked across buys/ sells, include all model.ids
- avg basis can be tracked across buys, do not include model.id for sells
- pnl must be tracked for sells but do not include model.id for buys
- **merge all on the deltas table**

- track the sum weighted average basis * forex rate over time in separate table and join on last aggregate

In [None]:
session = db2.session

buys = (
    session.query(
        Trade.id.label('id'), Trade.Basis.label('basis'), Trade.QuoteInLocalCurrency.label('forex'),
        (Trade.Basis / Trade.Quantity).label('buy_price')
    )
    .filter(Trade.Quantity > Decimal(0))
).subquery()

sells = (
    session.query(
        Trade.id.label('id'), Trade.T_Price.label('sell_price'), 
        Trade.Basis.label('basis'), Trade.QuoteInLocalCurrency.label('forex'),
        Trade.Realized_PnL.label('pnl')
    )
    .filter(Trade.Quantity < Decimal(0))
).subquery()

deltas = (
    session.query(
        Trade.id.label('id'),
        Trade.Symbol.label('asset'),
        Trade.Quantity.label('delta'),
        func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime).label('balance'),
        (func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime) - Trade.Quantity).label('prev_balance'),
    )
    .group_by(Trade.Symbol, Trade.DateTime)
).subquery()

d1 = aliased(deltas)
d2 = aliased(deltas)

# buys - w average price based on lagging price
buy1 = (
    session.query(
        d1.c.id.label('id'), d1.c.asset.label('asset'), d1.c.delta.label('delta'), 
        buys.c.buy_price.label('buy_price'),
        d1.c.balance.label('balance'), d1.c.prev_balance.label('prev_balance'),
        func.lag(buys.c.buy_price, 1, 0).over(partition_by=d1.c.asset, order_by=d1.c.id).label('prev_buy'),
        (
            (buys.c.basis + d1.c.prev_balance * func.lag(buys.c.buy_price, 1, 0).over(partition_by=d1.c.asset, order_by=d1.c.id)) / d1.c.balance).label('avg_cost_price')
    )
    .outerjoin(d1, d1.c.id==buys.c.id)
    .order_by(d1.c.asset, d1.c.id)
).subquery()
#for _id, _asset, _delta, _buy_price, _balance, _prev_balance, _prev_buy, _avg_cost_price in buy1:
#    print(f"{_asset} {_prev_balance} {_delta} {_buy_price} {_balance} {_avg_cost_price}")

# DONT KNOW IF THIS TRACK THE PRICE PROPERLY
# id-table with lagging avg_cost_price
merge = (
    session.query(
        d2.c.id.label('id'), d2.c.asset.label('asset'), d2.c.delta.label('delta'),
        buy1.c.avg_cost_price.label('avg_cost_price'), buy1.c.id.label('bid')
    )
    .outerjoin(buy1, buy1.c.id==d2.c.id)
).subquery()



#for row in merge:
#    print(row)

# create table: for each id, find relevant id of buy to use
# i.d. if id 6 and id 11 == buy, the merge table id 6 == buy id 6, .. merge table id 10== buy id 6 
# and merge table id 11== buy id 11
# this is then used to fetch the avg price by joining on column with buy.c.id == merge_table_id for merge id
'''

users_cte = select([users.c.id, users.c.name]).where(users.c.name == 'wendy').cte()

WITH FormattedT1 AS
(
    SELECT Countryid, Stateid, Value, ROW_NUMBER() OVER(PARTITION BY Countryid ORDER BY ...) AS num
    FROM T1
    WHERE Value > 0
)
SELECT Countryid, Stateid, Value
FROM FormattedT1
WHERE num = 1
'''


In [None]:
session = db.session

buys = (
    session.query(
        Trade.id.label('id'), Trade.Basis.label('basis'), Trade.QuoteInLocalCurrency.label('forex'),
        (Trade.Basis / Trade.Quantity).label('buy_price')
    )
    .filter(Trade.Quantity > Decimal(0))
).cte()

sells = (
    session.query(
        Trade.id.label('id'), Trade.T_Price.label('sell_price'), 
        Trade.Basis.label('basis'), Trade.QuoteInLocalCurrency.label('forex'),
        Trade.Realized_PnL.label('pnl')
    )
    .filter(Trade.Quantity < Decimal(0))
).subquery()

deltas = (
    session.query(
        Trade.id.label('id'),
        Trade.Symbol.label('asset'),
        Trade.Quantity.label('delta'),
        func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime).label('balance'),
        (func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime) - Trade.Quantity).label('prev_balance'),
    )
    .group_by(Trade.Symbol, Trade.DateTime)
).subquery()

d1 = aliased(deltas)
d2 = aliased(deltas)

# buys - w average price based on lagging price
buy1 = (
    session.query(
        d1.c.id.label('id'), d1.c.asset.label('asset'), d1.c.delta.label('delta'), 
        buys.c.buy_price.label('buy_price'),
        (
            (buys.c.basis + d1.c.prev_balance * func.lag(buys.c.buy_price, 1, 0).over(partition_by=d1.c.asset, order_by=d1.c.id)) / d1.c.balance).label('avg_cost_price')
    )
    .outerjoin(d1, d1.c.id==buys.c.id)
    .order_by(d1.c.asset, d1.c.id)
).cte()


merge = (
    session.query(
        d2.c.id.label('id'), d2.c.asset.label('asset'), d2.c.delta.label('delta'),
        func.max(buy1.c.avg_cost_price).over(partition_by=buy1.c.asset, order_by=buy1.c.id),
        d2.c.delta * (sells.c.sell_price - func.max(buy1.c.avg_cost_price).over(partition_by=buy1.c.asset, order_by=buy1.c.id)), 
        sells.c.pnl
    )
    .join(sells, sells.c.id==d2.c.id)
    .filter(buy1.c.id < d2.c.id, buy1.c.asset==d2.c.asset)
    .order_by(d2.c.asset, d2.c.id)
    .distinct()
)
for row in merge:
    print(row)

In [None]:
delta_q = (
    db.session.query(
        Trade.id.label('id'),
        Trade.Quantity.label('delta'),
        func.sum(Trade.Quantity).over(partition_by=Trade.Symbol, order_by=Trade.DateTime).label('balance'),
    )
    .group_by(Trade.Symbol, Trade.DateTime)
    .subquery()
)

forex_basis = (
    db.session.query(
        Trade.id.label('id'),
        #Trade.DateTime.label('dt'),
        Trade.Symbol.label('asset'),
        Trade.Basis.label('basis'),
        (Trade.Basis * Trade.QuoteInLocalCurrency).label('basis_local'),
        delta_q.c.delta.label('delta'),
        delta_q.c.balance.label('balance')
    )
    .outerjoin(delta_q, delta_q.c.id==Trade.id)
    .filter(Trade.Quantity > Decimal(0))
    .order_by(Trade.Symbol, Trade.DateTime)
)  # todo add the weighted basis_local
for row in forex_basis:
    print(row)