Skip to content

Commit

Permalink
Parallelize statcast calls (#168)
Browse files Browse the repository at this point in the history
* Parallellize testing

* Parallellize statcast calls; 1 call per day for future caching

* Add tqdm progress to the statcast parallelization
  • Loading branch information
TheCleric committed Nov 6, 2020
1 parent acc9701 commit 333430d
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 66 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Expand Up @@ -22,7 +22,7 @@ jobs:
python -m pip install --upgrade pip
pip install -e .[test]
- name: Run tests
run: pytest tests --doctest-modules --cov=pybaseball --cov-report term-missing
run: pytest tests --doctest-modules --cov=pybaseball --cov-report term-missing -n 5
- name: Timing test
timeout-minutes: 2
run: python -m scripts.statcast_timing
Expand Down
47 changes: 23 additions & 24 deletions pybaseball/statcast.py
@@ -1,13 +1,15 @@
from datetime import date, timedelta
from typing import List, Optional, Union
import concurrent.futures
import warnings
from datetime import date
from typing import Optional, Union

import pandas as pd
from tqdm import tqdm

import pybaseball.datasources.statcast as statcast_ds

from .utils import sanitize_date_range, statcast_date_range
from . import cache
from .utils import sanitize_date_range, statcast_date_range

_SC_SINGLE_GAME_REQUEST = "/statcast_search/csv?all=true&type=details&game_pk={game_pk}"
# pylint: disable=line-too-long
Expand All @@ -16,7 +18,7 @@


@cache.df_cache(expires=365)
def _small_request(start_dt: date, end_dt: date, team: Optional[str] = None, verbose: bool = False) -> pd.DataFrame:
def _small_request(start_dt: date, end_dt: date, team: Optional[str] = None) -> pd.DataFrame:
data = statcast_ds.get_statcast_data_from_csv_url(
_SC_SMALL_REQUEST.format(start_dt=str(start_dt), end_dt=str(end_dt), team=team if team else '')
)
Expand All @@ -25,8 +27,6 @@ def _small_request(start_dt: date, end_dt: date, team: Optional[str] = None, ver
['game_date', 'game_pk', 'at_bat_number', 'pitch_number'],
ascending=False
)
if verbose:
print(f"Completed sub-query from {start_dt} to {end_dt} ({len(data)} results)")

return data

Expand All @@ -39,35 +39,41 @@ def _small_request(start_dt: date, end_dt: date, team: Optional[str] = None, ver
you could lose a lot of progress. Enabling caching will allow you to immediately recover all the successful
subqueries if that happens.'''


def _check_warning(start_dt: date, end_dt: date) -> None:
if not cache.config.enabled and (end_dt - start_dt).days >= 42:
warnings.warn(_OVERSIZE_WARNING)


def _handle_request(start_dt: date, end_dt: date, step: int, verbose: bool,
team: Optional[str] = None) -> pd.DataFrame:
team: Optional[str] = None) -> pd.DataFrame:
"""
Fulfill the request in sensible increments.
"""

_check_warning(start_dt, end_dt)

dataframe_list = []

if verbose:
print("This is a large query, it may take a moment to complete")
print("This is a large query, it may take a moment to complete", flush=True)

for subq_start, subq_end in statcast_date_range(start_dt, end_dt, step, verbose):
data = _small_request(subq_start, subq_end, team=team, verbose=verbose)
dataframe_list = []
date_range = list(statcast_date_range(start_dt, end_dt, step, verbose))

# Append to list of dataframes if not empty or failed
# (failed requests have one row saying "Error: Query Timeout")
if data is not None and not data.empty:
dataframe_list.append(data)
with tqdm(total=len(date_range)) as progress:
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = {executor.submit(_small_request, subq_start, subq_end, team=team)
for subq_start, subq_end in date_range}
for future in concurrent.futures.as_completed(futures):
dataframe_list.append(future.result())
progress.update(1)

# Concatenate all dataframes into final result set
if dataframe_list:
final_data = pd.concat(dataframe_list, axis=0).convert_dtypes(convert_string=False)
final_data = final_data.sort_values(
['game_date', 'game_pk', 'at_bat_number', 'pitch_number'],
ascending=False
)
else:
final_data = pd.DataFrame()
return final_data
Expand All @@ -88,14 +94,7 @@ def statcast(start_dt: str = None, end_dt: str = None, team: str = None, verbose

start_dt_date, end_dt_date = sanitize_date_range(start_dt, end_dt)

# small_query_threshold days or less -> a quick one-shot request.
# this is handled by the iterator in large_request just doing the one thingy otherwise.
# Greater than small_query_threshold days -> break it into multiple smaller queries
# The reason 7 is chosen here is because statcast will return at most 40000 rows.
# 7 seems to be the largest number of days that will guarantee no dropped rows.
small_query_threshold = 7

return _handle_request(start_dt_date, end_dt_date, step=small_query_threshold, verbose=verbose, team=team)
return _handle_request(start_dt_date, end_dt_date, 1, verbose=verbose, team=team)


def statcast_single_game(game_pk: Union[str, int]) -> pd.DataFrame:
Expand Down
7 changes: 5 additions & 2 deletions setup.py
@@ -1,9 +1,10 @@
# Always prefer setuptools over distutils
from setuptools import setup, find_packages
# To use a consistent encoding
from codecs import open
from os import path

from setuptools import find_packages, setup

here = path.abspath(path.dirname(__file__))

# Get the long description from the README file
Expand Down Expand Up @@ -83,7 +84,8 @@
'pygithub>=1.51',
'altair>=4.0.0',
'scipy>=1.4.0',
'matplotlib>=2.0.0'
'matplotlib>=2.0.0',
'tqdm>=4.50.0',
],

# List additional groups of dependencies here (e.g. development
Expand All @@ -95,6 +97,7 @@
'test': ['pytest>=6.0.2',
'mypy>=0.782',
'pytest-cov>=2.10.1',
'pytest-xdist>=2.1.0',
],
},

Expand Down
40 changes: 1 addition & 39 deletions tests/pybaseball/test_statcast.py
Expand Up @@ -3,26 +3,13 @@
import pandas as pd
import pytest

from pybaseball.statcast import (_SC_SINGLE_GAME_REQUEST, _SC_SMALL_REQUEST, sanitize_date_range, statcast,
statcast_single_game)
from pybaseball.statcast import _SC_SINGLE_GAME_REQUEST, statcast_single_game
from pybaseball.utils import DATE_FORMAT

# For an explanation of this type, see the note on GetDataFrameCallable in tests/pybaseball/conftest.py
from .conftest import GetDataFrameCallable


@pytest.fixture(name="small_request_raw")
def _small_request_raw(get_data_file_contents: Callable[[str], str]) -> str:
return get_data_file_contents('small_request_raw.csv')


@pytest.fixture(name="small_request")
def _small_request(get_data_file_dataframe: GetDataFrameCallable) -> pd.DataFrame:
data = get_data_file_dataframe('small_request.csv', parse_dates=[2])
data[data.columns[2]].apply(pd.to_datetime, errors='ignore', format=DATE_FORMAT)
return data


@pytest.fixture(name="single_game_raw")
def _single_game_raw(get_data_file_contents: Callable[[str], str]) -> str:
return get_data_file_contents('single_game_request_raw.csv')
Expand All @@ -35,31 +22,6 @@ def _single_game(get_data_file_dataframe: GetDataFrameCallable) -> pd.DataFrame:
return data


def test_statcast(response_get_monkeypatch: Callable, small_request_raw: str, small_request: pd.DataFrame) -> None:
start_dt, end_dt = sanitize_date_range(None, None)
response_get_monkeypatch(
small_request_raw.encode('UTF-8'),
_SC_SMALL_REQUEST.format(start_dt=start_dt, end_dt=end_dt, team='')
)

statcast_result = statcast().reset_index(drop=True)

pd.testing.assert_frame_equal(statcast_result, small_request, check_dtype=False)


def test_statcast_team(response_get_monkeypatch: Callable, small_request_raw: str,
small_request: pd.DataFrame) -> None:
start_dt, end_dt = sanitize_date_range(None, None)
response_get_monkeypatch(
small_request_raw.encode('UTF-8'),
_SC_SMALL_REQUEST.format(start_dt=start_dt, end_dt=end_dt, team='TB')
)

statcast_result = statcast(None, None, team='TB').reset_index(drop=True)

pd.testing.assert_frame_equal(statcast_result, small_request, check_dtype=False)


def test_statcast_single_game_request(response_get_monkeypatch: Callable, single_game_raw: str,
single_game: pd.DataFrame) -> None:
game_pk = '631614'
Expand Down

0 comments on commit 333430d

Please sign in to comment.