In [None]:
"""
Tag and categorize the historical expenses from splitwise
The expense data is cached locally in expenses.json.
The tags are stored locally in tags.json in the form of
{
  "wants": [
                str,
                {
                  "keyword": str,
                  "comment": str,
                  "match_word": true,
                  "match_case": false,
                  "ignore_id": [list of expense id to ignore for this],
                  "include_id": [list of expense id to include for this]
                },
           ]
}
PS: There's no ML. Its just honest labour to categorize based on ones own thoughts
"""

In [None]:
import json
import os
from datetime import datetime, timedelta

import dotenv
import pytz
import requests
from dateutil.parser import parse

from utils import get_unique_by_key

dotenv.load_dotenv()

BASE_URL = "https://secure.splitwise.com/api/v3.0/get_expenses"
HEADERS = {"Authorization": f"Bearer {os.getenv('SPLITWISE_TOKEN')}"}


def fetch_expenses(start_date, end_date, limit=200):
    params = {
        "dated_after": start_date.isoformat() + "Z",
        "dated_before": end_date.isoformat() + "Z",
        "limit": limit,
    }
    response = requests.get(BASE_URL, headers=HEADERS, params=params)
    response.raise_for_status()  # Ensure no enigmatic errors
    return response.json()


# Embark on fetching data in monthly batches
def get_all_expenses(_start_date, _end_date):
    cached_expenses = read_from_file()
    oldest_expenses = max(cached_expenses, key=lambda x: parse(x["date"])) if len(cached_expenses) > 0 else None
    _start_date = parse(oldest_expenses["date"]) + timedelta(seconds=1) if oldest_expenses else _start_date
    _all_expenses = cached_expenses
    while _start_date <= _end_date:
        next_date = _start_date + timedelta(days=30)  # Fetch monthly
        print(f"Fetching expenses from {_start_date=} to {next_date}...")
        try:
            monthly_expenses = fetch_expenses(_start_date, next_date)
            _all_expenses.extend(monthly_expenses.get("expenses", []))
        except requests.exceptions.RequestException as e:
            print(f"Error fetching data: {e}")
            break
        _start_date = next_date + timedelta(seconds=1)

    return get_unique_by_key(_all_expenses, key="id")


# Save the combined mosaic to a JSON file
def save_to_file(data, filename="expenses.json"):
    data = sorted(data, key=lambda x: x["date"])
    with open(filename, "w") as file:
        json.dump(data, file, indent=4)


# Read the verdant data back from the file
def read_from_file(filename="expenses.json"):
    if not os.path.exists(filename):
        return []
    with open(filename, "r") as file:
        return json.load(file)


In [None]:
START_DATE = datetime(2022, 4, 1, tzinfo=pytz.timezone("Asia/Kolkata"))
END_DATE = datetime(2025, 7, 31, tzinfo=pytz.timezone("Asia/Kolkata"))
all_expenses = get_all_expenses(_start_date=START_DATE, _end_date=END_DATE)

In [None]:
len(all_expenses)

In [None]:
save_to_file(all_expenses)

In [None]:
# Verify by reading back
loaded_expenses = read_from_file()
len(loaded_expenses)

In [None]:
from dateutil import parser, tz
from typing import Optional
from requests import HTTPError

USER_ID: Optional[str] = None


def get_user_id() -> str:
    """
    Get the splitwise user id
    :return:
    """
    global USER_ID

    if USER_ID:
        return USER_ID

    url = "https://secure.splitwise.com/api/v3.0/get_current_user"

    headers = {
        'Authorization': f'Bearer {os.getenv("SPLITWISE_TOKEN")}'
    }

    response = requests.request("GET", url, headers=headers, data={})

    if response.status_code == 200:
        USER_ID = response.json()["user"]["id"]
    else:
        raise HTTPError(f'Invalid Notion response {response.status_code} {response.text}', response=response)

    return USER_ID


items = []
for item in loaded_expenses:
    created = parser.parse(item['date'])
    deleted = item['deleted_at']
    name = item['description'].strip()
    if deleted:
        continue

    if name == 'Payment':
        continue
    if name == 'Settle all balances':
        continue
    created = created.astimezone(tz.tzlocal())
    result = {
        "id": item["id"],
        "date": created.strftime("%Y-%m-%d"),
        "name": name
    }
    for user in item['users']:
        if user['user_id'] == get_user_id():
            result["cost"] = float(user['owed_share'].strip())
            items.append(result)

In [None]:
len(items)
items[0]

In [None]:
from collections import defaultdict

all_expense_names_count = defaultdict(int)
for item in items:
    all_expense_names_count[item['name'].lower()] += 1
dict(sorted(all_expense_names_count.items(), key=lambda item: -item[1]))

In [None]:
import re

TAG_MAPPING = None


def get_tags():
    global TAG_MAPPING

    if TAG_MAPPING:
        return TAG_MAPPING

    TAG_MAPPING = read_from_file('tags.json') or {}
    return TAG_MAPPING


def tag_expense(expense_name, _id):
    """
    Assign tags to an expense based on its name.

    Args:
        expense_name (str): The name of the expense.

    Returns:
        list: A list of tags that match the expense.
    """
    tags = []

    for tag, keywords in get_tags().items():
        char_match = [keyword for keyword in keywords if isinstance(keyword, str)]
        if any(re.search(keyword.lower(), expense_name.lower()) for keyword in char_match):
            tags.append(tag)
            continue

        # Structured matches (dict)
        for kw in [k for k in keywords if isinstance(k, dict)]:
            pattern = kw["keyword"]
            word_match = kw.get("match_word", False)
            case_sensitive = kw.get("match_case", False)
            ignore_id = kw.get("ignore_id", [])
            include_id = kw.get("include_id", [])

            if _id in ignore_id:
                continue

            if include_id and _id not in include_id:
                continue

            if not case_sensitive:
                pattern = pattern.lower()
                expense_name = expense_name.lower()

            if word_match:
                regex = rf"\b{re.escape(pattern)}\b"
            else:
                regex = re.escape(pattern)

            # print(regex, expense_name)
            if re.search(regex, expense_name):
                tags.append(tag)
                break

    return tags if tags else ["other"]
    # return tags

In [None]:
unclassified = defaultdict(int)

for expense in items:
    # if not expense["name"] == "Rent":
    #     continue
    tags = tag_expense(expense["name"], expense["id"])
    if not tags or tags == ["other"]:
        print(expense)
        unclassified[expense['name']] += 1
    expense["tags"] = tags

dict(sorted(unclassified.items(), key=lambda item: -item[1]))
unclassified.keys()

In [None]:
sorted([expense for expense in items if "trip" in expense["tags"] and expense["date"].startswith("2024-04")],
       key=lambda item: -item["cost"])

In [None]:
print(sum([expense['cost'] for expense in items if "trip" in expense["tags"] and expense["date"].startswith("2024-04")]))

In [None]:
import pandas as pd

df = pd.DataFrame(items)

# Parse dates and extract the month
df['date'] = pd.to_datetime(df['date'])
df['month'] = df['date'].dt.to_period('M')
df = df.explode('tags')
monthly_data = df.groupby(['month', 'tags'])['cost'].sum().reset_index()
pivot_data = monthly_data.pivot(index='month', columns='tags', values='cost').fillna(0)
pivot_data

In [None]:
pivot_data.drop(columns=['other', 'ignore'], inplace=True)

In [None]:
# Optionally remove extreme outliers. Caution: This will lead to data loss but will print better charts
import matplotlib.pyplot as plt
import numpy as np

# Calculate the 95th percentile for all data and cap values above the threshold
threshold = np.quantile(pivot_data.values.flatten(), 0.99)
print(f"Modifying any value above {threshold:.2f} to {threshold:.2f}. Cells that will be capped:")
for row_label, row in pivot_data.iterrows():
    for col_label, value in row.items():
        if value > threshold:
            print(f"  Row: {row_label}, Column: {col_label}, Value: {value}")
pivot_data = pivot_data.clip(upper=threshold)

In [None]:
from scipy.ndimage import gaussian_filter1d

plt.figure(figsize=(15, 9), dpi=300)

# Smoothen different columns differently based on their name.
smoothing_sigma = {
    'needs': 2.5,
    'parents': 2.5,
    'trip': 2,
}

for tag in pivot_data.columns:
    # Apply Gaussian smoothing to emulate smooth curves. Kinda like weighted moving average.
    sigma = smoothing_sigma.get(tag, 1.5)  # fallback to sigma=2 if not specified
    smoothed_y = gaussian_filter1d(pivot_data[tag], sigma=sigma)

    # Plot the smoothed data
    plt.plot(pivot_data.index.to_timestamp(), smoothed_y, label=tag)

plt.title('Monthly Aggregate Cost')
plt.xlabel('Month')
plt.ylabel('Total Cost')
plt.legend(title='Tags')
plt.minorticks_on()
# Major grid (less frequent, light gray)
plt.grid(which='major', linestyle='-', linewidth=0.6, alpha=0.75)

# Minor grid (more frequent, even lighter)
plt.grid(which='minor', linestyle=':', linewidth=0.45, alpha=0.45)
# plt.grid()
plt.tight_layout()
plt.savefig("transactions.png", dpi=300)
plt.show()