<a href="https://colab.research.google.com/github/mridul-sahu/advance_tax_calculations/blob/main/Stocks_Calculations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import pandas as pd
import numpy as np
from typing import Tuple, Dict, Any

# ==============================================================================
# --- 1. CONFIGURATION & HELPERS ---
# ==============================================================================

# NOTE: A single rate is used for simplicity. For actual tax filing in India,
# you must use the date-specific Telegraphic Transfer Buying Rate (TTBR).
USD_TO_INR_RATE = 83.50

def clean_currency(value: Any) -> float:
    """Removes currency symbols and commas from a string, then converts to float."""
    if isinstance(value, str):
        return float(value.replace('$', '').replace(',', ''))
    return float(value)

# ==============================================================================
# --- 2. DATA LOADING AND PREPARATION ---
# ==============================================================================

def load_and_clean_data(sales_file: str, acq_file: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Loads data from the sales and acquisition CSV files, cleans column names,
    converts data types, and sorts the data chronologically.
    """
    try:
        sales_df = pd.read_csv(sales_file, skiprows=2, skipfooter=1, engine='python')
        acq_df = pd.read_csv(acq_file, skiprows=1, skipfooter=1, engine='python')
    except FileNotFoundError as e:
        raise FileNotFoundError(f"File not found: {e}. Ensure CSVs are in the correct path.") from e

    # Remove any completely empty rows that might have been read
    sales_df.dropna(how='all', inplace=True)
    acq_df.dropna(how='all', inplace=True)

    # --- Standardize Column Names ---
    sales_df.columns = ['Sale_Date', 'Sale_Price', 'Shares_Sold', 'Symbol', 'Gross_Proceeds', 'Acquisition_Date_in_Report']
    acq_df.columns = ['Vest_Date', 'Order_Number', 'Plan', 'Type', 'Status', 'Acquisition_Price', 'Quantity', 'Net_Cash_Proceeds', 'Shares_Acquired', 'Tax_Payment_Method']

    # --- Convert Data Types and Clean Values ---
    for col in ['Sale_Price', 'Gross_Proceeds', 'Acquisition_Price', 'Net_Cash_Proceeds']:
        df = sales_df if col in sales_df.columns else acq_df
        df[col] = df[col].apply(clean_currency)

    sales_df['Shares_Sold'] = pd.to_numeric(sales_df['Shares_Sold'])
    acq_df['Shares_Acquired'] = pd.to_numeric(acq_df['Shares_Acquired'])

    # Convert dates, inferring format but handling potential variations
    sales_df['Sale_Date'] = pd.to_datetime(sales_df['Sale_Date'])
    acq_df['Vest_Date'] = pd.to_datetime(acq_df['Vest_Date'])

    # Sort dataframes by date, which is crucial for FIFO logic
    return sales_df.sort_values(by='Sale_Date').reset_index(drop=True), acq_df.sort_values(by='Vest_Date').reset_index(drop=True)

# ==============================================================================
# --- 3. CORE CALCULATION LOGIC ---
# ==============================================================================

def perform_fifo_matching(sales_df: pd.DataFrame, acq_df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Matches sales to acquisitions using the First-In, First-Out (FIFO) method.
    This function tracks remaining shares from each acquisition lot and calculates
    the profit/loss for every transaction.
    """
    # Prepare the acquisition data for tracking
    acquisitions_info = acq_df[['Vest_Date', 'Acquisition_Price', 'Shares_Acquired']].copy()
    acquisitions_info.rename(columns={'Vest_Date': 'Acquisition_Date'}, inplace=True)
    acquisitions_info['Remaining_Shares'] = acquisitions_info['Shares_Acquired']

    results_list = []
    current_acq_index = 0  # Pointer to the current acquisition lot to draw from

    # Iterate through each sale transaction
    for sale in sales_df.itertuples():
        shares_to_match = sale.Shares_Sold

        # Iterate through acquisition lots, starting from the oldest available
        for acq_index in range(current_acq_index, len(acquisitions_info)):
            if shares_to_match <= 1e-4:  # If sale is fully matched, stop
                break

            # Ensure we don't use shares acquired after the sale date
            if acquisitions_info.loc[acq_index, 'Acquisition_Date'] > sale.Sale_Date:
                continue

            # Determine how many shares can be sold from the current acquisition lot
            shares_from_lot = min(shares_to_match, acquisitions_info.loc[acq_index, 'Remaining_Shares'])

            if shares_from_lot > 0:
                acq = acquisitions_info.loc[acq_index]
                profit_loss = (sale.Sale_Price - acq.Acquisition_Price) * shares_from_lot
                holding_duration = (sale.Sale_Date - acq.Acquisition_Date).days

                results_list.append({
                    'Sale_Date': sale.Sale_Date,
                    'Shares_Sold': shares_from_lot,
                    'Profit/Loss (USD)': profit_loss,
                    'Gain_Type': 'LTCG' if holding_duration > 730 else 'STCG'
                })

                # Decrement remaining shares and the shares that still need to be matched
                acquisitions_info.loc[acq_index, 'Remaining_Shares'] -= shares_from_lot
                shares_to_match -= shares_from_lot

            # If an acquisition lot is fully depleted, move the pointer to the next one
            if acquisitions_info.loc[acq_index, 'Remaining_Shares'] <= 1e-4:
                current_acq_index = acq_index + 1

    summary_df = pd.DataFrame(results_list)
    summary_df['Profit/Loss (INR)'] = summary_df['Profit/Loss (USD)'] * USD_TO_INR_RATE
    return summary_df, acquisitions_info

def calculate_tax_liability(df: pd.DataFrame) -> Dict[str, Any]:
    """
    Calculates total tax liability based on correct Indian tax set-off rules.
    This is the core of the tax compliance logic.
    """
    stcg_tax_rate, ltcg_tax_rate, surcharge_rate, cess_rate = 0.30, 0.125, 0.15, 0.04

    # Step 1: Separate all gains and losses into four distinct categories
    stcg = df[(df['Gain_Type'] == 'STCG') & (df['Profit/Loss (INR)'] > 0)]['Profit/Loss (INR)'].sum()
    stcl = abs(df[(df['Gain_Type'] == 'STCG') & (df['Profit/Loss (INR)'] < 0)]['Profit/Loss (INR)'].sum())
    ltcg = df[(df['Gain_Type'] == 'LTCG') & (df['Profit/Loss (INR)'] > 0)]['Profit/Loss (INR)'].sum()
    ltcl = abs(df[(df['Gain_Type'] == 'LTCG') & (df['Profit/Loss (INR)'] < 0)]['Profit/Loss (INR)'].sum())

    # --- Step 2: Apply Official Set-Off Rules ---
    # a) Long-term losses can ONLY be set off against long-term gains.
    ltcg_after_ltcl = max(0, ltcg - ltcl)

    # b) Short-term losses can be set off against both short-term and long-term gains.
    stcg_after_stcl = max(0, stcg - stcl)
    stcl_remaining_after_stcg = max(0, stcl - stcg)
    ltcg_after_all_setoffs = max(0, ltcg_after_ltcl - stcl_remaining_after_stcg)

    net_taxable_ltcg = ltcg_after_all_setoffs
    net_taxable_stcg = stcg_after_stcl

    # --- Step 3: Calculate Final Tax on Net Gains ---
    total_base_tax = (net_taxable_stcg * stcg_tax_rate) + (net_taxable_ltcg * ltcg_tax_rate)
    total_surcharge = total_base_tax * surcharge_rate
    total_cess = (total_base_tax + total_surcharge) * cess_rate
    total_tax_liability = total_base_tax + total_surcharge + total_cess

    return {
        "stcg": stcg, "stcl": stcl, "ltcg": ltcg, "ltcl": ltcl,
        "net_taxable_stcg": net_taxable_stcg, "net_taxable_ltcg": net_taxable_ltcg,
        "total_base_tax": total_base_tax, "total_surcharge": total_surcharge,
        "total_cess": total_cess, "total_tax_liability": total_tax_liability
    }

def calculate_advance_tax_schedule(summary_df: pd.DataFrame) -> pd.DataFrame:
    """
    Calculates the advance tax installments using the correct "cumulative" method.
    It assesses the total tax liability at the end of each quarter and determines
    the payment needed for that installment.
    """
    fy_start_year = pd.Timestamp.now().year if pd.Timestamp.now().month >= 4 else pd.Timestamp.now().year - 1

    # Define quarter-end dates for calculation
    q_end_dates_str = [f'15-06-{fy_start_year}', f'15-09-{fy_start_year}', f'15-12-{fy_start_year}', f'31-03-{fy_start_year + 1}']
    q_ends = [pd.to_datetime(d, format='%d-%m-%Y') for d in q_end_dates_str]

    # Define the actual payment due dates
    due_dates_str = [f'15-06-{fy_start_year}', f'15-09-{fy_start_year}', f'15-12-{fy_start_year}', f'15-03-{fy_start_year + 1}']
    due_dates = [pd.to_datetime(d, format='%d-%m-%Y').date() for d in due_dates_str]

    # Calculate CUMULATIVE tax liability as of each quarter end
    cum_tax_q1 = calculate_tax_liability(summary_df[summary_df['Sale_Date'] <= q_ends[0]])['total_tax_liability']
    cum_tax_q2 = calculate_tax_liability(summary_df[summary_df['Sale_Date'] <= q_ends[1]])['total_tax_liability']
    cum_tax_q3 = calculate_tax_liability(summary_df[summary_df['Sale_Date'] <= q_ends[2]])['total_tax_liability']
    cum_tax_q4 = calculate_tax_liability(summary_df[summary_df['Sale_Date'] <= q_ends[3]])['total_tax_liability']

    # Calculate the amount to be PAID in each installment
    paid_so_far = 0
    payment_q1 = cum_tax_q1 * 0.15; paid_so_far += payment_q1
    payment_q2 = (cum_tax_q2 * 0.45) - paid_so_far; paid_so_far += payment_q2
    payment_q3 = (cum_tax_q3 * 0.75) - paid_so_far; paid_so_far += payment_q3
    payment_q4 = (cum_tax_q4 * 1.00) - paid_so_far

    schedule = pd.DataFrame({
        'Installment Due Date': due_dates,
        'Amount to Pay (INR)': [payment_q1, payment_q2, payment_q3, payment_q4]
    })

    # Ensure no negative payments are shown (can happen if losses offset previous gains)
    schedule['Amount to Pay (INR)'] = schedule['Amount to Pay (INR)'].clip(lower=0)
    return schedule

# ==============================================================================
# --- 4. VALIDATION ---
# ==============================================================================

def perform_validations(sales_df: pd.DataFrame, acq_df: pd.DataFrame, summary_df: pd.DataFrame, acq_status_df: pd.DataFrame, tax_data: Dict) -> Dict:
    """
    Runs a comprehensive set of checks before and after calculations and
    prints a clear, final validation report.
    """
    print("\n--- Running Final Calculation Validations ---")

    # --- Input Data Sanity Checks ---
    sanity_errors = False
    if sales_df.empty:
        print("🟡 Sanity Check: Capital Gains Report is empty."); sanity_errors = True
    if acq_df.empty:
        print("❌ Sanity Check Fail: Releases Report is empty, cannot process sales."); sanity_errors = True
    if (sales_df['Shares_Sold'] < 0).any():
        print("❌ Sanity Check Fail: Negative values found in 'Shares_Sold'."); sanity_errors = True
    if (acq_df['Shares_Acquired'] < 0).any():
        print("❌ Sanity Check Fail: Negative values in 'Shares_Acquired'."); sanity_errors = True

    # --- Post-Calculation Checks ---
    # 1. Share Count Match
    total_shares_sold_original = sales_df['Shares_Sold'].sum()
    total_shares_sold_summary = summary_df['Shares_Sold'].sum()
    share_match = np.isclose(total_shares_sold_original, total_shares_sold_summary)
    print(f"Share Count Match (Original vs Summary): {'✅ Pass' if share_match else '❌ Fail'}")

    # 2. Overselling Check
    oversold = (acq_status_df['Remaining_Shares'] < -1e-4).any()
    print(f"Overselling Check (No negative shares): {'✅ Pass' if not oversold else '❌ Fail'}")

    # 3. Tax Calculation Integrity
    tax_check = np.isclose(
        tax_data['total_base_tax'] + tax_data['total_surcharge'] + tax_data['total_cess'],
        tax_data['total_tax_liability']
    )
    print(f"Tax Calculation Integrity (Components Sum to Total): {'✅ Pass' if tax_check else '❌ Fail'}")
    print("-----------------------------------------")

    return {
        "Sanity Checks": 'Pass' if not sanity_errors else 'Fail',
        "Share Match": 'Pass' if share_match else 'Fail',
        "Overselling": 'Pass' if not oversold else 'Fail',
        "Tax Integrity": 'Pass' if tax_check else 'Fail',
    }

# ==============================================================================
# --- 5. EXCEL REPORT GENERATION ---
# ==============================================================================

def generate_excel_report(summary_df, acq_status_df, tax_data, schedule_df, sales_df, acq_df, validation_results, output_file):
    """Writes all the calculated dataframes to a formatted, multi-sheet Excel file."""
    set_off_summary = pd.DataFrame({
        'Category': ['Short-Term', 'Long-Term'],
        'Gross Gains (INR)': [tax_data['stcg'], tax_data['ltcg']],
        'Gross Losses (INR)': [tax_data['stcl'], tax_data['ltcl']],
        'Net Taxable Gains (INR)': [tax_data['net_taxable_stcg'], tax_data['net_taxable_ltcg']]
    })

    tax_summary = pd.DataFrame({
        'Description': ['Total Base Tax', 'Total Surcharge @ 15%', 'Total Health & Education Cess @ 4%', 'TOTAL TAX LIABILITY'],
        'Amount (INR)': [tax_data['total_base_tax'], tax_data['total_surcharge'], tax_data['total_cess'], tax_data['total_tax_liability']]
    })

    notes_list = [
        ("NOTES:", ""),
        ("1. Exchange Rate:", f"A single rate of {USD_TO_INR_RATE} has been used for simplicity."),
        ("", "For actual tax filing, use date-specific TTBR for each transaction."),
        ("2. Surcharge:", "A 15% rate is assumed (income > ₹1 Cr & ≤ ₹2 Cr). Adjust if needed."),
        ("", ""),
        ("VALIDATION SUMMARY:", ""),
        ("Input Sanity Checks:", f"{validation_results['Sanity Checks']}"),
        ("Share Count Match:", f"{validation_results['Share Match']}"),
        ("Overselling Check:", f"{validation_results['Overselling']}"),
        ("Tax Integrity Check:", f"{validation_results['Tax Integrity']}"),
    ]
    notes_df = pd.DataFrame(notes_list)

    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        summary_df.to_excel(writer, sheet_name='Profit Loss Summary', index=False)

        # --- Build Tax Calculation Sheet ---
        current_row = 0
        pd.DataFrame([['TAX SET-OFF CALCULATION']]).to_excel(writer, sheet_name='Tax Calculation', index=False, header=False, startrow=current_row); current_row += 1
        set_off_summary.to_excel(writer, sheet_name='Tax Calculation', index=False, startrow=current_row); current_row += len(set_off_summary) + 2

        pd.DataFrame([['FINAL TAX LIABILITY']]).to_excel(writer, sheet_name='Tax Calculation', index=False, header=False, startrow=current_row); current_row += 1
        tax_summary.to_excel(writer, sheet_name='Tax Calculation', index=False, startrow=current_row); current_row += len(tax_summary) + 2

        pd.DataFrame([['ADVANCE TAX PAYMENT SCHEDULE (Cumulative Method)']]).to_excel(writer, sheet_name='Tax Calculation', index=False, header=False, startrow=current_row); current_row += 1
        schedule_df.to_excel(writer, sheet_name='Tax Calculation', index=False, startrow=current_row); current_row += len(schedule_df) + 2

        notes_df.to_excel(writer, sheet_name='Tax Calculation', index=False, header=False, startrow=current_row)

        sales_df.to_excel(writer, sheet_name='Original Sales Report', index=False)
        acq_df.to_excel(writer, sheet_name='Original Releases Report', index=False)
        acq_status_df.to_excel(writer, sheet_name='Acquisition Lot Status', index=False)

    print(f"\n✅ Success! The '{output_file}' has been created with all validations.")

# ==============================================================================
# --- 6. MAIN EXECUTION ---
# ==============================================================================

def main():
    """Main function to run the entire capital gains and tax calculation process."""
    capital_gains_file = 'Capital Gains Report.csv'
    releases_file = 'Releases Report.csv'
    output_excel_file = 'capital_gains_summary_final.xlsx'

    try:
        # Step 1: Load and clean the data
        sales_df, acq_df = load_and_clean_data(capital_gains_file, releases_file)

        # Step 2: Perform calculations
        summary_df, acq_status_df = perform_fifo_matching(sales_df, acq_df)
        tax_data = calculate_tax_liability(summary_df)
        advance_tax_schedule = calculate_advance_tax_schedule(summary_df)

        # Step 3: Run all validations and print report
        validation_results = perform_validations(sales_df, acq_df, summary_df, acq_status_df, tax_data)

        # Step 4: Generate the final Excel report
        generate_excel_report(
            summary_df, acq_status_df, tax_data, advance_tax_schedule,
            sales_df, acq_df, validation_results, output_excel_file
        )
    except (FileNotFoundError, ValueError, KeyError, Exception) as e:
        print(f"\n❌ An unexpected error occurred: {e}")
        print("Please check your input files and the script configuration.")

In [15]:
main()


--- Running Final Calculation Validations ---
Share Count Match (Original vs Summary): ✅ Pass
Overselling Check (No negative shares): ✅ Pass
Tax Calculation Integrity (Components Sum to Total): ✅ Pass
-----------------------------------------

✅ Success! The 'capital_gains_summary_final.xlsx' has been created with all validations.
