Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,5 @@ cython_debug/
# VS Code
.vscode/

# uv local dev dependencies
uv.lock
pyproject.toml
# Dev dependencies
ruff.toml
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
author_email = 'guillaume@pommepause.com',
url = 'https://github.com/gboudreau/ws-api-python',
keywords = ['wealthsimple'],
python_requires='>=3.10',
install_requires = [
'requests',
],
Expand All @@ -19,5 +20,6 @@
'Operating System :: OS Independent',
'License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)',
'Programming Language :: Python :: 3',
"Programming Language :: Python :: 3.13"
],
)
8 changes: 4 additions & 4 deletions ws_api/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
import json
from abc import ABC


class OAuthSession(ABC):
Expand All @@ -13,7 +13,7 @@ class OAuthSession(ABC):
"""

def __init__(self):
self.client_id = None
self.client_id: str | None = None
self.access_token = None
self.refresh_token = None

Expand All @@ -31,8 +31,8 @@ class WSAPISession(OAuthSession):

def __init__(self):
super().__init__()
self.session_id = None
self.wssdi = None
self.session_id: str | None = None
self.wssdi: str | None = None
self.token_info = None

def to_json(self):
Expand Down
111 changes: 60 additions & 51 deletions ws_api/wealthsimple_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re
import uuid
from collections.abc import Callable
from datetime import datetime, timedelta
from inspect import signature
from typing import Any, Callable, Optional, Union
from typing import Any

import requests

Expand Down Expand Up @@ -37,24 +38,24 @@ class WealthsimpleAPIBase:
'FetchBrokerageMonthlyStatementTransactions': "query FetchBrokerageMonthlyStatementTransactions($period: String!, $accountId: String!) {\n brokerageMonthlyStatements(period: $period, accountId: $accountId) {\n id\n statementType\n createdAt\n data {\n ... on BrokerageMonthlyStatementObject {\n ...BrokerageMonthlyStatementObject\n __typename\n }\n __typename\n }\n __typename\n }\n}\n\nfragment BrokerageMonthlyStatementObject on BrokerageMonthlyStatementObject {\n custodianAccountId\n activitiesPerCurrency {\n currency\n currentTransactions {\n ...BrokerageMonthlyStatementTransactions\n __typename\n }\n __typename\n }\n currentTransactions {\n ...BrokerageMonthlyStatementTransactions\n __typename\n }\n isMultiCurrency\n __typename\n}\n\nfragment BrokerageMonthlyStatementTransactions on BrokerageMonthlyStatementTransactions {\n balance\n cashMovement\n unit\n description\n transactionDate\n transactionType\n __typename\n}",
}

def __init__(self, sess: Optional[WSAPISession] = None):
def __init__(self, sess: WSAPISession | None = None):
self.security_market_data_cache_getter = None
self.security_market_data_cache_setter = None
self.session = WSAPISession()
self.start_session(sess)

user_agent = None
user_agent: str | None = None

@staticmethod
def set_user_agent(user_agent: str):
def set_user_agent(user_agent: str) -> None:
WealthsimpleAPI.user_agent = user_agent

@staticmethod
def uuidv4() -> str:
return str(uuid.uuid4())

def send_http_request(
self, url: str, method: str = 'POST', data: Optional[dict] = None, headers: Optional[dict] = None, return_headers: bool = False
self, url: str, method: str = 'POST', data: dict | None = None, headers: dict | None = None, return_headers: bool = False
) -> Any:
headers = headers or {}
if method == 'POST':
Expand All @@ -77,20 +78,20 @@ def send_http_request(

if return_headers:
# Combine headers and body as a single string
headers = '\r\n'.join(f"{k}: {v}" for k, v in response.headers.items())
return f"{headers}\r\n\r\n{response.text}"
response_headers = '\r\n'.join(f"{k}: {v}" for k, v in response.headers.items())
return f"{response_headers}\r\n\r\n{response.text}"

return response.json()
except requests.exceptions.RequestException as e:
raise CurlException(f"HTTP request failed: {e}")

def send_get(self, url: str, headers: Optional[dict] = None, return_headers: bool = False) -> Any:
def send_get(self, url: str, headers: dict | None = None, return_headers: bool = False) -> Any:
return self.send_http_request(url, 'GET', headers=headers, return_headers=return_headers)

def send_post(self, url: str, data: dict, headers: Optional[dict] = None, return_headers: bool = False) -> Any:
def send_post(self, url: str, data: dict, headers: dict | None = None, return_headers: bool = False) -> Any:
return self.send_http_request(url, 'POST', data=data, headers=headers, return_headers=return_headers)

def start_session(self, sess: WSAPISession = None):
def start_session(self, sess: WSAPISession | None = None):
if sess:
self.session.access_token = sess.access_token
self.session.wssdi = sess.wssdi
Expand Down Expand Up @@ -138,16 +139,25 @@ def start_session(self, sess: WSAPISession = None):
if not self.session.session_id:
self.session.session_id = str(uuid.uuid4())

def check_oauth_token(self, persist_session_fct: Optional[Callable] = None, username = None):
def search_security(self, query):
# Fetch security search results using GraphQL query
return self.do_graphql_query(
'FetchSecuritySearchResult',
{'query': query},
'securitySearch.results',
'array',
)

def check_oauth_token(self, persist_session_fct: Callable | None = None, username = None):
if self.session.access_token:
try:
# noinspection PyUnresolvedReferences
self.search_security('XEQT')
return
except WSApiException as e:
if e.response['message'] != 'Not Authorized.':
raise e
if e.response is None or e.response.get('message') != 'Not Authorized.':
raise
# Access token expired; try to refresh it below
else:
return

if self.session.refresh_token:
data = {
Expand All @@ -161,7 +171,7 @@ def check_oauth_token(self, persist_session_fct: Optional[Callable] = None, user
}
response = self.send_post(f"{self.OAUTH_BASE_URL}/token", data, headers)
if 'access_token' not in response or 'refresh_token' not in response:
raise ManualLoginRequired(f"OAuth token invalid and cannot be refreshed: {response['error'] if 'error' in response else 'Invalid response from API'}")
raise ManualLoginRequired(f"OAuth token invalid and cannot be refreshed: {response.get('error', 'Invalid response from API')}")
self.session.access_token = response['access_token']
self.session.refresh_token = response['refresh_token']
if persist_session_fct:
Expand All @@ -179,8 +189,8 @@ def check_oauth_token(self, persist_session_fct: Optional[Callable] = None, user
def login_internal(self,
username: str,
password: str,
otp_answer: Optional[str] = None,
persist_session_fct: Optional[Callable] = None,
otp_answer: str | None = None,
persist_session_fct: Callable | None = None,
scope: str = SCOPE_READ_ONLY
) -> WSAPISession:
data = {
Expand Down Expand Up @@ -228,7 +238,7 @@ def login_internal(self,
return self.session

def do_graphql_query(self, query_name: str, variables: dict, data_response_path: str, expect_type: str,
filter_fn: Optional[Callable[[Any], bool]] = None, *, load_all_pages: bool = False):
filter_fn: Callable[[Any], bool] | None = None, *, load_all_pages: bool = False):
query = {
'operationName': query_name,
'query': self.GRAPHQL_QUERIES[query_name],
Expand Down Expand Up @@ -305,8 +315,8 @@ def get_token_info(self):
def login(
username: str,
password: str,
otp_answer: Optional[str] = None,
persist_session_fct: Optional[Callable] = None,
otp_answer: str | None = None,
persist_session_fct: Callable | None = None,
scope: str = SCOPE_READ_ONLY
) -> WSAPISession:
"""Login to Wealthsimple API and return a session object.
Expand All @@ -329,13 +339,13 @@ def login(
return ws.login_internal(username, password, otp_answer, persist_session_fct, scope)

@staticmethod
def from_token(sess: WSAPISession, persist_session_fct: callable = None, username: Optional[str] = None):
def from_token(sess: WSAPISession, persist_session_fct: Callable | None = None, username: str | None = None):
ws = WealthsimpleAPI(sess)
ws.check_oauth_token(persist_session_fct, username)
return ws

class WealthsimpleAPI(WealthsimpleAPIBase):
def __init__(self, sess: WSAPISession = None):
def __init__(self, sess: WSAPISession | None = None) -> None:
super().__init__(sess)
self.account_cache = {}

Expand Down Expand Up @@ -417,7 +427,7 @@ def get_account_balances(self, account_id):
for account in accounts[0]['custodianAccounts']:
for balance in account['financials']['balance']:
security = balance['securityId']
if security != 'sec-c-cad' and security != 'sec-c-usd':
if security not in {'sec-c-cad', 'sec-c-usd'}:
security = self.security_id_to_symbol(security)
balances[security] = balance['quantity']

Expand Down Expand Up @@ -457,12 +467,12 @@ def get_identity_historical_financials(self, account_ids = None, currency: str =

def get_activities(
self,
account_id: Union[str, list[str]],
account_id: str | list[str],
how_many: int = 50,
order_by: str = 'OCCURRED_AT_DESC',
ignore_rejected: bool = True,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
load_all: bool = False
) -> list[Any]:
"""Retrieve activities for a specific account or list of accounts.
Expand Down Expand Up @@ -587,24 +597,30 @@ def _activity_add_description(self, act):
type_ = act['type'].lower().capitalize()
direction = 'from' if act['type'] == 'DEPOSIT' else 'to'
prop = 'source' if act['type'] == 'DEPOSIT' else 'destination'
bank_account = details[prop]['bankAccount']
nickname = bank_account.get('nickname')
if isinstance(details, dict):
bank_account_info = details.get(prop, {})
if isinstance(bank_account_info, dict):
bank_account = bank_account_info.get('bankAccount', {})
nickname = bank_account.get('nickname')
account_number = bank_account.get('accountNumber')
if not nickname:
nickname = bank_account['accountName']
act['description'] = (
f"{type_}: EFT {direction} {nickname} {bank_account['accountNumber']}"
)
nickname = bank_account.get('accountName')
act['description'] = f"{type_}: EFT {direction} {nickname} {account_number}"

elif act['type'] == 'REFUND' and act['subType'] == 'TRANSFER_FEE_REFUND':
act['description'] = "Reimbursement: account transfer fee"

elif act['type'] == 'INSTITUTIONAL_TRANSFER_INTENT' and act['subType'] == 'TRANSFER_IN':
details = self.get_transfer_details(act['externalCanonicalId'])
verb = details['transferType'].replace('_', '-').capitalize()
if isinstance(details, dict):
verb = details['transferType'].replace('_', '-').capitalize()
client_account_type = details['clientAccountType'].upper()
institution_name = details['institutionName']
redacted_account_number = details['redactedInstitutionAccountNumber']
act['description'] = (
f"Institutional transfer: {verb} {details['clientAccountType'].upper()} "
f"account transfer from {details['institutionName']} "
f"****{details['redactedInstitutionAccountNumber']}"
f"Institutional transfer: {verb} {client_account_type} "
f"account transfer from {institution_name} "
f"****{redacted_account_number}"
)
elif act['type'] == 'INSTITUTIONAL_TRANSFER_INTENT' and act['subType'] == 'TRANSFER_OUT':
act['description'] = (
Expand Down Expand Up @@ -646,8 +662,8 @@ def _activity_add_description(self, act):

elif act['type'] == 'P2P_PAYMENT' and act['subType'] in ('SEND', 'SEND_RECEIVED'):
direction = 'sent to' if act['subType'] == 'SEND' else 'received from'
p2pHandle = act['p2pHandle']
act['description'] = f"Cash {direction} {p2pHandle}"
p2p_handle = act['p2pHandle']
act['description'] = f"Cash {direction} {p2p_handle}"

elif act['type'] == 'PROMOTION' and act['subType'] == 'INCENTIVE_BONUS':
type_ = act['type'].capitalize()
Expand Down Expand Up @@ -685,7 +701,7 @@ def security_id_to_symbol(self, security_id: str) -> str:
security_symbol = f"[{security_id}]"
if self.security_market_data_cache_getter:
market_data = self.get_security_market_data(security_id)
if market_data and 'stock' in market_data and market_data['stock']:
if isinstance(market_data, dict) and market_data['stock']:
stock = market_data['stock']
security_symbol = f"{stock['primaryExchange']}:{stock['symbol']}"
return security_symbol
Expand All @@ -706,7 +722,7 @@ def get_transfer_details(self, transfer_id):
'object',
)

def set_security_market_data_cache(self, security_market_data_cache_getter: callable, security_market_data_cache_setter: callable):
def set_security_market_data_cache(self, security_market_data_cache_getter: Callable, security_market_data_cache_setter: Callable) -> None:
self.security_market_data_cache_getter = security_market_data_cache_getter
self.security_market_data_cache_setter = security_market_data_cache_setter

Expand All @@ -731,14 +747,7 @@ def get_security_market_data(self, security_id: str, use_cache: bool = True):

return value

def search_security(self, query):
# Fetch security search results using GraphQL query
return self.do_graphql_query(
'FetchSecuritySearchResult',
{'query': query},
'securitySearch.results',
'array',
)


def get_security_historical_quotes(self, security_id, time_range='1m'):
# Fetch historical quotes for a security using GraphQL query
Expand All @@ -765,7 +774,7 @@ def get_corporate_action_child_activities(self, activity_canonical_id):

def get_statement_transactions(self, account_id: str, period: str) -> list[Any]:
"""Retrieve transactions from account monthly statement.

Args:
account_id (str): The account ID to retrieve transactions for.
period (str): The statement start date in 'YYYY-MM-DD' format.
Expand Down Expand Up @@ -795,5 +804,5 @@ def get_statement_transactions(self, account_id: str, period: str) -> list[Any]:

if not isinstance(transactions, list):
raise WSApiException(f"Unexpected response format: {self.get_statement_transactions.__name__}", transactions)

return transactions