In [None]:
import time
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from api_util import read_token, BASE_URL, get_headers
import requests

token = read_token()

def _make_get_request(endpoint, params=None, quiet=None):
    """共用的 HTTP GET 請求處理函數"""
    try:
        response = requests.get(f"{BASE_URL}/{endpoint}", headers=get_headers(token), params=params)
        response.raise_for_status()
        result = response.json()
        if not quiet: 
            print(result)
        return result
    except requests.exceptions.RequestException as err:
        print(f"Request failed: {err}")
        return None
    
def _make_post_request(endpoint, data, quiet=None):
	try:
		response = requests.post(f"{BASE_URL}/{endpoint}", headers=get_headers(token), json=data)
		response.raise_for_status()
		result = response.json()
		if not quiet: 
			print(result)
		return result
	except requests.exceptions.RequestException as err:
		print(f"Request failed: {err}")
		return None

def get_stock_prices():
    res = _make_get_request('market/stock/prices')
    return res.get('data')

def get_stock_symbols():
    res = _make_get_request('market/stock/symbols')
    return res.get('data')

def clone_db_transactions():
    res = _make_get_request('transactions/clone')
    return res

def get_all_transactions():
    res = _make_get_request('transactions')
    return res

def delete_transactions(id):
	try:
		response = requests.delete(f"{BASE_URL}/transactions/{id}", headers=get_headers(token))
		response.raise_for_status()
		result = response.json()
		print(result)
		return result
	except requests.exceptions.RequestException as err:
		print(f"Request failed: {err}")
		return None


def parse_date(date_str):
    date_format = '%Y-%m-%d'
    try:
        date = datetime.strptime(date_str, date_format)
        return date.strftime(date_format)
    except ValueError:
        return None

def record_my_transactions(transactions):
    try:
        for stock_id, transaction_type, quantity, price, transaction_date in transactions:
            data = {
                'stock_id': stock_id,
                'transaction_type': transaction_type,
                'quantity': quantity,
                'price': price,
                'transaction_date': parse_date(transaction_date) or datetime.now().strftime('%Y-%m-%d')
            }
            response = requests.post(f'{BASE_URL}/transactions', headers=get_headers(token), json=data)
            if response.status_code in range(200, 300):
                print(f'Successfully inserted: {data}')
            else:
                print(f'Failed to insert: {data}, Status code: {response.status_code}, Response: {response.text}')
            time.sleep(0.5)
    except requests.exceptions.RequestException as err:
        print(err)

def get_my_portfolio():
    res = _make_get_request('portfolio')
    return res.get('data')

def update_my_portfolio(params):
    res = _make_post_request('portfolio', params)
    return res

def draw_pie_chart(obj):
	labels = obj.keys()
	sizes = obj.values()
	colors = plt.get_cmap('Set2').colors
	explode = [0.1] * len(labels)

	plt.figure(figsize=(5, 5))
	plt.pie(sizes, explode=explode, labels=labels, autopct='%1.1f%%', colors=colors, startangle=140)
	plt.title('Portfolio Distribution by Stock')
	plt.show()


In [None]:
transactions = [
	# ('NET', 'sell', 2, 175.9, '2025-06-05'),
]

In [None]:
record_my_transactions(transactions)
transactions = []

In [None]:
symbols = get_stock_symbols()
symbols

In [None]:
from decimal import Decimal, ROUND_HALF_UP

my_portfolio = get_my_portfolio()
stock_prices = get_stock_prices()
current_price_map = {item['symbol']: item['price'] for item in stock_prices}

PORTFOLIO_USD = 16478 + 9176

def validate_portfolio_item(item):
    required_keys = ['stock_id', 'quantity', 'average_price']
    if not all(key in item for key in required_keys):
        raise ValueError(f"Invalid portfolio item: {item}")
    try:
        quantity = float(item['quantity'])
        average_price = float(item['average_price'])
        if quantity < 0 or average_price < 0:
            raise ValueError(f"Negative quantity or price in item: {item}")
    except (ValueError, TypeError):
        raise ValueError(f"Invalid data type in item: {item}")
    
def process_portfolio(portfolio_dict, usd_value):
    portfolio_dict['USD'] = usd_value
    return dict(sorted(
        {k: v for k, v in portfolio_dict.items() if v > 0}.items(),
        key=lambda item: item[1],
        reverse=True
    ))

portfolio_values = {}
portfolio_with_current_price = {}

for item in my_portfolio:
    validate_portfolio_item(item)
    stock_id = item['stock_id']
    quantity = Decimal(str(item['quantity']))
    average_price = Decimal(str(item['average_price']))
    current_price = Decimal(str(current_price_map.get(stock_id, item['average_price'])))
    
    # 計算成本價值
    total_cost = (quantity * average_price).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
    portfolio_values[stock_id] = portfolio_values.get(stock_id, 0) + float(total_cost)
    
    # 計算當前市場價值
    total_current = (quantity * current_price).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
    portfolio_with_current_price[stock_id] = portfolio_with_current_price.get(stock_id, 0) + float(total_current)
    
    print(f"{stock_id}: {float(quantity):>6.2f} * ${float(average_price):>8.2f} = ${float(total_cost):>10.2f}")

# 處理和排序投資組合
sorted_portfolio_values = process_portfolio(portfolio_values, PORTFOLIO_USD)
sorted_current_values = process_portfolio(portfolio_with_current_price, PORTFOLIO_USD)

print('成本：')
draw_pie_chart(sorted_portfolio_values)
print('現在價位：')
draw_pie_chart(sorted_current_values)

In [None]:

data = {
	# 'stock_id': 'TSLA',
	# 'average_price': 348,
	# 'quantity': 9
}
update_my_portfolio(data)

In [None]:
clone_db_transactions()

In [None]:
get_all_transactions()

In [None]:
delete_transactions()