Skip to content

Commit

Permalink
Add some type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
hroff-1902 committed Feb 2, 2020
1 parent 2396f35 commit f3d5000
Show file tree
Hide file tree
Showing 28 changed files with 114 additions and 100 deletions.
10 changes: 6 additions & 4 deletions freqtrade/commands/data_commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import sys
from typing import Any, Dict, List
from typing import Any, Dict, List, cast

import arrow

Expand Down Expand Up @@ -43,16 +43,18 @@ def start_download_data(args: Dict[str, Any]) -> None:
if config.get('download_trades'):
pairs_not_available = refresh_backtest_trades_data(
exchange, pairs=config["pairs"], datadir=config['datadir'],
timerange=timerange, erase=config.get("erase"))
timerange=timerange, erase=cast(bool, config.get("erase")))

# Convert downloaded trade data to different timeframes
convert_trades_to_ohlcv(
pairs=config["pairs"], timeframes=config["timeframes"],
datadir=config['datadir'], timerange=timerange, erase=config.get("erase"))
datadir=config['datadir'], timerange=timerange,
erase=cast(bool, config.get("erase")))
else:
pairs_not_available = refresh_backtest_ohlcv_data(
exchange, pairs=config["pairs"], timeframes=config["timeframes"],
datadir=config['datadir'], timerange=timerange, erase=config.get("erase"))
datadir=config['datadir'], timerange=timerange,
erase=cast(bool, config.get("erase")))

except KeyboardInterrupt:
sys.exit("SIGINT received, aborting ...")
Expand Down
4 changes: 2 additions & 2 deletions freqtrade/commands/deploy_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def start_create_userdir(args: Dict[str, Any]) -> None:
sys.exit(1)


def deploy_new_strategy(strategy_name, strategy_path: Path, subtemplate: str):
def deploy_new_strategy(strategy_name: str, strategy_path: Path, subtemplate: str) -> None:
"""
Deploy new strategy from template to strategy_path
"""
Expand Down Expand Up @@ -69,7 +69,7 @@ def start_new_strategy(args: Dict[str, Any]) -> None:
raise OperationalException("`new-strategy` requires --strategy to be set.")


def deploy_new_hyperopt(hyperopt_name, hyperopt_path: Path, subtemplate: str):
def deploy_new_hyperopt(hyperopt_name: str, hyperopt_path: Path, subtemplate: str) -> None:
"""
Deploys a new hyperopt template to hyperopt_path
"""
Expand Down
2 changes: 1 addition & 1 deletion freqtrade/commands/plot_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from freqtrade.state import RunMode


def validate_plot_args(args: Dict[str, Any]):
def validate_plot_args(args: Dict[str, Any]) -> None:
if not args.get('datadir') and not args.get('config'):
raise OperationalException(
"You need to specify either `--datadir` or `--config` "
Expand Down
2 changes: 1 addition & 1 deletion freqtrade/configuration/check_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)


def remove_credentials(config: Dict[str, Any]):
def remove_credentials(config: Dict[str, Any]) -> None:
"""
Removes exchange keys from the configuration and specifies dry-run
Used for backtesting / hyperopt / edge and utils.
Expand Down
4 changes: 2 additions & 2 deletions freqtrade/configuration/deprecated_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def check_conflicting_settings(config: Dict[str, Any],
section1: str, name1: str,
section2: str, name2: str):
section2: str, name2: str) -> None:
section1_config = config.get(section1, {})
section2_config = config.get(section2, {})
if name1 in section1_config and name2 in section2_config:
Expand All @@ -28,7 +28,7 @@ def check_conflicting_settings(config: Dict[str, Any],

def process_deprecated_setting(config: Dict[str, Any],
section1: str, name1: str,
section2: str, name2: str):
section2: str, name2: str) -> None:
section2_config = config.get(section2, {})

if name2 in section2_config:
Expand Down
2 changes: 1 addition & 1 deletion freqtrade/configuration/directory_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def create_datadir(config: Dict[str, Any], datadir: Optional[str] = None) -> Pat
return folder


def create_userdata_dir(directory: str, create_dir=False) -> Path:
def create_userdata_dir(directory: str, create_dir: bool = False) -> Path:
"""
Create userdata directory structure.
if create_dir is True, then the parent-directory will be created if it does not exist.
Expand Down
5 changes: 3 additions & 2 deletions freqtrade/configuration/timerange.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import arrow


logger = logging.getLogger(__name__)


Expand All @@ -30,7 +31,7 @@ def __eq__(self, other):
return (self.starttype == other.starttype and self.stoptype == other.stoptype
and self.startts == other.startts and self.stopts == other.stopts)

def subtract_start(self, seconds) -> None:
def subtract_start(self, seconds: int) -> None:
"""
Subtracts <seconds> from startts if startts is set.
:param seconds: Seconds to subtract from starttime
Expand Down Expand Up @@ -59,7 +60,7 @@ def adjust_start_if_necessary(self, timeframe_secs: int, startup_candles: int,
self.starttype = 'date'

@staticmethod
def parse_timerange(text: Optional[str]):
def parse_timerange(text: Optional[str]) -> 'TimeRange':
"""
Parse the value of the argument --timerange to determine what is the range desired
:param text: value from --timerange
Expand Down
7 changes: 4 additions & 3 deletions freqtrade/data/btanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import logging
from pathlib import Path
from typing import Dict
from typing import Dict, Union

import numpy as np
import pandas as pd
Expand All @@ -20,7 +20,7 @@
"open_rate", "close_rate", "open_at_end", "sell_reason"]


def load_backtest_data(filename) -> pd.DataFrame:
def load_backtest_data(filename: Union[Path, str]) -> pd.DataFrame:
"""
Load backtest data file.
:param filename: pathlib.Path object, or string pointing to the file.
Expand Down Expand Up @@ -151,7 +151,8 @@ def extract_trades_of_period(dataframe: pd.DataFrame, trades: pd.DataFrame) -> p
return trades


def combine_tickers_with_mean(tickers: Dict[str, pd.DataFrame], column: str = "close"):
def combine_tickers_with_mean(tickers: Dict[str, pd.DataFrame],
column: str = "close") -> pd.DataFrame:
"""
Combine multiple dataframes "column"
:param tickers: Dict of Dataframes, dict key should be pair.
Expand Down
12 changes: 6 additions & 6 deletions freqtrade/data/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def load_tickerdata_file(datadir: Path, pair: str, timeframe: str,


def store_tickerdata_file(datadir: Path, pair: str,
timeframe: str, data: list, is_zip: bool = False):
timeframe: str, data: list, is_zip: bool = False) -> None:
"""
Stores tickerdata to file
"""
Expand All @@ -109,15 +109,15 @@ def load_trades_file(datadir: Path, pair: str,


def store_trades_file(datadir: Path, pair: str,
data: list, is_zip: bool = True):
data: list, is_zip: bool = True) -> None:
"""
Stores tickerdata to file
"""
filename = pair_trades_filename(datadir, pair)
misc.file_dump_json(filename, data, is_zip=is_zip)


def _validate_pairdata(pair, pairdata, timerange: TimeRange):
def _validate_pairdata(pair: str, pairdata: List[Dict], timerange: TimeRange) -> None:
if timerange.starttype == 'date' and pairdata[0][0] > timerange.startts * 1000:
logger.warning('Missing data at start for pair %s, data starts at %s',
pair, arrow.get(pairdata[0][0] // 1000).strftime('%Y-%m-%d %H:%M:%S'))
Expand Down Expand Up @@ -331,7 +331,7 @@ def _download_pair_history(datadir: Path,

def refresh_backtest_ohlcv_data(exchange: Exchange, pairs: List[str], timeframes: List[str],
datadir: Path, timerange: Optional[TimeRange] = None,
erase=False) -> List[str]:
erase: bool = False) -> List[str]:
"""
Refresh stored ohlcv data for backtesting and hyperopt operations.
Used by freqtrade download-data subcommand.
Expand Down Expand Up @@ -401,7 +401,7 @@ def _download_trades_history(datadir: Path,


def refresh_backtest_trades_data(exchange: Exchange, pairs: List[str], datadir: Path,
timerange: TimeRange, erase=False) -> List[str]:
timerange: TimeRange, erase: bool = False) -> List[str]:
"""
Refresh stored trades data for backtesting and hyperopt operations.
Used by freqtrade download-data subcommand.
Expand All @@ -428,7 +428,7 @@ def refresh_backtest_trades_data(exchange: Exchange, pairs: List[str], datadir:


def convert_trades_to_ohlcv(pairs: List[str], timeframes: List[str],
datadir: Path, timerange: TimeRange, erase=False) -> None:
datadir: Path, timerange: TimeRange, erase: bool = False) -> None:
"""
Convert stored trades data to ohlcv data
"""
Expand Down
4 changes: 2 additions & 2 deletions freqtrade/edge/edge_positioning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pragma pylint: disable=W0603
""" Edge positioning package """
import logging
from typing import Any, Dict, NamedTuple
from typing import Any, Dict, List, NamedTuple

import arrow
import numpy as np
Expand Down Expand Up @@ -181,7 +181,7 @@ def stoploss(self, pair: str) -> float:
'strategy stoploss is returned instead.')
return self.strategy.stoploss

def adjust(self, pairs) -> list:
def adjust(self, pairs: List[str]) -> list:
"""
Filters out and sorts "pairs" according to Edge calculated pairs
"""
Expand Down
45 changes: 26 additions & 19 deletions freqtrade/exchange/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
from freqtrade.exchange.common import BAD_EXCHANGES, retrier, retrier_async
from freqtrade.misc import deep_merge_dicts


# Should probably use typing.Literal when we switch to python 3.8+
# CcxtModuleType = Literal[ccxt, ccxt_async]
CcxtModuleType = Any


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -51,7 +57,7 @@ class Exchange:
}
_ft_has: Dict = {}

def __init__(self, config: dict, validate: bool = True) -> None:
def __init__(self, config: Dict[str, Any], validate: bool = True) -> None:
"""
Initializes this module with the given config,
it does basic validation whether the specified exchange and pairs are valid.
Expand Down Expand Up @@ -135,7 +141,7 @@ def __del__(self):
if self._api_async and inspect.iscoroutinefunction(self._api_async.close):
asyncio.get_event_loop().run_until_complete(self._api_async.close())

def _init_ccxt(self, exchange_config: dict, ccxt_module=ccxt,
def _init_ccxt(self, exchange_config: Dict[str, Any], ccxt_module: CcxtModuleType = ccxt,
ccxt_kwargs: dict = None) -> ccxt.Exchange:
"""
Initialize ccxt with given config and return valid
Expand Down Expand Up @@ -224,13 +230,13 @@ def get_quote_currencies(self) -> List[str]:
markets = self.markets
return sorted(set([x['quote'] for _, x in markets.items()]))

def klines(self, pair_interval: Tuple[str, str], copy=True) -> DataFrame:
def klines(self, pair_interval: Tuple[str, str], copy: bool = True) -> DataFrame:
if pair_interval in self._klines:
return self._klines[pair_interval].copy() if copy else self._klines[pair_interval]
else:
return DataFrame()

def set_sandbox(self, api, exchange_config: dict, name: str):
def set_sandbox(self, api: ccxt.Exchange, exchange_config: dict, name: str) -> None:
if exchange_config.get('sandbox'):
if api.urls.get('test'):
api.urls['api'] = api.urls['test']
Expand All @@ -240,7 +246,7 @@ def set_sandbox(self, api, exchange_config: dict, name: str):
"Please check your config.json")
raise OperationalException(f'Exchange {name} does not provide a sandbox api')

def _load_async_markets(self, reload=False) -> None:
def _load_async_markets(self, reload: bool = False) -> None:
try:
if self._api_async:
asyncio.get_event_loop().run_until_complete(
Expand Down Expand Up @@ -273,7 +279,7 @@ def _reload_markets(self) -> None:
except ccxt.BaseError:
logger.exception("Could not reload markets.")

def validate_stakecurrency(self, stake_currency) -> None:
def validate_stakecurrency(self, stake_currency: str) -> None:
"""
Checks stake-currency against available currencies on the exchange.
:param stake_currency: Stake-currency to validate
Expand Down Expand Up @@ -319,7 +325,7 @@ def validate_pairs(self, pairs: List[str]) -> None:
f"Please check if you are impacted by this restriction "
f"on the exchange and eventually remove {pair} from your whitelist.")

def get_valid_pair_combination(self, curr_1, curr_2) -> str:
def get_valid_pair_combination(self, curr_1: str, curr_2: str) -> str:
"""
Get valid pair combination of curr_1 and curr_2 by trying both combinations.
"""
Expand Down Expand Up @@ -373,7 +379,7 @@ def validate_order_time_in_force(self, order_time_in_force: Dict) -> None:
raise OperationalException(
f'Time in force policies are not supported for {self.name} yet.')

def validate_required_startup_candles(self, startup_candles) -> None:
def validate_required_startup_candles(self, startup_candles: int) -> None:
"""
Checks if required startup_candles is more than ohlcv_candle_limit.
Requires a grace-period of 5 candles - so a startup-period up to 494 is allowed by default.
Expand All @@ -392,7 +398,7 @@ def exchange_has(self, endpoint: str) -> bool:
"""
return endpoint in self._api.has and self._api.has[endpoint]

def amount_to_precision(self, pair, amount: float) -> float:
def amount_to_precision(self, pair: str, amount: float) -> float:
'''
Returns the amount to buy or sell to a precision the Exchange accepts
Reimplementation of ccxt internal methods - ensuring we can test the result is correct
Expand All @@ -406,7 +412,7 @@ def amount_to_precision(self, pair, amount: float) -> float:

return amount

def price_to_precision(self, pair, price: float) -> float:
def price_to_precision(self, pair: str, price: float) -> float:
'''
Returns the price rounded up to the precision the Exchange accepts.
Partial Reimplementation of ccxt internal method decimal_to_precision(),
Expand Down Expand Up @@ -494,7 +500,7 @@ def create_order(self, pair: str, ordertype: str, side: str, amount: float,
raise OperationalException(e) from e

def buy(self, pair: str, ordertype: str, amount: float,
rate: float, time_in_force) -> Dict:
rate: float, time_in_force: str) -> Dict:

if self._config['dry_run']:
dry_order = self.dry_run_order(pair, ordertype, "buy", amount, rate)
Expand All @@ -507,7 +513,7 @@ def buy(self, pair: str, ordertype: str, amount: float,
return self.create_order(pair, ordertype, 'buy', amount, rate, params)

def sell(self, pair: str, ordertype: str, amount: float,
rate: float, time_in_force='gtc') -> Dict:
rate: float, time_in_force: str = 'gtc') -> Dict:

if self._config['dry_run']:
dry_order = self.dry_run_order(pair, ordertype, "sell", amount, rate)
Expand Down Expand Up @@ -976,8 +982,8 @@ def get_trades_for_order(self, order_id: str, pair: str, since: datetime) -> Lis
raise OperationalException(e) from e

@retrier
def get_fee(self, symbol, type='', side='', amount=1,
price=1, taker_or_maker='maker') -> float:
def get_fee(self, symbol: str, type: str = '', side: str = '', amount: float = 1,
price: float = 1, taker_or_maker: str = 'maker') -> float:
try:
# validate that markets are loaded before trying to get fee
if self._api.markets is None or len(self._api.markets) == 0:
Expand All @@ -1000,22 +1006,22 @@ def get_exchange_bad_reason(exchange_name: str) -> str:
return BAD_EXCHANGES.get(exchange_name, "")


def is_exchange_known_ccxt(exchange_name: str, ccxt_module=None) -> bool:
def is_exchange_known_ccxt(exchange_name: str, ccxt_module: CcxtModuleType = None) -> bool:
return exchange_name in ccxt_exchanges(ccxt_module)


def is_exchange_officially_supported(exchange_name: str) -> bool:
return exchange_name in ['bittrex', 'binance']


def ccxt_exchanges(ccxt_module=None) -> List[str]:
def ccxt_exchanges(ccxt_module: CcxtModuleType = None) -> List[str]:
"""
Return the list of all exchanges known to ccxt
"""
return ccxt_module.exchanges if ccxt_module is not None else ccxt.exchanges


def available_exchanges(ccxt_module=None) -> List[str]:
def available_exchanges(ccxt_module: CcxtModuleType = None) -> List[str]:
"""
Return exchanges available to the bot, i.e. non-bad exchanges in the ccxt list
"""
Expand Down Expand Up @@ -1075,7 +1081,8 @@ def timeframe_to_next_date(timeframe: str, date: datetime = None) -> datetime:
return datetime.fromtimestamp(new_timestamp, tz=timezone.utc)


def symbol_is_pair(market_symbol: str, base_currency: str = None, quote_currency: str = None):
def symbol_is_pair(market_symbol: str, base_currency: str = None,
quote_currency: str = None) -> bool:
"""
Check if the market symbol is a pair, i.e. that its symbol consists of the base currency and the
quote currency separated by '/' character. If base_currency and/or quote_currency is passed,
Expand All @@ -1088,7 +1095,7 @@ def symbol_is_pair(market_symbol: str, base_currency: str = None, quote_currency
(symbol_parts[1] == quote_currency if quote_currency else len(symbol_parts[1]) > 0))


def market_is_active(market):
def market_is_active(market: Dict) -> bool:
"""
Return True if the market is active.
"""
Expand Down

0 comments on commit f3d5000

Please sign in to comment.